# Import Libraries and Define Game Environment

In [1]:
# Cell 1: Import Libraries and Define Game Environment

import os
import numpy as np
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import deque
from multiprocessing import Pool, cpu_count
import random
import logging
import threading

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Create console handler and set level to info
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
# Add handler to logger
logger.addHandler(ch)

# For Logging and Visualization
from torch.utils.tensorboard import SummaryWriter

# For Interactive Widgets
import ipywidgets as widgets
from IPython.display import display, clear_output

# Define the Connect Four game environment
class ConnectFour:
    ROWS = 6
    COLS = 7

    def __init__(self):
        self.board = np.zeros((self.ROWS, self.COLS), dtype=int)
        self.current_player = 1  # Player 1 starts

    def make_move(self, col):
        for row in reversed(range(self.ROWS)):
            if self.board[row, col] == 0:
                self.board[row, col] = self.current_player
                self.current_player *= -1  # Switch player
                return True
        return False  # Column is full

    def valid_moves(self):
        return [col for col in range(self.COLS) if self.board[0, col] == 0]

    def is_full(self):
        return np.all(self.board != 0)

    def check_winner(self):
        # Check horizontal locations for win
        for c in range(self.COLS - 3):
            for r in range(self.ROWS):
                piece = self.board[r][c]
                if piece != 0 and all(self.board[r][c + i] == piece for i in range(4)):
                    return piece

        # Check vertical locations for win
        for c in range(self.COLS):
            for r in range(self.ROWS - 3):
                piece = self.board[r][c]
                if piece != 0 and all(self.board[r + i][c] == piece for i in range(4)):
                    return piece

        # Check positively sloped diagonals
        for c in range(self.COLS - 3):
            for r in range(self.ROWS - 3):
                piece = self.board[r][c]
                if piece != 0 and all(self.board[r + i][c + i] == piece for i in range(4)):
                    return piece

        # Check negatively sloped diagonals
        for c in range(self.COLS - 3):
            for r in range(3, self.ROWS):
                piece = self.board[r][c]
                if piece != 0 and all(self.board[r - i][c + i] == piece for i in range(4)):
                    return piece

        return 0  # No winner

    def reset(self):
        self.board = np.zeros((self.ROWS, self.COLS), dtype=int)
        self.current_player = 1

# Define Neural Network with Residual Blocks

In [2]:
# Cell 2: Define Neural Network with Residual Blocks

# Define a Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual  # Out-of-place addition to prevent in-place errors
        out = F.relu(out)
        return out

# Define the ConnectFour Neural Network with Residual Blocks
class ConnectNet(nn.Module):
    def __init__(self, num_residual_blocks=6):
        super(ConnectNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        # Stack multiple Residual Blocks
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_residual_blocks)]
        )
        self.conv_final = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn_final = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 6 * 7, 512)
        self.dropout = nn.Dropout(p=0.6)  # Increased dropout for regularization
        self.fc_policy = nn.Linear(512, 7)
        self.fc_value = nn.Linear(512, 1)

    def forward(self, x):
        x = x.view(-1, 1, 6, 7)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.residual_blocks(x)
        x = F.relu(self.bn_final(self.conv_final(x)))
        x = x.view(-1, 64 * 6 * 7)
        x = self.dropout(F.relu(self.fc1(x)))
        policy_logits = self.fc_policy(x)
        value = torch.tanh(self.fc_value(x))
        return policy_logits, value

# Define MCTS and Related Functions

In [3]:
# Cell 3: Define MCTS and Related Functions

# Define the MCTS Node
class MCTSNode:
    def __init__(self, state, parent=None):
        self.state = copy.deepcopy(state)
        self.parent = parent
        self.children = {}
        self.visits = 0
        self.value = 0.0
        self.prior = 0.0

# UCB Score Calculation
def ucb_score(parent, child, c_puct=2.0):
    prior_score = c_puct * child.prior * math.sqrt(parent.visits) / (1 + child.visits)
    value_score = child.value / (1 + child.visits)
    return prior_score - value_score

