In [1]:
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from board import Board

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%matplotlib inline

In [2]:
device

device(type='cuda')

In [3]:
Transition = namedtuple('Transition', ['current_state', 'current_action', 'next_state', 'reward', 'done'])

In [4]:
agentConfig = {'epsilon': 0.7, 'step_size': 0.4, 'gamma': 0.999}

In [5]:
class ExperienceReplay():
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.pos = 0
    
    def push(self, current_state, current_action, next_state, reward, done):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        
        self.memory[self.pos] = Transition(current_state, current_action, next_state, reward, done)
        self.pos = (self.pos + 1) % self.capacity
    
    def sample(self, batch_size):
        sample_idx = np.random.choice(np.linspace(0, self.capacity - 1, self.capacity, dtype=int), batch_size)

        x_cs = np.zeros((batch_size, 9))
        x_ns = np.zeros((batch_size, 9))
        a_t = np.zeros((batch_size, 1))
        r = np.zeros((batch_size, 1))  
        d = np.zeros((batch_size, 1))  
        
        for iter_tr in range(batch_size):
            x_cs[iter_tr, :] = self.memory[iter_tr][0].reshape(-1, 9)
            x_ns[iter_tr, :] = self.memory[iter_tr][2].reshape(-1, 9)
            a_t[iter_tr, :] = np.array([self.memory[iter_tr][1][0] * 3 + self.memory[iter_tr][1][1]], dtype=int).reshape(-1, 1)
            r[iter_tr, :] = np.array(self.memory[iter_tr][3]).reshape(-1, 1)
            d[iter_tr, :] = np.array(self.memory[iter_tr][4]).reshape(-1, 1)
        
        return [x_cs, a_t, x_ns, r, d]

In [6]:
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(9, 27)
        self.fc2 = nn.Linear(27, 27)
        self.fc3 = nn.Linear(27, 27)
        self.fc4 = nn.Linear(27, 9)
        
    def forward(self, x):
        y1 = self.fc1(x)
#         y1 = F.relu(y1)
        
        y2 = self.fc2(y1)
#         y2 = F.relu(y2)
        
        y3 = self.fc3(y2)
#         y3 = F.relu(y3)
        
        y4 = self.fc4(y3)

        return y4

In [7]:
def select_action(current_state, Q, ValidActionSpace, agentConfig):
    action_values = Q.forward(torch.from_numpy(current_state.reshape(-1, 9)).float().cuda()).cpu().detach().numpy().reshape(-1,)
    valid_actions = ValidActionSpace[:, 0] * 3 + ValidActionSpace[:, 1]
    
    
    pr = np.random.rand(1)[0]
    
    if pr <= 1 - agentConfig['epsilon']:
        next_action_idx = np.argmax(action_values[valid_actions])
        next_action = valid_actions[next_action_idx]
    else:
        next_action_idx = np.random.choice(np.linspace(0, len(valid_actions) - 1, len(valid_actions)))
        next_action = valid_actions[int(next_action_idx)]

    return [int(next_action / 3), next_action % 3]    

In [8]:
def train(numEpisodes = 500000):
    replayMemory = ExperienceReplay(50000)
    
    env = Board()
    env.reset()
        
    for iter_episodes in range(numEpisodes):
        done = False
        env.reset()
        
        if iter_episodes <= 50000:
            agentConfig['epsilon'] = 0.9
        elif iter_episodes <= 60000:
            agentConfig['epsilon'] = 0.5
        elif iter_episodes <= 70000:
            agentConfig['epsilon'] = 0.5
        else:
            agentConfig['epsilon'] = 0.1
        
        while not done:
            current_state1 = np.copy(env.current_state)
            action_space = env.getValidActionSpace()
            current_action1 = select_action(current_state1, policy_net, action_space, agentConfig)
            [next_state1, reward1, done] = env.step(current_action1, 1)
            replayMemory.push(current_state1, current_action1, next_state1, reward1, done)
            
            if not done:
                current_state2 = np.copy(next_state1) * -1
                action_space = env.getValidActionSpace()
                current_action2 = select_action(current_state2, policy_net, action_space, agentConfig)
                [next_state2, reward2, done] = env.step(current_action2, -1)
                replayMemory.push(current_state2, current_action2, next_state2 * -1, reward2, done)
            
        if (iter_episodes + 1) % 10000 == 0:
            loss_m = 0
            for iter_batches in range(100):
                [x_cs, a_t, x_ns, r, d] = replayMemory.sample(1000)
                q_cs = policy_net.forward(torch.from_numpy(x_cs).float().cuda()).cpu().reshape(-1, 9)
                q_ns = target_net.forward(torch.from_numpy(x_ns).float().cuda()).cpu().reshape(-1, 9)
                r = torch.from_numpy(r).float()
                a_t = torch.from_numpy(a_t)
                d = torch.from_numpy(d).float()

                estimate = torch.stack([q[int(i[0].numpy())] for q, i in zip(q_cs, a_t)]).reshape(-1, 1)
                target = r + agentConfig['gamma'] * torch.argmax(q_ns, axis = 1).float().reshape(-1, 1) * (1 - d)
                
                loss = F.smooth_l1_loss(estimate, target)
                
                loss_m += loss.item()
                
                opt.zero_grad()
                loss.backward()
                
                for param in policy_net.parameters():
                    param.grad.data.clamp_(-1, 1)
                
                if (iter_batches + 1) % 25 == 0:
                    target_net.load_state_dict(policy_net.state_dict())
                opt.step()   

            now = datetime.now()
            print("Loss: ", loss_m / 100, " | Time: ", now.strftime("%H:%M:%S"))
            torch.save(policy_net.state_dict(), 'policy_net.pt')
            torch.save(target_net.state_dict(), 'target_net.pt')
            torch.save(opt.state_dict(), 'opt.pt')

In [9]:
policy_net = DQN().cuda()
target_net = DQN().cuda()

In [10]:
opt = optim.Adam(params = policy_net.parameters())

In [None]:
train()

Loss:  1.982891103029251  | Time:  20:25:08
