## DQN train notebook

In [1]:
import sys
sys.path.append('../')

import numpy as np
from cube import Cube
from tqdm import tqdm
from q_network import QNetwork2DConv, QNetwork3DConv
import torch
from copy import deepcopy
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
episodes = 10000
episode_len = 15
scramble_steps = 5

lr = 0.0001
buffer_size = 10000
batch_len = 32
gamma = 0.9
epsilon = 0.1
target_update_interval = 1000
learning_starts = 1000

avg_interval = 25

In [3]:
Q_net = QNetwork2DConv().to(device)
Q_net.train()

Q_target_net = deepcopy(Q_net)

optim = torch.optim.SGD(Q_net.parameters(), lr = lr)

replay_buffer = []

In [4]:
action_encode = {
    'F': 0,
    'F\'': 1,
    'B' : 2,
    'B\'': 3,
    'L' : 4,
    'L\'': 5,
    'R' : 6,
    'R\'': 7,
    'U' : 8,
    'U\'': 9,
    'D' : 10,
    'D\'': 11,
}
action_decode = {encoding : action for action, encoding in action_encode.items()}

def normalize_state2D(state):
    state = (state - 2.5) / 2.5
    state = torch.Tensor(state).to(device)
    return state

def normalize_state3D(state):
    state = (state - 2.5) / 2.5
    state = torch.Tensor(state).to(device).long()
    return state

# Change this when changing architecture
normalize_state = normalize_state2D

In [5]:
iters_since_target_update = 0
it = 0
eps_since_last_avg = 0
ep_average_reward = 0
ep_average_loss = 0
ep_average_rewards = []
ep_average_losses = []
for ep in tqdm(range(episodes)):
    cube = Cube()
    cube.scramble(scramble_steps)
    s = np.copy(cube.facelets)
    
    for i in range(episode_len):
        with torch.no_grad():
            Q_pred = Q_net(torch.Tensor(normalize_state(s))[None, :].to(device))
        
        a = torch.argmax(Q_pred).item()
        if np.random.uniform() < epsilon:
            a = np.random.randint(12)
        
        r = cube.rotate_code_get_reward(action_decode[a])
        s_prime = np.copy(cube.facelets)
        
        replay_buffer.append({'s': s, 'a': a, 'r': r, 's_prime': s_prime})
        if len(replay_buffer) > buffer_size:
            replay_buffer.pop(0)
        
        s = s_prime
        it += 1
        
        if it >= learning_starts:
            batch = np.random.choice(replay_buffer, batch_len, replace = False)
            batch_s = np.array([x['s'] for x in batch])
            batch_a = np.array([x['a'] for x in batch])
            batch_r = np.array([x['r'] for x in batch])
            batch_r = torch.Tensor(batch_r).to(device)
            batch_s_prime = np.array([x['s_prime'] for x in batch])
            
            Q_hat = Q_net(normalize_state(batch_s))[range(batch_len), torch.Tensor(batch_a).to(device).long()]
            
            with torch.no_grad():
                Q_target = batch_r + gamma * torch.max(Q_target_net(normalize_state(batch_s_prime)), axis = 1).values
                for j in range(batch_len):
                    if cube.is_solved_state(batch_s_prime[j]):
                        Q_target[j] = batch_r[j]
            
            optim.zero_grad()
            loss = torch.nn.MSELoss()(Q_hat, Q_target)
            # print(loss)
            loss.backward()
            
            optim.step()
            
            iters_since_target_update += 1
            if iters_since_target_update >= target_update_interval:
                Q_target_net = deepcopy(Q_net)
                iters_since_target_update = 0
            
            ep_average_reward += r
            ep_average_loss += loss.item()
            # print("LOSS",loss.item())
        
        if cube.is_solved_state(s):
            break
            
    eps_since_last_avg += 1
    if eps_since_last_avg >= avg_interval:
        ep_average_rewards.append(ep_average_reward / avg_interval)
        ep_average_losses.append(ep_average_loss / avg_interval)
        print(ep_average_losses[-1])
        ep_average_reward = 0
        ep_average_loss = 0
        eps_since_last_avg = 0

  0%|          | 37/10000 [00:03<06:58, 23.83it/s] 

0.0


  1%|          | 64/10000 [00:03<03:24, 48.52it/s]

0.0


  1%|          | 73/10000 [00:04<06:04, 27.23it/s]

66.375627784729


  1%|          | 102/10000 [00:06<10:38, 15.51it/s]

190.8638435935974


  1%|▏         | 126/10000 [00:07<11:33, 14.24it/s]

181.39197227478027


  2%|▏         | 152/10000 [00:09<11:36, 14.15it/s]

182.20812915802003


  2%|▏         | 176/10000 [00:11<12:59, 12.60it/s]

174.70095946311952


  2%|▏         | 202/10000 [00:13<12:54, 12.65it/s]

173.39609645843507


  2%|▏         | 206/10000 [00:14<11:10, 14.61it/s]


KeyboardInterrupt: 

In [None]:
plt.plot(ep_average_rewards)
plt.title('Average episode reward')
plt.show()

In [None]:
plt.plot(ep_average_losses)
plt.title('Average episode loss')
plt.show()