<a href="https://colab.research.google.com/github/Mootha-sri-harshit/tic-tac-toe_agent/blob/main/tic_tac_toe_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque

# ==================== Tic-Tac-Toe Environment ====================
class TicTacToe:
    def __init__(self):
        self.board = [' '] * 9
        self.win_combinations = [
            (0,1,2), (3,4,5), (6,7,8), (0,3,6),
            (1,4,7), (2,5,8), (0,4,8), (2,4,6)
        ]

    def reset(self):
        self.board = [' '] * 9
        return tuple(self.board)

    def available_moves(self):
        return [i for i, spot in enumerate(self.board) if spot == ' ']

    def make_move(self, position, player):
        if self.board[position] == ' ':
            self.board[position] = player
            return True
        return False

    def check_winner(self):
        for combo in self.win_combinations:
            a, b, c = combo
            if self.board[a] == self.board[b] == self.board[c] != ' ':
                return self.board[a]
        return "Draw" if ' ' not in self.board else None

    def print_board(self):
        for i in range(0, 9, 3):
            print(f"{self.board[i]} | {self.board[i+1]} | {self.board[i+2]}")
        print("-" * 9)

# ==================== Noisy Network Layer ====================
class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, std_init=0.4):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.std_init = std_init

        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.Tensor(out_features, in_features))
        self.register_buffer('weight_epsilon', torch.Tensor(out_features, in_features))

        self.bias_mu = nn.Parameter(torch.Tensor(out_features))
        self.bias_sigma = nn.Parameter(torch.Tensor(out_features))
        self.register_buffer('bias_epsilon', torch.Tensor(out_features))

        self.reset_parameters()
        self.reset_noise()

    def reset_parameters(self):
        mu_range = 1 / np.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.std_init / np.sqrt(self.in_features))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.std_init / np.sqrt(self.out_features))

    def reset_noise(self):
        epsilon_in = self._scale_noise(self.in_features)
        epsilon_out = self._scale_noise(self.out_features)
        self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)

    def _scale_noise(self, size):
        x = torch.randn(size)
        return x.sign().mul(x.abs().sqrt())

    def forward(self, x):
        if self.training:
            return F.linear(x,
                self.weight_mu + self.weight_sigma * self.weight_epsilon,
                self.bias_mu + self.bias_sigma * self.bias_epsilon)
        else:
            return F.linear(x, self.weight_mu, self.bias_mu)

# ==================== DQN Architecture ====================
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            NoisyLinear(9, 256),
            nn.ReLU(),
            NoisyLinear(256, 256),
            nn.ReLU(),
            NoisyLinear(256, 9)
        )

    def forward(self, x):
        return self.net(x)

    def reset_noise(self):
        for layer in self.net:
            if isinstance(layer, NoisyLinear):
                layer.reset_noise()