# Backpropagate the value up the path
def backpropagate(path, value):
    for node in reversed(path):
        node.visits += 1
        node.value += value
        value = -value  # Flip the value for the opponent

# Monte Carlo Tree Search
def mcts_search(root, net, num_simulations=800):
    for _ in range(num_simulations):
        node = root
        path = []

        # Selection
        while node.children:
            # Select the child with the highest UCB score
            child = max(node.children.values(), key=lambda child: ucb_score(node, child))
            node = child
            path.append(node)

        # Expansion
        winner = node.state.check_winner()
        if winner == 0 and not node.state.is_full():
            # Prepare the state tensor
            state_tensor = torch.tensor(node.state.board, dtype=torch.float32).unsqueeze(0).to(device)
            net.eval()
            with torch.no_grad():
                policy_logits, value = net(state_tensor)
            net.train()
            policy = F.softmax(policy_logits, dim=1).cpu().numpy()[0]

            valid_moves = node.state.valid_moves()
            policy = {idx: policy[idx] for idx in valid_moves}
            policy_sum = sum(policy.values())
            for idx in policy:
                policy[idx] /= policy_sum  # Normalize the probabilities

            # Add Dirichlet noise at the root node for exploration
            if node == root:
                dirichlet_alpha = 0.3
                epsilon = 0.25
                dirichlet_noise = np.random.dirichlet([dirichlet_alpha] * len(valid_moves))
                for i, idx in enumerate(valid_moves):
                    policy[idx] = (1 - epsilon) * policy[idx] + epsilon * dirichlet_noise[i]

            # Expand children
            for idx in valid_moves:
                child_state = copy.deepcopy(node.state)
                child_state.make_move(idx)
                child_node = MCTSNode(child_state, node)
                child_node.prior = policy[idx]
                node.children[idx] = child_node
            # Use the value estimate from the neural network
            backpropagate(path + [node], value.item())
        else:
            # Terminal node
            value = winner if winner != 0 else 0
            backpropagate(path + [node], value)
    # Choose the move with the most visits
    best_move = max(root.children.items(), key=lambda item: item[1].visits)[0]
    return best_move

# Define Replay Buffer, Data Augmentation, and Self-Play Functions

In [4]:
# Cell 4: Define Replay Buffer, Data Augmentation, and Self-Play Functions

# Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, max_size=1000000):
        self.buffer = deque(maxlen=max_size)
    
    def add(self, memory):
        self.buffer.extend(memory)
    
    def sample(self, batch_size):
        if len(self.buffer) < batch_size:
            return list(self.buffer)
        else:
            return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

# Data Augmentation Functions

def augment_board(board):
    """
    Generate augmented versions of the board by applying horizontal reflections.
    Returns a list of augmented boards.
    """
    augmented_boards = []
    # Original board
    augmented_boards.append(board)
    # Horizontal flip
    flipped = np.fliplr(board)
    augmented_boards.append(flipped)
    return augmented_boards

def augment_memory(memory):
    """
    Apply data augmentation to the entire memory.
    Returns a new augmented memory list.
    """
    augmented_memory = []
    for state, probs, reward in memory:
        augmented_states = augment_board(state)
        for aug_state in augmented_states:
            augmented_memory.append((aug_state, probs, reward))
    return augmented_memory

