In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
from collections import namedtuple

In [None]:
device = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
TrainingData = namedtuple("TrainingData", ["netwrok_input", "action", "reward"])

In [None]:
!unzip training_data_cleaned.zip

Archive:  training_data_cleaned.zip
replace training_data_cleaned.pt? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
with open("training_data_cleaned.pt", "rb") as f:
    data = pickle.load(f)

In [None]:
from torch.utils.data import DataLoader, Dataset, random_split


ACTIONS = ["UP", "DOWN", "LEFT", "RIGHT", "WAIT", "BOMB"]

class Custom_Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        datapoint = self.data[idx]
        board = datapoint.netwrok_input
        board = torch.squeeze(board, 0).to(device)
        action = datapoint.action
        action = torch.tensor(action).to(device)
        reward = torch.tensor(datapoint.reward, dtype=torch.float32).to(device)
        label = (action, reward)
        return board, label

train_size = int(0.8 * len(data))

train_data, test_data = random_split(data, [train_size, len(data) - train_size])


In [None]:
len(train_data)

302047

In [None]:
train_loader = DataLoader(Custom_Dataset(train_data), batch_size=32, shuffle=True)
test_loader = DataLoader(Custom_Dataset(test_data), batch_size=32, shuffle=True)

In [None]:
from deep_network import MCTSNetwork, save_model, load_model

model = MCTSNetwork().to(device)

In [None]:
loss_policy = nn.CrossEntropyLoss()
loss_value = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

In [None]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.0
    last_loss = 0.0

    for i, data in enumerate(train_loader):
        inputs, labels = data
        action_labels, value_labels = labels
        optimizer.zero_grad()
        outputs = model(inputs)
        value_outputs, action_outputs = outputs
        value_outputs = torch.squeeze(value_outputs, 1)

        action_loss = loss_policy(action_outputs, action_labels)
        value_loss = loss_value(value_outputs, value_labels)
        loss = action_loss + value_loss
        loss.backward()

        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [None]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/train_mcts_model{}'.format(timestamp))
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(test_loader):
            vinputs, vlabels = vdata
            vaction_labels, vvalue_labels = vlabels
            vaction_labels = vaction_labels
            vvalue_labels = vvalue_labels[:, None]
            voutputs = model(vinputs)
            vvalue_outputs, vaction_outputs = voutputs

            vaction_loss = loss_policy(vaction_outputs, vaction_labels)
            vvalue_loss = loss_value(vvalue_outputs, vvalue_labels)
            vloss = vaction_loss + vvalue_loss
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
  batch 1000 loss: 1.4940115375518799
  batch 2000 loss: 1.423982370853424
  batch 3000 loss: 1.4111077930927276
  batch 4000 loss: 1.4111725059747695
  batch 5000 loss: 1.4075127046108247
  batch 6000 loss: 1.4021588814258577
  batch 7000 loss: 1.402689187169075
  batch 8000 loss: 1.4052076382637024
  batch 9000 loss: 1.3951623042821883
LOSS train 1.3951623042821883 valid 1.3920483589172363
EPOCH 2:
  batch 1000 loss: 1.3972736817598344
  batch 2000 loss: 1.3956752136945725
  batch 3000 loss: 1.3996170933246612
  batch 4000 loss: 1.391513397336006
  batch 5000 loss: 1.3956113802194596
  batch 6000 loss: 1.3929737962484359
  batch 7000 loss: 1.3984960366487502
  batch 8000 loss: 1.3917562190294266
  batch 9000 loss: 1.391571696281433
LOSS train 1.391571696281433 valid 1.3862462043762207
EPOCH 3:
  batch 1000 loss: 1.3939559390544891
  batch 2000 loss: 1.3865705012083054
  batch 3000 loss: 1.385343130350113
  batch 4000 loss: 1.3932771669626236
  batch 5000 loss: 1.388243237018

In [None]:
save_model(model, "mcts_model.pt")

In [None]:
!zip -r runs.zip runs/

  adding: runs/ (stored 0%)
  adding: runs/train_mcts_model20240923_013213/ (stored 0%)
  adding: runs/train_mcts_model20240923_013213/events.out.tfevents.1727055133.eb3f025b953f.2111.3 (deflated 9%)
  adding: runs/train_mcts_model20240923_013057/ (stored 0%)
  adding: runs/train_mcts_model20240923_013057/events.out.tfevents.1727055057.eb3f025b953f.2111.1 (deflated 9%)
  adding: runs/train_mcts_model20240923_013710/ (stored 0%)
  adding: runs/train_mcts_model20240923_013710/events.out.tfevents.1727055430.eb3f025b953f.4818.0 (deflated 58%)
  adding: runs/train_mcts_model20240923_013710/Training vs. Validation Loss_Training/ (stored 0%)
  adding: runs/train_mcts_model20240923_013710/Training vs. Validation Loss_Training/events.out.tfevents.1727055506.eb3f025b953f.4818.1 (deflated 48%)
  adding: runs/train_mcts_model20240923_013710/Training vs. Validation Loss_Validation/ (stored 0%)
  adding: runs/train_mcts_model20240923_013710/Training vs. Validation Loss_Validation/events.out.tfevents