# ==================== DQN Agent ====================
class DQNAgent:
    def __init__(self):
        self.policy_net = DQN()
        self.target_net = DQN()
        self.target_net.load_state_dict(self.policy_net.state_dict())

        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=5e-4, weight_decay=1e-5)
        self.memory = deque(maxlen=50000)
        self.batch_size = 128
        self.gamma = 0.97
        self.train_step = 0

    def get_state_tensor(self, state):
        return torch.FloatTensor([
            1 if cell == 'X' else -1 if cell == 'O' else 0
            for cell in state
        ])

    def choose_action(self, state, available_moves):
        state_tensor = self.get_state_tensor(state).unsqueeze(0)
        with torch.no_grad():
            q_values = self.policy_net(state_tensor)

        # Boltzmann exploration with dynamic temperature
        temp = max(0.5, 2.0 - self.train_step/10000)
        mask = torch.full((9,), -np.inf)
        mask[available_moves] = 0
        probs = F.softmax((q_values + mask) / temp, dim=1)
        return torch.multinomial(probs, 1).item()

    def store_experience(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def update_model(self):
        if len(self.memory) < self.batch_size:
            return

        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.stack([self.get_state_tensor(s) for s in states])
        actions = torch.tensor(actions)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        next_states = torch.stack([self.get_state_tensor(s) for s in next_states])
        dones = torch.tensor(dones, dtype=torch.bool)

        # Double DQN with target network
        current_q = self.policy_net(states).gather(1, actions.unsqueeze(1))
        with torch.no_grad():
            next_actions = self.policy_net(next_states).argmax(1)
            next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(1))

        target_q = rewards + (1 - dones.float()) * self.gamma * next_q.squeeze()
        loss = F.smooth_l1_loss(current_q.squeeze(), target_q)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 5.0)
        self.optimizer.step()

        # Update target network and reset noise
        self.train_step += 1
        if self.train_step % 250 == 0:
            for target_param, policy_param in zip(self.target_net.parameters(),
                                                 self.policy_net.parameters()):
                target_param.data.copy_(0.005*policy_param.data + 0.995*target_param.data)

        self.policy_net.reset_noise()
        self.target_net.reset_noise()

    def save(self, path):
        torch.save({
            'policy_state': self.policy_net.state_dict(),
            'target_state': self.target_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'train_step': self.train_step,
            'memory': list(self.memory)
        }, path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.policy_net.load_state_dict(checkpoint['policy_state'])
        self.target_net.load_state_dict(checkpoint['target_state'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.train_step = checkpoint['train_step']
        self.memory = deque(checkpoint['memory'], maxlen=50000)

# ==================== Training and Evaluation ====================
def find_blocking_move(env, player):
    opponent = 'X' if player == 'O' else 'O'
    for combo in env.win_combinations:
        cells = [env.board[i] for i in combo]
        if cells.count(opponent) == 2 and cells.count(' ') == 1:
            return combo[cells.index(' ')]
    return None

def train_agent(episodes=10000):
    env = TicTacToe()
    agent = DQNAgent()

    for episode in range(episodes):
        state = env.reset()
        done = False

        while not done:
            # Agent's move (X)
            avail_moves = env.available_moves()
            action = agent.choose_action(state, avail_moves)
            env.make_move(action, 'X')
            next_state = env.board.copy()
            winner = env.check_winner()

            # Reward shaping
            if winner == 'X':
                reward = 1 + 0.2*(9 - len(avail_moves))
                done = True
            elif winner == 'Draw':
                reward = 0.5
                done = True
            else:
                # Mixed opponent strategy
                opp_moves = env.available_moves()
                if np.random.rand() < 0.7:
                    opp_action = random.choice(opp_moves) if opp_moves else None
                else:
                    opp_action = find_blocking_move(env, 'O')

                if opp_action is not None:
                    env.make_move(opp_action, 'O')
                    winner = env.check_winner()
                    next_state = env.board.copy()

                    if winner == 'O':
                        reward = -1 - 0.2*(9 - len(opp_moves))
                        done = True
                    elif winner == 'Draw':
                        reward = 0.3
                        done = True
                    else:
                        reward = -0.05 * (9 - len(opp_moves))
                else:
                    reward = -0.2

            agent.store_experience(state, action, reward, next_state, done)
            state = next_state

            if len(agent.memory) > agent.batch_size:
                agent.update_model()

        # Save model periodically
        if episode % 500 == 0:
            agent.save("dqn_tic_tac_toe.pth")
            print(f"Episode {episode} | Model Saved")

    return agent

def play_vs_ai(agent_path="dqn_tic_tac_toe.pth"):
    agent = DQNAgent()
    agent.load(agent_path)

    env = TicTacToe()
    human = input("Choose X or O: ").upper().strip()
    while human not in ['X', 'O']:
        human = input("Invalid! Choose X/O: ").upper().strip()

    ai = 'O' if human == 'X' else 'X'
    current_player = 'X'
    state = env.reset()

    while True:
        env.print_board()
        if current_player == human:
            try:
                move = int(input("Your move (0-8): "))
                if move not in env.available_moves():
                    print("Invalid move!")
                    continue
            except ValueError:
                print("Enter a number 0-8")
                continue
        else:
            move = agent.choose_action(state, env.available_moves())
            print(f"AI chooses: {move}")

        env.make_move(move, current_player)
        winner = env.check_winner()

        if winner:
            env.print_board()
            print(f"Game Over! Winner: {winner}")
            break

        current_player = 'O' if current_player == 'X' else 'X'
        state = env.board.copy()

# ==================== Main Execution ====================
if __name__ == "__main__":
    # Train the agent (start with 1000 episodes for testing)
    trained_agent = train_agent(episodes=5000)

    # Play against the AI
    play_vs_ai()


Episode 0 | Model Saved
Episode 500 | Model Saved
Episode 1000 | Model Saved
Episode 1500 | Model Saved
Episode 2000 | Model Saved
Episode 2500 | Model Saved
Episode 3000 | Model Saved
Episode 3500 | Model Saved
Episode 4000 | Model Saved
Episode 4500 | Model Saved
Choose X or O: x
  |   |  
  |   |  
  |   |  
---------
Your move (0-8): 4
  |   |  
  | X |  
  |   |  
---------
AI chooses: 0
O |   |  
  | X |  
  |   |  
---------
Your move (0-8): 1
O | X |  
  | X |  
  |   |  
---------
AI chooses: 7
O | X |  
  | X |  
  | O |  
---------
Your move (0-8): 3
O | X |  
X | X |  
  | O |  
---------
AI chooses: 5
O | X |  
X | X | O
  | O |  
---------
Your move (0-8): 2
O | X | X
X | X | O
  | O |  
---------
AI chooses: 6
O | X | X
X | X | O
O | O |  
---------
Your move (0-8): 8
O | X | X
X | X | O
O | O | X
---------
Game Over! Winner: Draw


In [1]:
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque

# ==================== Tic-Tac-Toe Environment ====================
class TicTacToe:
    def __init__(self):
        self.board = [' '] * 9
        self.win_combinations = [
            (0,1,2), (3,4,5), (6,7,8), (0,3,6),
            (1,4,7), (2,5,8), (0,4,8), (2,4,6)
        ]

    def reset(self):
        self.board = [' '] * 9
        return tuple(self.board)

    def available_moves(self):
        return [i for i, spot in enumerate(self.board) if spot == ' ']

    def make_move(self, position, player):
        if self.board[position] == ' ':
            self.board[position] = player
            return True
        return False

    def check_winner(self):
        for combo in self.win_combinations:
            a, b, c = combo
            if self.board[a] == self.board[b] == self.board[c] != ' ':
                return self.board[a]
        return "Draw" if ' ' not in self.board else None

    def print_board(self):
        for i in range(0, 9, 3):
            print(f"{self.board[i]} | {self.board[i+1]} | {self.board[i+2]}")
        print("-" * 9)

# ==================== Noisy Network Layer ====================
class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, std_init=0.4):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.std_init = std_init

        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.Tensor(out_features, in_features))
        self.register_buffer('weight_epsilon', torch.Tensor(out_features, in_features))

        self.bias_mu = nn.Parameter(torch.Tensor(out_features))
        self.bias_sigma = nn.Parameter(torch.Tensor(out_features))
        self.register_buffer('bias_epsilon', torch.Tensor(out_features))

        self.reset_parameters()
        self.reset_noise()

    def reset_parameters(self):
        mu_range = 1 / np.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.std_init / np.sqrt(self.in_features))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.std_init / np.sqrt(self.out_features))

    def reset_noise(self):
        epsilon_in = self._scale_noise(self.in_features)
        epsilon_out = self._scale_noise(self.out_features)
        self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)

    def _scale_noise(self, size):
        x = torch.randn(size)
        return x.sign().mul(x.abs().sqrt())

    def forward(self, x):
        if self.training:
            return F.linear(x,
                self.weight_mu + self.weight_sigma * self.weight_epsilon,
                self.bias_mu + self.bias_sigma * self.bias_epsilon)
        else:
            return F.linear(x, self.weight_mu, self.bias_mu)

