# Imports

In [None]:
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from typing import Iterable
import gym

from kaggle_environments import evaluate, make, utils
from kaggle_environments.utils import Struct

# Model architecture

In [None]:
class DQN(nn.Module):

    def __init__(
        self,
        learning_rate: float,
        input_dims: Iterable,
        middle_dims: int,
        output_dims: int,
        n_actions: int
    ):
        super(DQN, self).__init__()
        self.input_layer = nn.Linear(*input_dims, middle_dims)
        self.middle_layer = nn.Linear(middle_dims, output_dims) 
        self.output_layer = nn.Linear(output_dims, n_actions) 
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.loss = nn.MSELoss()
        self.device = 'cpu'
        self.to(self.device)

    def forward(self, state):
        x = F.relu(self.input_layer(state))
        x = F.relu(self.middle_layer(x))
        actions = self.output_layer(x)

        return actions

# Agent

In [None]:
class Agent:

    ROWS = 6
    COLS = 7

    def __init__(
        self,
        gamma: float,
        epsilon: float,
        learning_rate: float,
        input_dims: Iterable,
        batch_size: int,
        n_actions: int,
        max_mem_size: int = 100000
    ):
        self.gamma = gamma
        self.epsilon = epsilon
        self.learning_rate = learning_rate
        self.action_space = [i for i in range(n_actions)]
        self.mem_size = max_mem_size
        self.batch_size = batch_size
        self.mem_counter = 0

        self.Q_eval = DQN(
            learning_rate=self.learning_rate,
            input_dims=input_dims,
            middle_dims=32,
            output_dims=32,
            n_actions=n_actions
        )
        self.state_memory = np.zeros((self.mem_size, *input_dims), dtype=np.float32)
        self.new_state_memory = np.zeros((self.mem_size, *input_dims), dtype=np.float32)
        self.action_memory = np.zeros(self.mem_size, dtype=np.int32)
        self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
        self.terminal_memory = np.zeros(self.mem_size, dtype=bool)

    def store_transition(self, state, action, reward, state_, done):
        index = self.mem_counter % self.mem_size
        self.state_memory[index] = state
        self.new_state_memory[index] = state_
        self.reward_memory[index] = reward
        self.action_memory[index] = action
        self.terminal_memory[index] = done

        self.mem_counter += 1
    
    def choose_action(self, observation: Struct, board: list) -> int:
        state = T.tensor([observation], dtype=T.float32).to(self.Q_eval.device)
        actions = self.Q_eval.forward(state)
        action = self._load_available_action(actions, board)
        return action
    
    def learn(self):
        if self.mem_counter < self.batch_size:
            return 
        
        self.Q_eval.optimizer.zero_grad()

        max_mem = min(self.mem_counter, self.mem_size)
        batch = np.random.choice(max_mem, self.batch_size, replace=False)
        batch_index = np.arange(self.batch_size, dtype=np.int32)

        state_batch = T.tensor(self.state_memory[batch]).to(self.Q_eval.device)
        new_state_batch = T.tensor(self.new_state_memory[batch]).to(self.Q_eval.device)
        reward_batch = T.tensor(self.reward_memory[batch]).to(self.Q_eval.device)
        terminal_batch = T.tensor(self.terminal_memory[batch]).to(self.Q_eval.device)
        action_batch = self.action_memory[batch]

        # loss function
        q_eval = self.Q_eval.forward(state_batch)[batch_index, action_batch]
        q_next = self.Q_eval.forward(new_state_batch)
        q_next['terminal_batch'] = 0.0
        
        q_target = reward_batch + self.gamma + T.max(q_next, dim=1)[0]
        
        loss = self.Q_eval.loss(q_target, q_eval).to(self.Q_eval.device)
        loss.backward()
        self.Q_eval.optimizer.step()

    def _load_available_action(self, actions, board) -> int:
        board = np.array(board).reshape(self.ROWS, self.COLS).T

        actions_list: list = actions.tolist()[0]        
        actions_dict = {k: v for k, v in zip(actions_list, range(len(actions_list)))}

        for i in range(self.COLS):
            if board[i][0]:
                actions_list.remove(max(actions_list))
        
        if len(actions_list):
            action = actions_dict[max(actions_list)]
        else:
            action = 0
        
        return action


# Train agent

In [None]:
n_games = 1
env = make("connectx", debug=True)
env.render()
trainer = env.train([None, "random"])

In [None]:
agent = Agent(
    gamma=0.99,
    epsilon=1.0,
    learning_rate=0.003,
    input_dims=[44],
    batch_size=64,
    n_actions=7
)

In [None]:
def get_model_input(observation: Struct) -> list:
    return observation.board + [observation.step, observation.mark]

In [None]:
for i in range(n_games):
    done  = False
    observation = trainer.reset()
    
    while not done:
        model_input = get_model_input(observation)
        action = agent.choose_action(model_input, observation.board)
        observation_, reward, done, info = trainer.step(int(action))

        model_next_input = get_model_input(observation_)
        agent.store_transition(model_input, action, reward, model_next_input, done)
        agent.learn()
        observation = observation_
        
    env.render(mode="ipython", width=500, height=500, header=False, controls=False)