In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
%matplotlib inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class Connect4:
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.board = np.zeros((6, 7), dtype=np.int8)
        self.current_player = 1
        
    def make_move(self, col):
        if not self.is_valid_move(col):
            self.current_player = -1 if self.current_player == 1 else 1
            return -20, False  # Return a strong punishment

        # Make the move and check if it results in a win
        row = self.get_open_row(col)
        self.board[row][col] = self.current_player
        winner = self.check_winner()
        self.current_player = -1 if self.current_player == 1 else 1

        if winner == self.current_player:
            raise Exception("Impossible stuff")
        elif winner == 0:
            if self.board.all():
                return 0, True
            else:
                return 0, False
        else:
            # The current player won the game, so return a large positive reward
            return 10, True

    def is_valid_move(self, col):
        if col < 0 or col >= 7:
            return False

        if self.board[0][col] != 0:
            return False

        return True

    def get_open_row(self, col):
        rows = np.where(self.board[:,col] == 0)[0]
        if len(rows) == 0:
            return -1
        else:
            return rows[-1]

    def get_state(self):
        return np.reshape(self.board, (1, 42))

    def check_winner(self):
        # Check horizontal
        for r in range(6):
            row = self.board[r,:]
            for c in range(4):
                if row[c] != 0 and np.all(row[c] == row[c+1:c+4]):
                    return row[c]

        # Check vertical
        for c in range(7):
            col = self.board[:,c]
            for r in range(3):
                if col[r] != 0 and np.all(col[r] == col[r+1:r+4]):
                    return col[r]

        # Check diagonal
        for r in range(3):
            for c in range(4):
                if self.board[r][c] != 0 and np.all(np.array([self.board[r+i][c+i] for i in range(4)]) == self.board[r][c]):
                    return self.board[r][c]

        # Check other diagonal
        for r in range(3):
            for c in range(4):
                if self.board[r][c+3] != 0 and np.all(np.array([self.board[r+i][c+3-i] for i in range(4)]) == self.board[r][c+3]):
                    return self.board[r][c+3]

        return 0

    def is_game_over(self):
        if self.check_winner() != 0:
            return True

        return np.all(self.board != 0)

    def print_board(self):
        print("Connect 4")
        print("-----------------")
        for r in range(6):
            row = ""
            for c in range(7):
                if self.board[r][c] == 0:
                    row += " -"
                elif self.board[r][c] == 1:
                    row += " X"
                else:
                    row += " O"
            print(row)
        print("-----------------")
        print(" 0 1 2 3 4 5 6")
        print()
               
    def play_game(self):
        while not self.is_game_over():
            self.print_board()

            col = int(input("Player %d: Enter a column (0-6) to place your piece: " % self.current_player))
            print()
            if not self.make_move(col):
                print("Invalid move. Please try again.\n")
            

        self.print_board()

        winner = self.check_winner()
        if winner == 0:
            print("The game ended in a tie.")
        else:
            print("Player %d wins!" % winner)


In [4]:
class ReplayMemory():

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, state, action, next_state, reward):
        self.memory.append((state, action, next_state, reward))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [5]:
class Human():
    def choose_move(self, state, eps):
        col = int(input("Elige tu jugada: "))
        return col

class RNDAgent():
    
    def choose_move(self, state, eps):
        return torch.randint(7, (1,), device = device)

class DQNAgent():
    
    def __init__(self, policy, target):
        self.memo = ReplayMemory(MAX_MEMO)
        self.policy = policy
        self.target = target
    
    @torch.no_grad()
    def choose_move(self, state, eps):
        self.policy.eval()
        if random.random() < eps:
            return torch.randint(7, (1,), device = device)
        else:
            return self.policy(state).max(1)[1]
    
    def replay(self): 
        states, actions, next_states, rewards = zip(*self.memo.sample(BATCH_SIZE))
        
        states = torch.cat(states)
        actions = torch.stack(actions)
        rewards = torch.stack(rewards)

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, next_states)), device = device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in next_states if s is not None])
        
        self.policy.train()
        pred_Q = self.policy(states).gather(1, actions)
        self.target.eval()
        max_next_Q = torch.zeros(BATCH_SIZE, 1, device = device)
        with torch.no_grad():
            max_next_Q[non_final_mask] = self.target(non_final_next_states).max(1, keepdims = True)[0]

        target_Q = rewards - GAMMA * max_next_Q

        loss = F.huber_loss(pred_Q, target_Q)
        return loss
        

class Network(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 128, 4, bias=False),
            nn.BatchNorm2d(128),
            nn.Tanh(),
            nn.Conv2d(128, 128, 2, bias=False),
            nn.BatchNorm2d(128),
            nn.Tanh(),
            nn.Flatten(),
            nn.Linear(2*3*128, 512, bias=False),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 7),
        )
        
    def forward(self, xin):
        B, T = xin.shape
        x = xin.view(B, 1, 6, 7)
        return self.net(x)

@torch.no_grad()
def battle(agent1, agent2, render = False):
    current_player = agent1
    game = Connect4()
    state = torch.from_numpy(game.get_state() * game.current_player).float().to(device)
    turn = 0
    if render: game.print_board()
        
    while True:
        action = current_player.choose_move(state, 0.0)
        reward, finished = game.make_move(action)
        
        if render:
            game.print_board()
            
        if finished:
            if reward == 0:
                return 0
            return -game.current_player
        
        state = torch.from_numpy(game.get_state() * game.current_player).float().to(device)
        current_player = agent2 if current_player == agent1 else agent1
        
        turn += 1
        if turn >= MAX_TURNS:
            return 0