# Self-Play Function
def self_play(net, num_games=50, num_simulations=800):
    memory = []
    win_cnt = {1: 0, -1: 0, 'draw': 0}
    for game_num in range(num_games):
        game = ConnectFour()
        # Randomize starting player
        game.current_player = np.random.choice([1, -1])
        states = []
        mcts_probs = []
        current_players = []
        while True:
            root = MCTSNode(game)
            move = mcts_search(root, net, num_simulations=num_simulations)
            states.append(game.board.copy())
            mcts_prob = np.zeros(7)
            for idx, child in root.children.items():
                mcts_prob[idx] = child.visits
            if np.sum(mcts_prob) > 0:
                mcts_prob = mcts_prob / np.sum(mcts_prob)
            else:
                # If no children were expanded, use uniform probabilities
                mcts_prob = np.ones(7) / 7
            mcts_probs.append(mcts_prob)
            current_players.append(game.current_player)
            valid_move = game.make_move(move)
            if not valid_move:
                logger.warning(f"Invalid move attempted at column {move} by player {game.current_player}.")
                break  # Avoid infinite loops
            winner = game.check_winner()
            if winner != 0 or game.is_full():
                break
        # Assign values to the game states
        if winner == 0:
            win_cnt['draw'] += 1
        else:
            win_cnt[winner] += 1
        for i in range(len(states)):
            reward = winner if winner == current_players[i] else -winner
            memory.append((states[i], mcts_probs[i], reward))
    return memory, win_cnt

# Training Function
def train(net, replay_buffer, optimizer, batch_size=64, epochs=1, clip_grad=1.0):
    net.train()
    dataset_size = len(replay_buffer.buffer)
    if dataset_size < batch_size:
        logger.info("Not enough data to train. Skipping this epoch.")
        return 0, 0, 0  # Not enough data to train
    total_value_loss = 0.0
    total_policy_loss = 0.0
    total_loss = 0.0
    num_batches = 0

    for epoch in range(epochs):
        # Shuffle the data
        perm = np.random.permutation(dataset_size)
        for i in range(0, dataset_size, batch_size):
            end_idx = min(i + batch_size, dataset_size)
            batch_indices = perm[i:end_idx]
            batch = [replay_buffer.buffer[idx] for idx in batch_indices]
            batch_states, batch_probs, batch_rewards = zip(*batch)
            
            # Validate and filter states
            valid_batch = []
            for state, prob, reward in zip(batch_states, batch_probs, batch_rewards):
                state_array = np.array(state)
                if state_array.shape == (6, 7):
                    valid_batch.append((state_array, prob, reward))
                else:
                    logger.warning(f"Invalid state shape detected: {state_array.shape}. Skipping this state.")
            
            if not valid_batch:
                logger.warning("No valid game states in this batch. Skipping training for this epoch.")
                continue
            
            # Unzip the valid batch
            valid_batch_states, valid_batch_probs, valid_batch_rewards = zip(*valid_batch)
            
            # Convert to tensors
            state_tensor = torch.tensor(np.array(valid_batch_states), dtype=torch.float32).to(device)
            mcts_prob_tensor = torch.tensor(np.array(valid_batch_probs), dtype=torch.float32).to(device)
            reward_tensor = torch.tensor(np.array(valid_batch_rewards), dtype=torch.float32).to(device)
            
            # Forward pass
            policy_logits, value = net(state_tensor)
            
            # Calculate losses
            value_loss = F.mse_loss(value.squeeze(), reward_tensor)
            policy_loss = -torch.mean(torch.sum(mcts_prob_tensor * F.log_softmax(policy_logits, dim=1), dim=1))
            loss = value_loss + policy_loss
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=clip_grad)
            optimizer.step()
            
            total_value_loss += value_loss.item()
            total_policy_loss += policy_loss.item()
            total_loss += loss.item()
            num_batches += 1
    
    if num_batches > 0:
        avg_value_loss = total_value_loss / num_batches
        avg_policy_loss = total_policy_loss / num_batches
        avg_loss = total_loss / num_batches
    else:
        avg_value_loss = 0
        avg_policy_loss = 0
        avg_loss = 0
    return avg_value_loss, avg_policy_loss, avg_loss

# Define Evaluation and Checkpointing Functions

In [5]:
# Cell 5: Define Evaluation and Checkpointing Functions

