Importing Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
import matplotlib.pyplot as plt
import pickle

In [None]:
ROWS = 6
COLS = 7
WIN_REWARD = 10
LOSS_PENALTY = -10
THREE_IN_ROW_REWARD = 1
MISS_BLOCK_PENALTY = -3
MISS_WIN_PENALTY = -3

In [None]:
def self_play(agent1, agent2, env, episodes=1000, batch_size=32):
    metrics = {
        'episodes': [],
        'agent1_wins': [],
        'agent2_wins': [],
        'draws': [],
        'avg_rewards_1': [],
        'avg_rewards_2': []
    }

    for episode in range(episodes):
        state = env.reset()
        done = False
        current_agent = agent1
        opponent = agent2
        agent1_rewards = 0
        agent2_rewards = 0

        while not done:
            action = current_agent.select_action(state)
            env.make_move(action)

            # Count threes for the current agent and the opponent
            current_threes = env.count_threes(current_agent.env.current_player)
            opponent_threes = env.count_threes(opponent.env.current_player)

            # Apply three-in-a-row reward
            reward = (current_threes * THREE_IN_ROW_REWARD)

            # Check for missed winning move
            if current_threes > 0 and not env.check_winner():
                reward -= MISS_WIN_PENALTY

            # Check for missed block (if opponent had a three and it's not blocked)
            if opponent_threes > 0 and not env.check_winner():
                reward -= MISS_BLOCK_PENALTY

            next_state = env.get_state()
            winner = env.check_winner()

            if winner == current_agent.env.current_player:
                reward += WIN_REWARD
                done = True
                if current_agent == agent1:
                    metrics['agent1_wins'].append(1)
                    metrics['agent2_wins'].append(0)
                    metrics['draws'].append(0)
                else:
                    metrics['agent1_wins'].append(0)
                    metrics['agent2_wins'].append(1)
                    metrics['draws'].append(0)
            elif winner == 0:
                done = True
                reward -= 5  # Mild penalty for draws
                metrics['agent1_wins'].append(0)
                metrics['agent2_wins'].append(0)
                metrics['draws'].append(1)

            if current_agent == agent1:
                agent1_rewards += reward
            else:
                agent2_rewards += reward

            current_agent.memory.push((state, action, reward, next_state, done))
            current_agent.train_step(batch_size)

            state = next_state
            env.switch_player()
            current_agent, opponent = opponent, current_agent

        metrics['episodes'].append(episode)
        metrics['avg_rewards_1'].append(agent1_rewards)
        metrics['avg_rewards_2'].append(agent2_rewards)

        if episode % 100 == 0:
            print(f"Episode {episode}/{episodes} - Agent 1 Wins: {metrics['agent1_wins'][-1]}, Agent 2 Wins: {metrics['agent2_wins'][-1]}, Draws: {metrics['draws'][-1]}")

        if episode %5000 == 0:
            save_model(agent1, f"agent1_{episode}.pth")
            save_model(agent2, f"agent2_{episode}.pth")

    return metrics


In [None]:
def plot_metrics(metrics):
    plt.figure(figsize=(10, 5))
    plt.plot(metrics['episodes'], metrics['agent1_wins'], label='Agent 1 Wins')
    plt.plot(metrics['episodes'], metrics['agent2_wins'], label='Agent 2 Wins')
    plt.plot(metrics['episodes'], metrics['draws'], label='Draws')
    plt.plot(metrics['episodes'], metrics['avg_rewards_1'], label='Agent 1 Avg Rewards')
    plt.plot(metrics['episodes'], metrics['avg_rewards_2'], label='Agent 2 Avg Rewards')
    plt.xlabel('Episodes')
    plt.ylabel('Games')
    plt.legend()
    plt.show()

In [None]:
def play_against_agent(agent, env):
    print("You are Player 1 (X). Agent is Player 2 (O).")
    state = env.reset()
    env.render()

    while True:
        col = int(input("Enter your move (0-6): "))
        if not env.make_move(col):
            print("Invalid move. Try again.")
            continue
        
        env.render()
        if env.check_winner() == 1:
            print("You win!")
            break
        elif env.check_winner() == 0:
            print("It's a draw!")
            break

        env.switch_player()
        print("Agent is thinking...")
        action = agent.select_action(env.get_state())
        env.make_move(action)
        env.render()

        if env.check_winner() == 2:
            print("Agent wins!")
            break
        elif env.check_winner() == 0:
            print("It's a draw!")
            break

        env.switch_player()

In [2]:
def save_model(agent, path):
    torch.save({
        'model_state_dict': agent.model.state_dict(),
        'target_model_state_dict': agent.target_model.state_dict(),
        'optimizer_state_dict': agent.optimizer.state_dict(),
        'epsilon': agent.epsilon
    }, path)
    print(f"Model saved at {path}")


def load_model(agent, path):
    checkpoint = torch.load(path)
    agent.model.load_state_dict(checkpoint['model_state_dict'])
    agent.target_model.load_state_dict(checkpoint['target_model_state_dict'])
    agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    agent.epsilon = checkpoint['epsilon']
    print(f"Model loaded from {path}")