In [8]:
BATCH_SIZE = 128
LR = 1e-4
EPS_START = 0.9
EPS_END = 0.1
EPS_DECAY = 25000
TAU = 0.001

GAMMA = 0.95
MAX_MEMO = 100000
MIN_MEMO = 1000
MAX_ITERS = 300000

BATTLE_FREQ = 3500
BATTLE_NUM = 1000
MAX_TURNS = 100

policy_net = Network().to(device)
policy_net.load_state_dict(torch.load('net.pth', map_location=device))

target_net = Network().to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = torch.optim.AdamW(policy_net.parameters(), lr=LR)
agent = DQNAgent(policy_net, target_net)
rnd_agent = RNDAgent()
player = Human()

record = 90
loss_record = 0.5
record_holder = policy_net.state_dict()
loss_mem = []
wrate_mem = []

In [None]:
game = Connect4()
state = torch.from_numpy(game.get_state() * game.current_player).float().to(device)

for step in range(MAX_ITERS):
    eps = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * step / EPS_DECAY)

    if step % BATTLE_FREQ == 0:
        w1 = 0
        w2 = 0
        for _ in range(BATTLE_NUM):
            winner = battle(rnd_agent, agent)
            w1 += winner == 1
            w2 += winner == -1
        wrate = (w2 / BATTLE_NUM)*100
        wrate_mem.append(wrate)
        print('-------------------------------')
        print(f'Step {step} | epsilon {eps:.2f}')
        print(f'Player 1: {w1} | Player 2: {w2} | wrate: {wrate:.1f}%')

        if len(loss_mem) >= BATTLE_FREQ:
            meanloss = sum(loss_mem[-BATTLE_FREQ:])/BATTLE_FREQ
            print(f'Mean loss: {meanloss:.6f}')

            if wrate == record and meanloss < loss_record:
                loss_record = meanloss
                record_holder = policy_net.state_dict()
                torch.save(record_holder, 'net.pth')
                print(f'The champion got {loss_record} mean loss')
            elif wrate > record:
                loss_record = meanloss    
                record = wrate
                record_holder = policy_net.state_dict()
                torch.save(record_holder, 'net.pth')
                print(f'The champion got {record} winrate')

    #play the game
    action = agent.choose_move(state, eps)
    reward, finished = game.make_move(action)
    reward = torch.tensor([reward], device = device)
    
    if finished:
        next_state = None
        game.reset()
    else:  
        next_state = torch.from_numpy(game.get_state() * game.current_player).float().to(device)
    
    #remember
    agent.memo.push(state, action, next_state, reward)

    #replay
    if len(agent.memo) >= MIN_MEMO:
        #Train policy
        loss = agent.replay()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #Train target
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)
        
        loss_mem.append(loss.item())
    
    #prepare next iter
    if finished:
        state = torch.from_numpy(game.get_state() * game.current_player).float().to(device)
    else:
        state = next_state

-------------------------------
Step 0 | epsilon 0.60
Player 1: 8 | Player 2: 992 | wrate: 99.2%
Learning Rate: 3.5e-05
Mean loss: 0.055575
-------------------------------
Step 3500 | epsilon 0.52
Player 1: 15 | Player 2: 985 | wrate: 98.5%
Learning Rate: 3.5e-05
Mean loss: 0.060350
-------------------------------
Step 7000 | epsilon 0.45
Player 1: 8 | Player 2: 992 | wrate: 99.2%
Learning Rate: 3.5e-05
Mean loss: 0.056073
-------------------------------
Step 10500 | epsilon 0.40
Player 1: 15 | Player 2: 985 | wrate: 98.5%
Learning Rate: 3.5e-05
Mean loss: 0.054917
-------------------------------
Step 14000 | epsilon 0.35
Player 1: 12 | Player 2: 988 | wrate: 98.8%
Learning Rate: 3.5e-05
Mean loss: 0.054472
-------------------------------
Step 17500 | epsilon 0.31
Player 1: 14 | Player 2: 986 | wrate: 98.6%
Learning Rate: 3.5e-05
Mean loss: 0.054948
-------------------------------
Step 21000 | epsilon 0.27
Player 1: 23 | Player 2: 977 | wrate: 97.7%
Learning Rate: 2.4499999999999996e-0

In [19]:
battle(agent, player, render = True)


Connect 4
-----------------
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
-----------------
 0 1 2 3 4 5 6

Connect 4
-----------------
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - X - - - - -
-----------------
 0 1 2 3 4 5 6

Elige tu jugada: 6
Connect 4
-----------------
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - X - - - - O
-----------------
 0 1 2 3 4 5 6

Connect 4
-----------------
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - X - X - - O
-----------------
 0 1 2 3 4 5 6

Elige tu jugada: 2
Connect 4
-----------------
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - X O X - - O
-----------------
 0 1 2 3 4 5 6

Connect 4
-----------------
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - - - - - -
 - - X - - - -
 - X O X - - O
-----------------
 0 1 2 3 4 5 6

Elige tu jugada: 3
Connect 4
-----------------
 - 

1