# Evaluation Function
def evaluate(net, num_games=20, num_simulations=800):
    net.eval()
    win_cnt = {1: 0, -1: 0, 'draw': 0}
    for game_num in range(num_games):
        game = ConnectFour()
        game.current_player = np.random.choice([1, -1])  # Randomize starting player
        while True:
            if game.current_player == 1:
                # AI move
                root = MCTSNode(game)
                ai_move = mcts_search(root, net, num_simulations=num_simulations)
                valid_move = game.make_move(ai_move)
                if not valid_move:
                    logger.warning(f"AI attempted invalid move at column {ai_move}.")
                    break
            else:
                # Opponent move (heuristic-based)
                valid_moves = game.valid_moves()
                # Simple heuristic: choose the column with the least filled spots
                heights = [np.sum(game.board[:, col] != 0) for col in valid_moves]
                min_height = min(heights)
                candidate_cols = [col for col, height in zip(valid_moves, heights) if height == min_height]
                move = np.random.choice(candidate_cols)
                game.make_move(move)
            winner = game.check_winner()
            if winner != 0 or game.is_full():
                break
        if winner == 0:
            win_cnt['draw'] += 1
        else:
            win_cnt[winner] += 1
    return win_cnt

# Save Checkpoint
def save_checkpoint(state, filename='connect4_model.pth'):
    torch.save(state, filename)
    logger.info(f"Checkpoint saved to {filename}")

# Load Checkpoint
def load_checkpoint(filename='connect4_model.pth'):
    if not os.path.exists(filename):
        logger.warning(f"No checkpoint found at {filename}.")
        return None
    checkpoint = torch.load(filename, map_location=device)
    logger.info(f"Checkpoint loaded from {filename}")
    return checkpoint

# Logging and Visualisation Setup

In [None]:
# Cell 6: Logging and Visualization Setup

# Initialize TensorBoard writer
writer = SummaryWriter(log_dir='connect4_logs')

# Function to log training metrics
def log_training_metrics(iteration, avg_value_loss, avg_policy_loss, avg_loss, self_play_stats, eval_stats=None):
    writer.add_scalar('Loss/Value', avg_value_loss, iteration)
    writer.add_scalar('Loss/Policy', avg_policy_loss, iteration)
    writer.add_scalar('Loss/Total', avg_loss, iteration)
    
    # Log self-play statistics
    for key, value in self_play_stats.items():
        writer.add_scalar(f'SelfPlay/{key}', value, iteration)
    
    # Log evaluation statistics if available
    if eval_stats:
        for key, value in eval_stats.items():
            writer.add_scalar(f'Evaluation/{key}', value, iteration)
    
    writer.flush()  # Ensure all pending logs are written

# Launch TensorBoard within Jupyter
%load_ext tensorboard
%tensorboard --logdir='C:\Users\Hephzibah\OneDrive\Documents\VScodeprojects\Connect4_project\connect4_logs'

Reusing TensorBoard on port 6008 (pid 15508), started 1 day, 20:55:45 ago. (Use '!kill 15508' to kill it.)

# Hyperparameter Configuration

In [7]:
# Cell 7: Hyperparameter Configuration

# Define Hyperparameters
class Hyperparameters:
    # Training Parameters
    num_residual_blocks = 6
    learning_rate = 0.001
    weight_decay = 1e-4
    batch_size = 64
    epochs = 1
    clip_grad = 1.0
    total_iterations = 1000
    
    # Self-Play Parameters
    num_self_play_games = 50
    num_mcts_simulations = 800
    
    # Replay Buffer
    replay_buffer_max_size = 1000000
    
    # Evaluation Parameters
    evaluation_interval = 10
    num_evaluation_games = 20
    
    # MCTS Parameters
    c_puct = 2.0
    dirichlet_alpha = 0.3
    epsilon = 0.25

# Instantiate Hyperparameters
hp = Hyperparameters()

# Training Loop

In [8]:
# Cell 8: Training Loop with Enhancements

# Initialize or load the model
net = ConnectNet(num_residual_blocks=hp.num_residual_blocks).to(device)  # Increased residual blocks for deeper network
optimizer = optim.Adam(net.parameters(), lr=hp.learning_rate, weight_decay=hp.weight_decay)  # L2 regularization
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

