In [157]:
import numpy as np
import pandas as pd

import random

import multiprocessing as mp
from joblib import Parallel, delayed

from tic_env import TictactoeEnv, OptimalPlayer

import plotly.express as px
import plotly.graph_objects as go

In [158]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import deque

from tqdm.notebook import tqdm, trange

In [4]:
# If a GPU is available
if not torch.cuda.is_available():
  raise Exception("Things will go much quicker if you use a GPU")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [223]:
x = torch.arange(6).reshape(2,3).unsqueeze(0)
y = nn.Flatten()(x).max(1)
y

torch.return_types.max(
values=tensor([5]),
indices=tensor([5]))

In [212]:
L = [("a","b","c",torch.arange(6).reshape(2,3)),("a","b","c",torch.arange(6).reshape(2,3)*7),("a","b","c",None),("a","b","c",torch.arange(6).reshape(2,3)*99)]

L.extend(L[:])

In [6]:
class Qnetwork(nn.Module):
    """
    Our Q-network will be a simple linear neural network with two hidden layers.
    """
    def __init__(self, epsilon=0.2, player='X', learningRate=0.0005, discountFactor=1.0, batch_size=64, C=500 ,n_max=100, R=deque(maxlen=10_000)):
        super(Qnetwork, self).__init__()
        self.flattener = nn.Flatten()
        self.inputLayer = nn.Linear(3*3*2, 128),
        self.fullyConnected = nn.Linear(128, 128),
        self.outputLayer = nn.Linear(128, 9)

        if isinstance(epsilon, tuple):
            self.epsilon_min, self.epsilon_max = epsilon
            self.epsilon = self.epsilon_max
        else:
            self.epsilon = epsilon
            self.epsilon_min = epsilon
            self.epsilon_max = epsilon
        self.discountFactor = discountFactor

        self.state = None
        self.action = None
        
        self.n = 0
        self.n_max = n_max

        self.isLearning = True

        self.player = player # 'X' or 'O'

        self.batch_size = batch_size

        self.t = 0
        self.C = C

        self.R = R

        # criterion is Huber loss (with delta = 1)
        self.criterion = nn.HuberLoss()

        # optimizer is Adam
        self.optimizer = torch.optim.Adam(self.parameters(), lr=learningRate)

        # If a GPU is available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.Qtarget = Qnetwork(epsilon, player, learningRate, discountFactor, batch_size, C, n_max, R).to(self.device) # TODO : check this at the end of modifications to see if still correct
        self.Qtarget.load_state_dict(self.state_dict())
    
    def forward(self, x):
        x = x.to(self.device)
        x = self.flattener(x)
        x = F.relu(self.inputLayer(x))
        x = F.relu(self.fullyConnected(x))
        x = self.outputLayer(x)
        return x

    def decrease_epsilon(self):
        self.epsilon = max(self.epsilon_min, self.epsilon_max * (1 - self.n / self.n_max))

    def set_player(self, player = 'X', j=-1):
        self.player = player
        if j != -1:
            self.player = 'X' if j % 2 == 0 else 'O'

    def empty(self, state):
        """ Return all empty positions. """
        availableActions = []
        for x in range(3):
            for y in range(3):
                position = (x, y)
                if state[position] == 0:
                    availableActions.append(position)
        return availableActions

    def randomAction(self, state):
        """ Choose a random action from the available options. """
        availableActions = self.empty(state)

        return random.choice(availableActions)

    def bestAction(self, state):
        """
        Choose the available actions which have a maximum expected future reward
        using the Q-network.
        """
        # convert state to tensor, adding batch dimension
        with torch.no_grad:
            q_values = self.forward(state)
        return q_values.argmax(dim=1).item()

    def act(self, grid):
        """
        epsilon-greedy action selection, according to the Q-table.
        """
        state = torch.tensor(grid, dtype=torch.int64)
        state = F.one_hot(state+1)
        state = state[:,:,(2,0)]
        state = state.unsqueeze(0)
        state = state.type(torch.float).to(self.device)
        self.state = state

        # whether move in random or not
        if random.random() < self.epsilon:
            action = self.randomAction(grid)
            self.action = action[0] * 3 + action[1]
        else:
            # Get the best move
            self.action = self.bestAction(self.state)
            # action is a tuple of (x, y) from self.action
            action = (self.action // 3, self.action % 3)


        return action

    def learn(self, grid, reward, end=False):
        if self.isLearning:
            if not end:
                s_prime = torch.tensor(grid, dtype=torch.int64)
                s_prime = F.one_hot(s_prime+1)
                s_prime = s_prime[:,:,(2,0)]
                s_prime = s_prime.unsqueeze(0)
                s_prime = s_prime.type(torch.float).to(self.device)

                self.R.append((self.state, self.action, reward, s_prime))
            else:
                self.R.append((self.state, self.action, reward, None))

                self.state = None
                self.action = None

                self.n += 1
                self.decrease_epsilon()
            # self.R is a deque with maxlen=buffer_size so it auto pop

            if len(self.R) < self.batch_size:
                pass
                # return # TODO: check if this is correct.
            # TODO See for impossible moves
            
            # sample random minibatch from self.R
            batch = random.sample(self.R, self.batch_size)

            # convert to tensor
            states = torch.cat([x[0] for x in batch]).to(self.device)
            actions = [(i,x[1]) for i,x in enumerate(batch)]
            rewards = torch.tensor([x[2] for x in batch]).to(self.device)
            s_primes = torch.cat([x[3] for x in batch if x[3] is not None]).to(self.device)
            s_prime_mask = torch.tensor([x[3] is not None for x in batch], device=self.device, dtype=torch.bool)

            self.optimizer.zero_grad()

            Q_theta_sj_aj = self.forward(states)[actions]
            maxQtarget = torch.zeros(self.batch_size, device=self.device)
            maxQtarget[s_prime_mask] = self.Qtarget.forward(s_primes).max(dim=1).values.detach()

            loss = self.criterion(Q_theta_sj_aj, rewards + self.discountFactor*maxQtarget)

            loss.backward()
            self.optimizer.step()

            self.t += 1
            if self.t == self.C:
                self.t = 0
                self.Qtarget.load_state_dict(self.state_dict())
        
        elif end:
            self.state = None
            self.action = None
                


.