In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
from collections import namedtuple

In [2]:
device = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [3]:
Datapoint = namedtuple("Datapoint", ("network_input", "action", "discounted_reward"))

In [5]:
Datapoint = namedtuple("Datapoint", ("network_input", "action", "discounted_reward"))
with open("training_data_cleaned.pt", "rb") as f:
    data = pickle.load(f)

In [7]:
from torch.utils.data import DataLoader, Dataset, random_split


ACTIONS = ["UP", "DOWN", "LEFT", "RIGHT", "WAIT", "BOMB"]

class Custom_Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        datapoint = self.data[idx]
        board = datapoint.network_input
        board = torch.squeeze(board, 0).to(device)
        action = datapoint.action if datapoint.action else "WAIT"
        action = torch.tensor(ACTIONS.index(action)).to(device)
        reward = torch.tensor(datapoint.discounted_reward).to(device)
        label = (action, reward)
        return board, label

train_size = int(0.8 * len(data))

train_data, test_data = random_split(data, [train_size, len(data) - train_size])


In [8]:
len(train_data)

307024

In [9]:
train_loader = DataLoader(Custom_Dataset(train_data), batch_size=32, shuffle=True)
test_loader = DataLoader(Custom_Dataset(test_data), batch_size=32, shuffle=True)

In [11]:
from deep_network import MCTSNetwork, save_model, load_model

model = MCTSNetwork().to(device)

In [12]:
loss_policy = nn.CrossEntropyLoss()
loss_value = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

In [13]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

In [14]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.0
    last_loss = 0.0

    for i, data in enumerate(train_loader):
        inputs, labels = data
        action_labels, value_labels = labels
        optimizer.zero_grad()
        outputs = model(inputs)
        value_outputs, action_outputs = outputs
        value_outputs = torch.squeeze(value_outputs, 1)

        action_loss = loss_policy(action_outputs, action_labels)
        value_loss = loss_value(value_outputs, value_labels)
        loss = action_loss + value_loss
        loss.backward()

        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [None]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/train_mcts_model{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(test_loader):
            vinputs, vlabels = vdata
            vaction_labels, vvalue_labels = vlabels
            vaction_labels = vaction_labels
            vvalue_labels = vvalue_labels[:, None]
            voutputs = model(vinputs)
            vvalue_outputs, vaction_outputs = voutputs

            vaction_loss = loss_policy(vaction_outputs, vaction_labels)
            vvalue_loss = loss_value(vvalue_outputs, vvalue_labels)
            vloss = vaction_loss + vvalue_loss
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
  batch 1000 loss: 1.6036082782745362
  batch 2000 loss: 1.50031951212883
  batch 3000 loss: 1.494853812098503
  batch 4000 loss: 1.4809586511850357
  batch 5000 loss: 1.4789516509771348
  batch 6000 loss: 1.47475180375576
  batch 7000 loss: 1.4679318443536757
  batch 8000 loss: 1.4727025296688079
  batch 9000 loss: 1.465678373336792
LOSS train 1.465678373336792 valid 1.4628123044967651
EPOCH 2:
  batch 1000 loss: 1.4573805178403854
  batch 2000 loss: 1.4609615030288696
  batch 3000 loss: 1.4654378856420518
  batch 4000 loss: 1.4545039182901383
  batch 5000 loss: 1.4567536331415176
  batch 6000 loss: 1.4586019105911254
  batch 7000 loss: 1.450035817027092
  batch 8000 loss: 1.4432693802118302
  batch 9000 loss: 1.4478336571455002
LOSS train 1.4478336571455002 valid 1.4484941959381104
EPOCH 3:
  batch 1000 loss: 1.4429887256622314
  batch 2000 loss: 1.4398429354429245
  batch 3000 loss: 1.4470537085533142
  batch 4000 loss: 1.436425262451172
  batch 5000 loss: 1.440939012765884

In [None]:
save_model(model, "mcts_model.pt")