start_iteration = 0  # To keep track of iterations when loading

# Check if a saved model exists
model_filename = 'connect4_model.pth'
best_model_filename = 'connect4_best_model.pth'
load_existing = False
best_eval_score = -math.inf

if os.path.exists(model_filename):
    user_input = input("A saved model was found. Do you want to load it and continue training? (y/n): ")
    if user_input.lower() == 'y':
        load_existing = True
else:
    logger.info("No saved model found. Training a new model.")

if load_existing:
    checkpoint = load_checkpoint(model_filename)
    if checkpoint:
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_iteration = checkpoint['iteration']
        logger.info(f"Loaded model from iteration {start_iteration}")
        # Optionally load best_eval_score from checkpoint if stored
        if 'best_eval_score' in checkpoint:
            best_eval_score = checkpoint['best_eval_score']
    else:
        logger.warning("Failed to load checkpoint. Starting training from scratch.")
else:
    logger.info("Starting training from scratch.")

total_iterations = hp.total_iterations  # Total number of training iterations you want

# Initialize Replay Buffer
replay_buffer = ReplayBuffer(max_size=hp.replay_buffer_max_size)  # Increased buffer size

for iteration in range(start_iteration, total_iterations):
    logger.info(f"\n=== Iteration {iteration + 1}/{total_iterations} ===")
    # Self-Play
    memory, self_play_stats = self_play(net, num_games=hp.num_self_play_games, num_simulations=hp.num_mcts_simulations)
    
    # Apply Data Augmentation
    augmented_memory = augment_memory(memory)
    replay_buffer.add(augmented_memory)
    logger.info(f"Self-Play completed. Stats: {self_play_stats}")
    logger.info(f"Added {len(augmented_memory)} augmented game states to replay buffer.")
    
    # Train the network
    avg_value_loss, avg_policy_loss, avg_loss = train(net, replay_buffer, optimizer, 
                                                      batch_size=hp.batch_size, 
                                                      epochs=hp.epochs, 
                                                      clip_grad=hp.clip_grad)
    logger.info(f"Training completed. Avg Value Loss: {avg_value_loss:.4f}, Avg Policy Loss: {avg_policy_loss:.4f}, Avg Total Loss: {avg_loss:.4f}")
    
    # Log training metrics
    log_training_metrics(iteration + 1, avg_value_loss, avg_policy_loss, avg_loss, self_play_stats)
    
    # Update the learning rate scheduler
    scheduler.step(avg_loss)
    
    # Periodic Evaluation
    eval_stats = None
    if (iteration + 1) % hp.evaluation_interval == 0:
        eval_stats = evaluate(net, num_games=hp.num_evaluation_games, num_simulations=hp.num_mcts_simulations)
        logger.info(f"Evaluation after {iteration + 1} iterations: {eval_stats}")
        # Log evaluation metrics
        log_training_metrics(iteration + 1, avg_value_loss, avg_policy_loss, avg_loss, self_play_stats, eval_stats)
        
        # Calculate a simple evaluation score (e.g., wins minus losses)
        eval_score = eval_stats.get(1, 0) - eval_stats.get(-1, 0)
        
        # Save the best model
        if eval_score > best_eval_score:
            best_eval_score = eval_score
            save_checkpoint({
                'iteration': iteration + 1,
                'model_state_dict': net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_eval_score': best_eval_score
            }, filename=best_model_filename)
            logger.info(f"New best model saved at iteration {iteration + 1} with evaluation score {eval_score}")
    
    # Save the latest model checkpoint
    save_checkpoint({
        'iteration': iteration + 1,
        'model_state_dict': net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_eval_score': best_eval_score
    }, filename=model_filename)
    logger.info(f"Model checkpoint saved at iteration {iteration + 1}")

  checkpoint = torch.load(filename, map_location=device)