# ==================== DQN Architecture ====================
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            NoisyLinear(9, 256),
            nn.ReLU(),
            NoisyLinear(256, 256),
            nn.ReLU(),
            NoisyLinear(256, 9)
        )

    def forward(self, x):
        return self.net(x)

    def reset_noise(self):
        for layer in self.net:
            if isinstance(layer, NoisyLinear):
                layer.reset_noise()

# ==================== DQN Agent ====================
class DQNAgent:
    def __init__(self):
        self.policy_net = DQN()
        self.target_net = DQN()
        self.target_net.load_state_dict(self.policy_net.state_dict())

        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=5e-4, weight_decay=1e-5)
        self.memory = deque(maxlen=50000)
        self.batch_size = 128
        self.gamma = 0.97
        self.train_step = 0

    def get_state_tensor(self, state):
        return torch.FloatTensor([
            1 if cell == 'X' else -1 if cell == 'O' else 0
            for cell in state
        ])

    def choose_action(self, state, available_moves):
        state_tensor = self.get_state_tensor(state).unsqueeze(0)
        with torch.no_grad():
            q_values = self.policy_net(state_tensor)
        temp = max(0.5, 2.0 - self.train_step/10000)
        mask = torch.full((9,), -np.inf)
        mask[available_moves] = 0
        probs = F.softmax((q_values + mask) / temp, dim=1)
        return torch.multinomial(probs, 1).item()

    def save(self, path):
        torch.save({
            'policy_state': self.policy_net.state_dict(),
            'target_state': self.target_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'train_step': self.train_step,
            'memory': list(self.memory)
        }, path)

    def load(self, path):
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
        self.policy_net.load_state_dict(checkpoint['policy_state'])
        self.target_net.load_state_dict(checkpoint['target_state'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.train_step = checkpoint['train_step']
        self.memory = deque(checkpoint['memory'], maxlen=50000)

# ==================== Play Against AI ====================
def play_vs_ai(agent_path="dqn_tic_tac_toe.pth"):
    agent = DQNAgent()
    agent.load(agent_path)

    env = TicTacToe()
    human = input("Choose X or O: ").upper().strip()
    while human not in ['X', 'O']:
        human = input("Invalid! Choose X/O: ").upper().strip()

    ai = 'O' if human == 'X' else 'X'
    current_player = 'X'
    state = env.reset()

    while True:
        env.print_board()
        if current_player == human:
            try:
                move = int(input("Your move (0-8): "))
                if move not in env.available_moves():
                    print("Invalid move!")
                    continue
            except ValueError:
                print("Enter a number 0-8")
                continue
        else:
            move = agent.choose_action(state, env.available_moves())
            print(f"AI chooses: {move}")

        env.make_move(move, current_player)
        winner = env.check_winner()

        if winner:
            env.print_board()
            print(f"Game Over! Winner: {winner}")
            break

        current_player = 'O' if current_player == 'X' else 'X'
        state = env.board.copy()

# ==================== Main Execution ====================
play_vs_ai("dqn_tic_tac_toe.pth")


Choose X or O: x
  |   |  
  |   |  
  |   |  
---------
Your move (0-8): 1
  | X |  
  |   |  
  |   |  
---------
AI chooses: 0
O | X |  
  |   |  
  |   |  
---------
Your move (0-8): 4
O | X |  
  | X |  
  |   |  
---------
AI chooses: 7
O | X |  
  | X |  
  | O |  
---------
Your move (0-8): 2
O | X | X
  | X |  
  | O |  
---------
AI chooses: 6
O | X | X
  | X |  
O | O |  
---------
Your move (0-8): 3
O | X | X
X | X |  
O | O |  
---------
AI chooses: 5
O | X | X
X | X | O
O | O |  
---------
Your move (0-8): 8
O | X | X
X | X | O
O | O | X
---------
Game Over! Winner: Draw
