# Reinforcement Learning with Convolutional Neural Network

This notebook demonstrates a basic Deep Q-Network (DQN) style reinforcement learning setup with a convolutional neural network (CNN) as the Q-function approximator.

---

## General Considerations

> ⚠️ Run this notebook **locally**, not in Colab. Audio processing with `ffmpeg` does **not** work in Colab. Additionally, the models and metrics routing is based on the repo relative locations, so make sure to **clone** the Repo

1. Clone this repo.
2. Install dependencies, including `ffmpeg`:
   ```bash
   brew install ffmpeg  # for macOS
   ```
3. Open the notebook locally. Set `TRAIN_MODEL = True` to retrain, or `False` to use the provided model (in `/models/`) and metrics (in `/metrics/`).

In [None]:
TRAIN_MODEL = False

Just check weather `ffmpeg` is installed or not

In [7]:
!which ffmpeg || echo "FFmpeg not found!" # else install it using e.g. brew install ffmpeg

/opt/homebrew/bin/ffmpeg


## 1. Setup and Imports

Import necessary libraries and set up the environment.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
from copy import deepcopy
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, HTML
import time
import io
import re
import tempfile
from contextlib import suppress
import os

Device configuration

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    print("CUDA device name:", torch.cuda.get_device_name(0))

False
Using device: cpu


## 2. Define the Convolutional Q-Network

A simple CNN that takes in a board and outputs Q-values for each action.

In [6]:
class ConvQNetwork(nn.Module):
    def __init__(self, input_shape=(1,6,7), num_actions=7):
        super().__init__()
        c, h, w = input_shape
        self.conv = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=(3,3), stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=(3,3), stride=1, padding=1),
            nn.ReLU(),
        )
        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 128),
            nn.ReLU(),
            nn.Linear(128, num_actions)
        )

    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x).view(x.size(0), -1)
        return self.fc(conv_out)

# Define Environment

This class defines a Connect 4 environment tailored for reinforcement learning using OpenAI Gym. 

In [None]:
import requests
import gym
from gym import spaces

class BoardEnv(gym.Env):

    metadata = {'render.modes': ['human']}

    def __init__(self):
        super().__init__()
        # Observation: 6×7 matrix with values in {-1, 0, +1}
        self.observation_space = spaces.Box(low=-1, high=1, shape=(6,7), dtype=np.int8)
        # Actions: drop in one of 7 columns
        self.action_space = spaces.Discrete(7)

        self.state = {}
        self.current_player = +1

    def reset(self):
        board = np.zeros((6,7), dtype=np.int8)
        self.state["board"] = board
        self.state["move-sequence"] = ""
        self.current_player = +1
        # Return a fresh copy for safety:
        return {"board": board.copy(),
                "move-sequence": self.state["move-sequence"]}

    def _get_action_reward(self, sequence, action):
      scores = [0] * 7
      # GET request
      response = requests.get('https://ludolab.net/solve/connect4?position=' + sequence)
      for score in response.json():
          scores[int(score['move'])-1] = score['score']
      return (scores[action]+20)/40

    def step(self, action):
        """
        action: integer 0–6, the column to drop your piece into.
        Returns: (next_state, reward, done, info)
        """
        # 1) Check legality
        if not self._is_valid_action(action):
            # Illegal move: immediate loss
            return {"board": self.state["board"].copy(), "move-sequence": self.state["move-sequence"]}, -1.0, True, {"illegal_move": True}

        # 2) Apply move
        row = self._get_drop_row(action)
        self.state["board"][row, action] = self.current_player
        reward = self._get_action_reward(self.state["move-sequence"], action)
        self.state["move-sequence"] += str(action+1)

        # 3) Check for win
        if self._check_win(self.state["board"], action, self.current_player):
            return {"board": self.state["board"].copy(), "move-sequence": self.state["move-sequence"]}, reward, True, {}

        # 4) Check for draw
        if np.all(self.state["board"] != 0):
            return {"board": self.state["board"].copy(), "move-sequence": self.state["move-sequence"]}, reward, True, {"draw": True}

        # 5) Otherwise, game continues
        self.current_player *= -1
        return {"board": self.state["board"].copy(), "move-sequence": self.state["move-sequence"]}, reward, False, {}

    def render(self, mode='human'):
        # Simple text render
        print(self.state["board"])

    def _is_valid_action(self, action):
        return 0 <= action < 7 and self.state["board"][0, action] == 0

    def _get_drop_row(self, action):
        # Find the lowest empty row in the chosen column
        col = self.state["board"][:, action]
        empties = np.where(col == 0)[0]
        return empties[-1]

    def _check_win(self, board: np.ndarray, action: int, player: int) -> bool:
      """
      Check for a four-in-a-row involving the most recent move in column `action`
      by `player` (±1). Returns True if that move created a win.
      """
      rows, cols = board.shape

      # 1) Find the row index where the last piece landed
      col_vals = board[:, action]
      # indices where the board equals the player in that column
      player_positions = np.where(col_vals == player)[0]
      row = player_positions[0]

      # 2) Define a helper to count in one direction
      def count_dir(dr: int, dc: int) -> int:
          r, c = row + dr, action + dc
          count = 0
          while 0 <= r < rows and 0 <= c < cols and board[r, c] == player:
              count += 1
              r += dr
              c += dc
          return count

      # 3) Check horizontal (← & →)
      horiz = 1 + count_dir(0, -1) + count_dir(0, +1)
      if horiz >= 4:
          return True

      # 4) Check vertical
      vert = 1 + count_dir(1, 0) + count_dir(-1, 0)
      if vert >= 4:
          return True

      # 5) Check diagonal up-right / down-left
      diag1 = 1 + count_dir(-1, +1) + count_dir(+1, -1)
      if diag1 >= 4:
          return True

      # 6) Check diagonal up-left / down-right
      diag2 = 1 + count_dir(-1, -1) + count_dir(+1, +1)
      if diag2 >= 4:
          return True

      return False