2024-11-22 13:04:12,736 - INFO - Checkpoint loaded from connect4_model.pth
2024-11-22 13:04:12,745 - INFO - Loaded model from iteration 33
2024-11-22 13:04:12,745 - INFO - 
=== Iteration 34/1000 ===


KeyboardInterrupt: 

# Playing Against the AI

In [9]:
# Cell 9: Playing Against the AI

# Function to Play Against the AI
def play_against_ai(net, num_simulations=800):
    net.eval()
    game = ConnectFour()
    while True:
        print("\nCurrent Board:")
        print(game.board)
        # Human move
        try:
            move = int(input("Enter your move (0-6): "))
        except ValueError:
            print("Invalid input. Please enter an integer between 0 and 6.")
            continue
        if move not in game.valid_moves():
            print("Invalid move. Try again.")
            continue
        game.make_move(move)
        winner = game.check_winner()
        if winner != 0 or game.is_full():
            break
        # AI move
        root = MCTSNode(game)
        ai_move = mcts_search(root, net, num_simulations=num_simulations)
        print(f"AI selects column {ai_move}")
        game.make_move(ai_move)
        winner = game.check_winner()
        if winner != 0 or game.is_full():
            break
    print("\nFinal Board:")
    print(game.board)
    if winner == 1:
        print("You win!")
    elif winner == -1:
        print("AI wins!")
    else:
        print("It's a draw.")

# Function to Load the Best Trained Model
def load_best_model(model_filename='connect4_best_model.pth'):
    checkpoint = load_checkpoint(model_filename)
    if checkpoint:
        net_loaded = ConnectNet(num_residual_blocks=hp.num_residual_blocks).to(device)
        net_loaded.load_state_dict(checkpoint['model_state_dict'])
        net_loaded.eval()
        logger.info(f"Loaded best model from iteration {checkpoint['iteration']} with evaluation score {checkpoint.get('best_eval_score', 'N/A')}")
        return net_loaded
    else:
        logger.warning("Best model not found.")
        return None
    
# Load the best trained model
best_net = load_best_model()

# Play against the AI if the best model was successfully loaded
if best_net is not None:
    play_against_ai(best_net, num_simulations=hp.num_mcts_simulations)
else:
    logger.warning("Cannot play against AI without a trained best model.")

  checkpoint = torch.load(filename, map_location=device)
2024-11-22 13:04:24,472 - INFO - Checkpoint loaded from connect4_best_model.pth
2024-11-22 13:04:24,506 - INFO - Loaded best model from iteration 10 with evaluation score 6



Current Board:
[[0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0]]
AI selects column 3

Current Board:
[[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0 -1  0  0  0]
 [ 0  0  0  1  0  0  0]]
AI selects column 4

Current Board:
[[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0 -1  0  0  0]
 [ 0  0  1  1 -1  0  0]]
AI selects column 1

Current Board:
[[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0 -1  0  0  0]
 [ 1 -1  1  1 -1  0  0]]
AI selects column 3

Current Board:
[[ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0]
 [ 0  0  0 -1  0  0  0]
 [ 0  1  0 -1  0  0  0]
 [ 1 -1  1  1 -1  0  0]]
AI selects column 3