## 3. Replay Buffer

A simple replay buffer to store and sample transitions.


In [8]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            np.stack(states),
            np.array(actions),
            np.array(rewards, dtype=np.float32),
            np.stack(next_states),
            np.array(dones, dtype=np.uint8)
        )

    def __len__(self):
        return len(self.buffer)

## Illegal Moves Mask Helper

In [10]:
def get_illegal_moves_mask(state):
    """
    Given a board state (6×7 numpy array with 0=empty, ±1=player tokens),
    return a boolean list of length 7 where True indicates the column is full/illegal.
    """
    mask = [False] * 7
    # Top row index 0 corresponds to the highest (first-placed) slot in each column
    for col in range(7):
        if state[0, col] != 0:
            mask[col] = True
    return mask

# Epsilon Decay Helper

In [11]:
def get_epsilon(start, end, period, step):
    # linearly anneal from start → end over decay_steps
    fraction = min(step / period, 1.0)
    return start + fraction * (end - start)

## 4. Training Loop

A basic training loop.


In [None]:
import os
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

os.makedirs("checkpoints", exist_ok=True)

if TRAIN_MODEL == True:
  # create a log directory
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  log_dir = f"runs/c4_dqn_{timestamp}"
  checkpoint_dir = f"checkpoints/c4_dqn_{timestamp}"
  os.makedirs(checkpoint_dir, exist_ok=True)
  os.makedirs(log_dir, exist_ok=True)

  # instantiate the writer
  writer = SummaryWriter(log_dir)
  print(f"Logging to: {log_dir}")

  # Hyperparameters
  learning_rate = 1e-4
  gamma = 0.99
  buffer_capacity = 3000
  batch_size = 32
  sync_target_steps = 1000
  num_episodes = 500

  env = BoardEnv()

  # Initialize networks and optimizer
  policy_net1 = ConvQNetwork(input_shape=(1, 6, 7), num_actions=7).to(device)
  target_net1 = ConvQNetwork(input_shape=(1, 6, 7), num_actions=7).to(device)
  target_net1.load_state_dict(policy_net1.state_dict())
  optimizer1 = optim.Adam(policy_net1.parameters(), lr=learning_rate)
  replay_buffer1 = ReplayBuffer(buffer_capacity)

  policy_net2 = ConvQNetwork(input_shape=(1, 6, 7), num_actions=7).to(device)
  target_net2 = ConvQNetwork(input_shape=(1, 6, 7), num_actions=7).to(device)
  target_net2.load_state_dict(policy_net2.state_dict())
  optimizer2 = optim.Adam(policy_net2.parameters(), lr=learning_rate)
  replay_buffer2 = ReplayBuffer(buffer_capacity)

  steps_done = 0


  for episode in range(num_episodes):
      state = env.reset()  # obtain initial state
      done = False
      total_reward1 = 0
      total_reward2 = 0
      loss1Val = 0
      loss2Val = 0

      epsilon = get_epsilon(1.0, 0.05, 3000, steps_done)

      while not done:
          # Select action (epsilon-greedy with action masking)
          # Convert current board to tensor
          state_tensor = torch.tensor(state["board"], dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(1)  # shape [1,1,6,7]
          with torch.no_grad():
              qvals = policy_net1(state_tensor)  # shape [1,7]
              # Mask out illegal moves (e.g., full columns)
              illegal_mask = get_illegal_moves_mask(state["board"])  # bool array of length 7
              qvals[0][illegal_mask] = -1e9
              # Epsilon-greedy selection among legal actions
              if random.random() < epsilon:
                  valid_actions = [a for a, illegal in enumerate(illegal_mask) if not illegal]
                  action = random.choice(valid_actions)
              else:
                  action = qvals.argmax(dim=1).item()
          # Step environment
          try:
            next_state, reward, done, _ = env.step(action)
          except:
            action = -1
            break


          # Compute placeholder reward
          total_reward1 += reward

          # Store transition
          replay_buffer1.push(state["board"], action, reward, next_state["board"], done)

          state = next_state

          if not done:
            #Model2 turn
            with torch.no_grad():
                qvals = policy_net2(state_tensor)  # shape [1,7]
                # Mask out illegal moves (e.g., full columns)
                illegal_mask = get_illegal_moves_mask(state["board"])  # bool array of length 7
                qvals[0][illegal_mask] = -1e9
                # Epsilon-greedy selection among legal actions
                if random.random() < epsilon:
                    valid_actions = [a for a, illegal in enumerate(illegal_mask) if not illegal]
                    action = random.choice(valid_actions)
                else:
                    action = qvals.argmax(dim=1).item()
            # Step environment
            try:
              next_state, reward, done, _ = env.step(action)
            except:
              action = -1
              break

            # Compute placeholder reward
            total_reward2 += reward

            # Store transition
            replay_buffer2.push(state["board"], action, reward, next_state["board"], done)

          steps_done += 1

          if done and episode % 10 == 0:
            env.render()

          # Sample and learn
          if len(replay_buffer1) >= batch_size:
            # 1) Sample a batch
            states_b, actions_b, rewards_b, next_states_b, dones_b = replay_buffer1.sample(batch_size)

            # 2) Convert to tensors
            states_v      = torch.tensor(states_b,      dtype=torch.float32, device=device).unsqueeze(1)   # [B,1,6,7]
            actions_v     = torch.tensor(actions_b,     dtype=torch.int64,   device=device).unsqueeze(1)   # [B,1]
            rewards_v     = torch.tensor(rewards_b,     dtype=torch.float32, device=device)               # [B]
            next_states_v = torch.tensor(next_states_b, dtype=torch.float32, device=device).unsqueeze(1)   # [B,1,6,7]
            dones_v       = torch.tensor(dones_b,       dtype=torch.uint8,   device=device)               # [B]

            # 3) Compute current Q-values
            q_vals = policy_net1(states_v).gather(1, actions_v).squeeze(1)        # [B]

            # 4) Compute target Q-values
            with torch.no_grad():
                # max over next actions (masking illegal moves if desired)
                next_q = target_net1(next_states_v).max(1)[0]                     # [B]
                target_q = rewards_v + gamma * next_q * (1 - dones_v.float())   # [B]

            # 5) Compute loss & backpropagate
            loss1 = nn.MSELoss()(q_vals, target_q)

            optimizer1.zero_grad()
            loss1.backward()
            optimizer1.step()
            loss1Val = loss1.item()

          # Sample and learn
          if len(replay_buffer2) >= batch_size:
            # 1) Sample a batch
            states_b, actions_b, rewards_b, next_states_b, dones_b = replay_buffer2.sample(batch_size)

            # 2) Convert to tensors
            states_v      = torch.tensor(states_b,      dtype=torch.float32, device=device).unsqueeze(1)   # [B,1,6,7]
            actions_v     = torch.tensor(actions_b,     dtype=torch.int64,   device=device).unsqueeze(1)   # [B,1]
            rewards_v     = torch.tensor(rewards_b,     dtype=torch.float32, device=device)               # [B]
            next_states_v = torch.tensor(next_states_b, dtype=torch.float32, device=device).unsqueeze(1)   # [B,1,6,7]
            dones_v       = torch.tensor(dones_b,       dtype=torch.uint8,   device=device)               # [B]

            # 3) Compute current Q-values
            q_vals = policy_net2(states_v).gather(1, actions_v).squeeze(1)        # [B]

            # 4) Compute target Q-values
            with torch.no_grad():
                # max over next actions (masking illegal moves if desired)
                next_q = target_net2(next_states_v).max(1)[0]                     # [B]
                target_q = rewards_v + gamma * next_q * (1 - dones_v.float())   # [B]

            # 5) Compute loss & backpropagate
            loss2 = nn.MSELoss()(q_vals, target_q)

            optimizer2.zero_grad()
            loss2.backward()
            optimizer2.step()
            loss2Val = loss2.item()

            # Periodically sync target network
            if steps_done % sync_target_steps == 0:
                target_net1.load_state_dict(policy_net1.state_dict())
                target_net2.load_state_dict(policy_net2.state_dict())

      writer.add_scalar("Rewards/Agent1", total_reward1, episode)
      writer.add_scalar("Rewards/Agent2", total_reward2, episode)

      writer.add_scalar("Loss/Agent1", loss1Val, episode)
      writer.add_scalar("Loss/Agent2", loss2Val, episode)

      writer.add_scalar("Epsilon", epsilon, episode)

      print(f"Episode {episode} - R1: {total_reward1} - R2: {total_reward2} - E: {epsilon} - A {action}")

      if episode % 100 == 0:  # every 100 episodes
        torch.save(policy_net1.state_dict(),
                  f"checkpoints/c4_dqn_{timestamp}/policy_net1_ep{episode:04d}.pth")
        torch.save(policy_net2.state_dict(),
                  f"checkpoints/c4_dqn_{timestamp}/policy_net2_ep{episode:04d}.pth")
        print(f"  → Saved models at episode {episode}")
  episode += 1
  torch.save(policy_net1.state_dict(),
            f"checkpoints/c4_dqn_{timestamp}/policy_net1_ep{episode:04d}.pth")
  torch.save(policy_net2.state_dict(),
            f"checkpoints/c4_dqn_{timestamp}/policy_net2_ep{episode:04d}.pth")
  print(f"  → Saved models at episode {episode}")
  writer.close()

## UI Class

In [23]:
class Connect4UI:
    def __init__(self, model_p1, model_p2, ai_player=-1):
        # Initialize the game environment
        self.env = BoardEnv()
        self.model_p1 = model_p1  # Model for player 1 (X)
        self.model_p2 = model_p2  # Model for player -1 (O)
        self.ai_player = ai_player  # -1 means AI plays as O, 1 means AI plays as X
        self.model = self.model_p2 if ai_player == -1 else self.model_p1  # Current model

        # Initialize state first
        self.state = self.env.reset()
        self.done = False
        
        # Create UI elements
        self.create_ui()

        # Complete game state initialization
        self.update_display()

        # If AI goes first, make its move
        if self.env.current_player == self.ai_player:
            time.sleep(0.5)
            self.make_ai_move()

    def initialize_game_state(self):
        """Initialize the game state based on who starts first"""
        # Reset game
        self.state = self.env.reset()
        self.done = False

        # Update display
        self.update_display()

        # If AI goes first, make its move
        if self.env.current_player == self.ai_player:
            time.sleep(0.5)
            self.make_ai_move()

    def create_ui(self):
        # Title
        self.title = widgets.HTML(value="<h1 style='text-align: center;'>Connect 4</h1>")

        # Status message
        self.status = widgets.HTML(value="<h3 style='text-align: center;'>Game ready! Make your move</h3>")

        # Create buttons for each column
        self.buttons = []
        for col in range(7):
            btn = widgets.Button(description=str(col),
                                layout=widgets.Layout(width='60px', height='40px'))
            btn.on_click(lambda b, col=col: self.make_move(col))
            self.buttons.append(btn)

        # Button container (top row)
        self.button_container = widgets.HBox(self.buttons,
                                           layout=widgets.Layout(justify_content='center'))

        # Game board display
        self.board_display = widgets.HTML(value=self.render_board_html())

        # Add Read Board button
        self.read_board_button = widgets.Button(
            description="Read Board",
            button_style='info',
            layout=widgets.Layout(width='150px')
        )
        self.read_board_button.on_click(self.read_board_aloud)
        
        # Who starts selector
        self.player_options = [('You start (X)', 1), ('AI starts (O)', -1)]
        self.player_starter = widgets.RadioButtons(
            options=self.player_options,
            value=-1,  # Default to AI starting
            description='New Game:',
            layout=widgets.Layout(width='300px')
        )

        # New Game button
        self.new_game_button = widgets.Button(
            description="Start New Game",
            button_style='primary',
            layout=widgets.Layout(width='150px')
        )
        self.new_game_button.on_click(self.start_new_game)

        # Game controls
        self.game_controls = widgets.HBox([
            self.player_starter,
            self.new_game_button,
            self.read_board_button  # Add the Read Board button to the controls
        ], layout=widgets.Layout(justify_content='center', margin='20px 0'))

        # Add file upload widget for voice commands
        self.file_upload = widgets.FileUpload(
            accept='',  # Accept all file types
            multiple=False,  # Only allow single file upload
            description='Voice Command:',
            layout=widgets.Layout(width='250px')
        )
        self.file_upload.observe(self.handle_file_upload, names='value')

        # Add submit button for processing the uploaded file
        self.submit_button = widgets.Button(
            description="Process Command",
            button_style='success',
            layout=widgets.Layout(width='150px')
        )
        self.submit_button.on_click(self.process_audio_command)

        # Audio controls
        self.audio_controls = widgets.HBox([
            self.file_upload,
            self.submit_button
        ], layout=widgets.Layout(justify_content='center', margin='10px 0'))

        # Add upload status indicator
        self.upload_status = widgets.HTML(value="<p>No file uploaded</p>")

        # Add status for speech synthesis
        self.speech_status = widgets.HTML(value="")

        # Combine all widgets
        self.app = widgets.VBox([
            self.title,
            self.status,
            self.button_container,
            self.board_display,
            self.game_controls,
            self.audio_controls,
            self.upload_status,
            self.speech_status
        ], layout=widgets.Layout(width='100%', align_items='center'))

        # Display the UI
        display(self.app)
        
        # Add JavaScript for text-to-speech functionality
        display(HTML("""
        <script>
        function speakText(text) {
            if ('speechSynthesis' in window) {
                const utterance = new SpeechSynthesisUtterance(text);
                utterance.rate = 1.0;  // Speech rate
                utterance.pitch = 1.0; // Speech pitch
                window.speechSynthesis.cancel(); // Cancel any ongoing speech
                window.speechSynthesis.speak(utterance);
                return "Speaking...";
            } else {
                return "Text-to-speech not supported in this browser.";
            }
        }
        
        // Make the function available to Python
        window.speakText = speakText;
        </script>
        """))

    def read_board_aloud(self, _=None):
        """Convert board state to spoken text and read it aloud"""
        board_text = self.generate_board_description()
        
        # Use JavaScript to speak the text
        js_code = f"""
        var result = "";
        if (typeof window.speakText === 'function') {{
            result = window.speakText("{board_text}");
        }} else {{
            result = "Text-to-speech function not available.";
        }}
        result;
        """
        
        # Execute the JavaScript to speak the text
        try:
            from IPython.display import Javascript
            display(Javascript(js_code))
            self.speech_status.value = "<p>Reading board state aloud...</p>"
            
            # Clear the status after 3 seconds
            def clear_status():
                time.sleep(3)
                self.speech_status.value = ""
            
            import threading
            threading.Thread(target=clear_status).start()
            
        except Exception as e:
            self.speech_status.value = f"<p style='color:red;'>Error with text-to-speech: {str(e)}</p>"

    def generate_board_description(self):
        """Generate a textual description of the board state"""
        # Get the board from state
        board = self.state["board"]
        rows, cols = board.shape
        
        # Start with the game status
        if hasattr(self, 'done') and self.done:
            if self._check_winner(1):
                winner = 1
            elif self._check_winner(-1):
                winner = -1
            else:
                winner = 0
                
            if winner == 1:
                status = "Player X has won. " if self.ai_player == -1 else "AI has won. "
            elif winner == -1:
                status = "AI has won. " if self.ai_player == -1 else "Player X has won. "
            else:
                status = "The game is a draw. "
        else:
            human_player = -self.ai_player
            if self.env.current_player == human_player:
                player_name = "Your"
                player_symbol = "X" if human_player == 1 else "O"
            else:
                player_name = "AI's"
                player_symbol = "X" if self.ai_player == 1 else "O"
                
            status = f"It is {player_name} turn with {player_symbol}. "
        
        # Describe the board
        board_desc = "Board state: "
        
        # Count pieces by column
        for col in range(cols):
            pieces = []
            for row in range(rows-1, -1, -1):  # Start from bottom row
                if board[row, col] == 1:
                    pieces.append("X")
                elif board[row, col] == -1:
                    pieces.append("O")
            
            if pieces:
                board_desc += f"Column {col} has {len(pieces)} pieces: {', '.join(pieces)} from bottom to top. "
            else:
                board_desc += f"Column {col} is empty. "
        
        # Escape quotes and special characters
        full_description = status + board_desc
        full_description = full_description.replace('"', '\\"').replace("'", "\\'")
        
        return full_description

    def handle_file_upload(self, change):
        """Handle file upload event"""
        if change['new']:
            try:
                # Check if change['new'] is a tuple or a dictionary
                if isinstance(change['new'], tuple):
                    # If it's a tuple, extract the first element
                    uploaded_file = change['new'][0]
                else:
                    # If it's a dictionary, use the original code
                    uploaded_file = next(iter(change['new'].values()))
                
                # Check if 'metadata' exists in the structure
                if 'metadata' in uploaded_file and 'name' in uploaded_file['metadata']:
                    filename = uploaded_file['metadata']['name']
                    # Just acknowledge the upload
                    self.upload_status.value = f"<p>File uploaded: {filename}</p>"
                else:
                    # Handle case where metadata or name is missing
                    self.upload_status.value = f"<p>File uploaded successfully</p>"
                
                self.upload_status.value += f"<p>Click 'Process Command' to execute the command.</p>"
            except Exception as e:
                # Fallback for any unexpected structure
                self.upload_status.value = f"<p>File uploaded, but couldn't read file details: {str(e)}</p>"
                self.upload_status.value += f"<p>Click 'Process Command' to execute the command.</p>"

    def process_audio_command(self, _=None):
        """Process an uploaded audio file and dispatch the spoken command."""
        try:
            # ── 1. Validate upload ────────────────────────────────────────────────
            if not self.file_upload.value:
                self.upload_status.value = (
                    "<p style='color:orange;'>Please upload an audio file first.</p>"
                )
                return

            try:
                # Check if file_upload.value is a tuple or a dictionary
                if isinstance(self.file_upload.value, tuple):
                    # If it's a tuple, extract the first element
                    uploaded = self.file_upload.value[0]
                else:
                    # If it's a dictionary, use the original code
                    uploaded = next(iter(self.file_upload.value.values()))
                
                # Check if 'content' exists
                if 'content' not in uploaded:
                    self.upload_status.value = (
                        "<p style='color:red;'>Invalid file format: missing content</p>"
                    )
                    return
                    
                raw_bytes = uploaded["content"]
                
                # Try to get filename but provide default if not available
                fname = "uploaded_audio"
                if 'metadata' in uploaded and 'name' in uploaded['metadata']:
                    fname = uploaded['metadata']['name']
            except Exception as e:
                self.upload_status.value = (
                    f"<p style='color:red;'>Error reading file: {str(e)}</p>"
                )
                return

            self.upload_status.value = "<p style='color:blue;'>Processing audio file...</p>"

            # You need to import these libraries in your notebook
            # If these imports are failing, install the libraries first:
            # !pip install SpeechRecognition pydub
            try:
                import speech_recognition as sr
                from pydub import AudioSegment
            except ImportError as e:
                self.upload_status.value = (
                    "<p style='color:red;'>Missing required Python libraries. Please run the following in a cell:</p>"
                    "<pre>!pip install SpeechRecognition pydub</pre>"
                )
                return

            # ── 2. Convert to mono-WAV in-memory (handles mp3, wav, m4a, etc.) ────
            try:
                # This will fail if ffmpeg/ffprobe is not installed
                audio = AudioSegment.from_file(io.BytesIO(raw_bytes))
                audio = audio.set_frame_rate(16_000).set_channels(1)
                with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wav_tmp:
                    audio.export(wav_tmp.name, format="wav")
                    wav_path = wav_tmp.name
            except FileNotFoundError as err:
                if 'ffprobe' in str(err) or 'ffmpeg' in str(err):
                    self.upload_status.value = (
                        "<p style='color:red;'>Missing FFmpeg tools. This feature requires FFmpeg to be installed.</p>"
                        "<p>Please install FFmpeg by running the following command in a cell:</p>"
                        "<pre>!apt-get update && apt-get install -y ffmpeg</pre>"
                        "<p>If using Google Colab, run:</p>"
                        "<pre>!apt-get update && apt-get install -y ffmpeg</pre>"
                        "<p>If using a local environment, install FFmpeg using your package manager, e.g.:</p>"
                        "<p>- Ubuntu/Debian: <code>sudo apt install ffmpeg</code></p>"
                        "<p>- macOS: <code>brew install ffmpeg</code></p>"
                        "<p>- Windows: <a href='https://ffmpeg.org/download.html'>Download from ffmpeg.org</a></p>"
                    )
                else:
                    self.upload_status.value = (
                        f"<p style='color:red;'>Error: {str(err)}</p>"
                    )
                return
            except Exception as exc:
                self.upload_status.value = (
                    "<p style='color:red;'>Cannot process audio \"" + fname + "\": " + str(exc) + "</p>"
                )
                return

            # ── 3. Speech-to-text (Google Web API) ────────────────────────────────
            try:
                recog = sr.Recognizer()
                with sr.AudioFile(wav_path) as source:
                    # remove ambient-noise adjustment for prerecorded files
                    audio_data = recog.record(source)      # grab the whole file
                text = recog.recognize_google(audio_data, language='en-US').lower()
                self.upload_status.value = f"<p>Heard: \"{text}\"</p>"
            except sr.UnknownValueError:
                self.upload_status.value = (
                    "<p style='color:red;'>Sorry, I couldn't understand that.</p>"
                )
                return
            except sr.RequestError as exc:
                self.upload_status.value = (
                    "<p style='color:red;'>Speech-service error: " + str(exc) + "</p>"
                )
                return
            finally:
                # Clean temp file
                with suppress(FileNotFoundError):
                    os.remove(wav_path)

            # ── 4. Command routing ────────────────────────────────────────────────
            digit_words = {
                "zero": 0,
                "one": 1,
                "two": 2,
                "three": 3,
                "four": 4,
                "five": 5,
                "six": 6,
            }

            # Add handling for "read board" command
            if re.search(r"\bread\b.*\bboard\b", text):
                self.upload_status.value += "<p>Reading board state...</p>"
                self.read_board_aloud(None)
                return

            # new-game
            if re.search(r"\bnew\b.*\bgame\b", text):
                self.upload_status.value += "<p>Starting a new game!</p>"
                self.start_new_game(None)
                return

            # column n / just n
            # 1️⃣ match "column three", "col 3", "place in five" …
            col_match = re.search(r"\b(col(?:umn)?|place(?:\sin)?)\s*(\w+)", text)
            word_or_digit = None
            if col_match:
                word_or_digit = col_match.group(2)
            else:
                # 2️⃣ plain "three" / "3"
                word_or_digit = text.strip()

            # map to integer 0-6
            if word_or_digit in digit_words:
                col = digit_words[word_or_digit]
            elif word_or_digit.isdigit() and 0 <= int(word_or_digit) <= 6:
                col = int(word_or_digit)
            else:
                self.upload_status.value += (
                    "<p style='color:orange;'>Command not recognized. "
                    "Say e.g. \"column three\", \"new game\", or \"read board\".</p>"
                )
                return

            # ── 5. Execute move ───────────────────────────────────────────────────
            self.upload_status.value += f"<p>Placing piece in column {col}</p>"
            self.make_move(col)

            # ── 6. Reset uploader for next use ────────────────────────────────────
            try:
                if isinstance(self.file_upload.value, tuple):
                    self.file_upload.value = ()  # Clear tuple
                else:
                    self.file_upload.value.clear()  # Clear dictionary
            except Exception as e:
                # Just ignore errors in clearing
                pass

        except ImportError as e:
            self.upload_status.value = (
                "<p style='color:red;'>Missing required libraries: " + str(e) + ". "
                "Please install the required libraries using !pip install.</p>"
            )
        except Exception as e:
            self.upload_status.value = (
                "<p style='color:red;'>Error processing audio: " + str(e) + "</p>"
            )

    def start_new_game(self, b):
        """Start a new game with the selected starting player"""
        # Get who starts from the radio buttons
        selected_value = self.player_starter.value

        # Figure out who the AI player is based on selected value
        self.ai_player = -selected_value
        
        # Assign the right model based on which player the AI is
        self.model = self.model_p2 if self.ai_player == -1 else self.model_p1
        
        # Reset the game
        self.state = self.env.reset()
        self.done = False

        # Re-enable buttons
        for btn in self.buttons:
            btn.disabled = False

        # Update the display
        self.update_display()

        # If AI goes first, make its move
        if self.env.current_player == self.ai_player:
            time.sleep(0.5)
            self.make_ai_move()

    def render_board_html(self):
        """Render the Connect 4 board as HTML for display"""
        html = """
        <style>
        .board {
            background-color: #0052cc;
            display: inline-block;
            padding: 10px;
            border-radius: 10px;
        }
        .cell {
            width: 60px;
            height: 60px;
            background-color: #ffffff;
            border-radius: 50%;
            display: inline-block;
            margin: 5px;
        }
        .player1 {
            background-color: #ff0000;
        }
        .player-1 {
            background-color: #ffff00;
        }
        </style>
        <div class="board">
        """

        # Get board from state
        board = self.state["board"]
        rows, cols = board.shape
        
        for row in range(rows):
            html += "<div>"
            for col in range(cols):
                cell_value = board[row, col]
                cell_class = f"cell player{cell_value}" if cell_value != 0 else "cell"
                html += f'<div class="{cell_class}"></div>'
            html += "</div>"

        html += "</div>"
        return html

    def update_display(self):
        """Update the board display and status message"""
        self.board_display.value = self.render_board_html()

        # Check game status
        if hasattr(self, 'done') and self.done:
            # Game is over
            if self._check_winner(1):
                winner = 1
            elif self._check_winner(-1):
                winner = -1
            else:
                winner = 0  # Draw

            if winner == 1:
                message = "You win! 🎉" if self.ai_player == -1 else "AI wins! 🤖"
                color = "green" if self.ai_player == -1 else "red"
            elif winner == -1:
                message = "AI wins! 🤖" if self.ai_player == -1 else "You win! 🎉"
                color = "red" if self.ai_player == -1 else "green"
            else:
                message = "Draw game! 🤝"
                color = "blue"

            self.status.value = f"<h3 style='text-align: center; color: {color};'>{message}</h3>"

            # Disable column buttons
            for btn in self.buttons:
                btn.disabled = True
        else:
            # Game is ongoing
            current_player = self.env.current_player
            human_player = -self.ai_player

            if current_player == human_player:
                player_name = "Your"
                player_symbol = "(X)" if human_player == 1 else "(O)"
            else:
                player_name = "AI's"
                player_symbol = "(X)" if self.ai_player == 1 else "(O)"

            self.status.value = f"<h3 style='text-align: center;'>{player_name} turn {player_symbol}</h3>"

    def valid_actions(self):
        """Helper method to get valid actions from BoardEnv"""
        # Add this method to simulate the Connect4Env's valid_actions method
        return [c for c in range(7) if self.env._is_valid_action(c)]

    def make_move(self, column):
        """Handle player's move on column click"""
        if hasattr(self, 'done') and self.done:
            # Game is already over
            return

        # Determine if it's human's turn
        human_player = -self.ai_player
        if self.env.current_player != human_player:
            self.status.value = "<h3 style='text-align: center; color: orange;'>Not your turn!</h3>"
            return

        # Check if move is valid
        if not self.env._is_valid_action(column):  # Changed from valid_actions to _is_valid_action
            self.status.value = "<h3 style='text-align: center; color: orange;'>Invalid move! Column is full</h3>"
            return

        # Make the move
        next_state, reward, done, _ = self.env.step(column)
        self.state = next_state
        self.done = done

        # Update the display
        self.update_display()

        # If game not over and AI's turn, make AI move
        if not done and self.env.current_player == self.ai_player:
            # Add a small delay for better UX
            time.sleep(0.5)
            self.make_ai_move()

    def make_ai_move(self):
        """Make an AI move using the model"""
        if hasattr(self, 'done') and self.done:
            return

        # Use the model to select an action
        # Prepare tensor
        state_tensor = torch.tensor(self.state["board"], dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0)
        
        with torch.no_grad():
            qvals = self.model(state_tensor)
            # Mask out illegal moves
            illegal_mask = get_illegal_moves_mask(self.state["board"])
            qvals[0][illegal_mask] = -1e9
            action = qvals.argmax(dim=1).item()

        # Make the move
        next_state, reward, done, _ = self.env.step(action)
        self.state = next_state
        self.done = done

        # Update the display
        self.update_display() 

    def _check_winner(self, player):
        """Check if given player has won by having 4 in a row anywhere on the board"""
        # Can only win if player is 1 (X) or -1 (O)
        if player == 0:
            return False
            
        board = self.state["board"]  # Use self.state["board"] instead of self.env.state["board"]
        rows, cols = board.shape
    
        # Horizontal check
        for r in range(rows):
            for c in range(cols - 3):
                if (board[r, c] == player and 
                    board[r, c+1] == player and 
                    board[r, c+2] == player and 
                    board[r, c+3] == player):
                    return True
        
        # Vertical check
        for r in range(rows - 3):
            for c in range(cols):
                if (board[r, c] == player and 
                    board[r+1, c] == player and 
                    board[r+2, c] == player and 
                    board[r+3, c] == player):
                    return True
        
        # Diagonal down-right
        for r in range(rows - 3):
            for c in range(cols - 3):
                if (board[r, c] == player and 
                    board[r+1, c+1] == player and 
                    board[r+2, c+2] == player and 
                    board[r+3, c+3] == player):
                    return True
        
        # Diagonal up-right
        for r in range(3, rows):
            for c in range(cols - 3):
                if (board[r, c] == player and 
                    board[r-1, c+1] == player and 
                    board[r-2, c+2] == player and 
                    board[r-3, c+3] == player):
                    return True
        
        return False

def start_connect4_ui(model_p1, model_p2, ai_player=-1):
    """
    Start the Connect4 UI with the trained models

    Args:
        model_p1: The trained DQN model for player 1
        model_p2: The trained DQN model for player 2
        ai_player: The player ID for the AI (-1 or 1)
    """
    ui = Connect4UI(model_p1, model_p2, ai_player=ai_player)
    return ui

## Actions

### Plot Metrics (WE NEED TO MODIFY THE PATH WE LOAD FROM)

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs

### Play (WE NEED TO MODIFY THE PATH WE LOAD FROM)

In [24]:
def start_connect4_ui(model_p1, model_p2, ai_player=-1):
    """
    Start the Connect4 UI with the trained models

    Args:
        model_p1: The trained DQN model for player 1
        model_p2: The trained DQN model for player 2
        ai_player: The player ID for the AI (-1 or 1)
    """
    ui = Connect4UI(model_p1, model_p2, ai_player=ai_player)
    return ui

# Load your trained models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_p1 = ConvQNetwork(input_shape=(1,6,7), num_actions=7).to(device)
model_p2 = ConvQNetwork(input_shape=(1,6,7), num_actions=7).to(device)

# Load saved model weights
model_p1.load_state_dict(torch.load("checkpoints/c4_dqn_20250511_143151/policy_net1_ep0500.pth", map_location=device))
model_p2.load_state_dict(torch.load("checkpoints/c4_dqn_20250511_143151/policy_net2_ep0500.pth", map_location=device))

# Set models to evaluation mode
model_p1.eval()
model_p2.eval()

ui = start_connect4_ui(model_p1, model_p2, ai_player=-1)


  model_p1.load_state_dict(torch.load("checkpoints/c4_dqn_20250511_143151/policy_net1_ep0500.pth", map_location=device))
  model_p2.load_state_dict(torch.load("checkpoints/c4_dqn_20250511_143151/policy_net2_ep0500.pth", map_location=device))


VBox(children=(HTML(value="<h1 style='text-align: center;'>Connect 4</h1>"), HTML(value="<h3 style='text-align…

<IPython.core.display.Javascript object>

## 6. Extensions and Next Steps

- Define a real environment and replace the placeholder reward function.
- Implement action selection strategies (epsilon-greedy, softmax).  
- Add saving/loading model checkpoints.  
- Incorporate advanced techniques: Double DQN, Prioritized Experience Replay, Dueling Networks.