Current Board:
[[ 0  0  0  0  0  0  0]
 [ 0  0  0 -1  0  0  0]
 [ 0  0  0  1  0  0  0]
 [ 0  0  0 -1  0  0  0]
 [ 0  1  0 -1  0  0  0]
 

# Interactive Widget for training and Playing

In [None]:
# Cell 10: Interactive Widgets for Training and Playing

# Define Buttons
train_button = widgets.Button(description="Start Training")
play_button = widgets.Button(description="Play Against AI")
stop_training_button = widgets.Button(description="Stop Training", button_style='danger')

# Define Output Areas
train_output = widgets.Output()
play_output = widgets.Output()

display(train_button, play_button, stop_training_button, train_output, play_output)

# Global flag to control training
stop_training_flag = False

# Define Training Function Wrapper with Threading
def start_training_thread():
    global stop_training_flag
    stop_training_flag = False
    with train_output:
        clear_output()
        logger.info("Starting Training...")
    for iteration in range(start_iteration, hp.total_iterations):
        if stop_training_flag:
            logger.info("Training stopped by user.")
            break
        logger.info(f"\n=== Iteration {iteration + 1}/{hp.total_iterations} ===")
        # Self-Play
        memory, self_play_stats = self_play(net, num_games=hp.num_self_play_games, num_simulations=hp.num_mcts_simulations)
        
        # Apply Data Augmentation
        augmented_memory = augment_memory(memory)
        replay_buffer.add(augmented_memory)
        logger.info(f"Self-Play completed. Stats: {self_play_stats}")
        logger.info(f"Added {len(augmented_memory)} augmented game states to replay buffer.")
        
        # Train the network
        avg_value_loss, avg_policy_loss, avg_loss = train(net, replay_buffer, optimizer, 
                                                          batch_size=hp.batch_size, 
                                                          epochs=hp.epochs, 
                                                          clip_grad=hp.clip_grad)
        logger.info(f"Training completed. Avg Value Loss: {avg_value_loss:.4f}, Avg Policy Loss: {avg_policy_loss:.4f}, Avg Total Loss: {avg_loss:.4f}")
        
        # Log training metrics
        log_training_metrics(iteration + 1, avg_value_loss, avg_policy_loss, avg_loss, self_play_stats)
        
        # Update the learning rate scheduler
        scheduler.step(avg_loss)
        
        # Periodic Evaluation
        eval_stats = None
        if (iteration + 1) % hp.evaluation_interval == 0:
            eval_stats = evaluate(net, num_games=hp.num_evaluation_games, num_simulations=hp.num_mcts_simulations)
            logger.info(f"Evaluation after {iteration + 1} iterations: {eval_stats}")
            # Log evaluation metrics
            log_training_metrics(iteration + 1, avg_value_loss, avg_policy_loss, avg_loss, self_play_stats, eval_stats)
            
            # Calculate a simple evaluation score (e.g., wins minus losses)
            eval_score = eval_stats.get(1, 0) - eval_stats.get(-1, 0)
            
            # Save the best model
            if eval_score > best_eval_score:
                best_eval_score = eval_score
                save_checkpoint({
                    'iteration': iteration + 1,
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_eval_score': best_eval_score
                }, filename=best_model_filename)
                logger.info(f"New best model saved at iteration {iteration + 1} with evaluation score {eval_score}")
        
        # Save the latest model checkpoint
        save_checkpoint({
            'iteration': iteration + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_eval_score': best_eval_score
        }, filename=model_filename)
        logger.info(f"Model checkpoint saved at iteration {iteration + 1}")
    
    logger.info("Training Completed.")

# Define Play Function Wrapper
def start_playing(b):
    with play_output:
        clear_output()
        logger.info("Loading Best Model...")
        loaded_net = load_best_model()
        if loaded_net is not None:
            play_against_ai(loaded_net, num_simulations=hp.num_mcts_simulations)
        else:
            logger.warning("Cannot play against AI without a trained best model.")

# Define Stop Training Function
def stop_training_func(b):
    global stop_training_flag
    stop_training_flag = True
    with train_output:
        logger.info("Stopping Training...")

# Assign Functions to Buttons
def on_train_button_clicked(b):
    training_thread = threading.Thread(target=start_training_thread)
    training_thread.start()

train_button.on_click(on_train_button_clicked)
play_button.on_click(start_playing)
stop_training_button.on_click(stop_training_func)

Button(description='Start Training', style=ButtonStyle())

Button(description='Play Against AI', style=ButtonStyle())

Button(button_style='danger', description='Stop Training', style=ButtonStyle())

Output()

Output()

2024-11-20 22:12:45,188 - INFO - Starting Training...
2024-11-20 22:12:45,188 - INFO - 
=== Iteration 8/1000 ===
