# Setup

In [None]:
%%capture
#@title External Imports
import torch
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-geometric


In [None]:
#@title Imports
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

from torch_geometric.nn import MessagePassing
from torch_geometric.nn.conv import SAGEConv
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
from torch_geometric.loader import DataLoader as ShannonDataLoader

from tqdm import tqdm
import random

from torch_scatter import scatter, scatter_softmax

from math import sqrt, ceil
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
#@title Hyperparameters and Constants
board_width = 11
num_tiles = board_width**2

# Define the hyperparameters we are going to use:
gnn_params = {
    'input_dim': 3,
    'hidden_dim': 128,
    'num_layers': 9,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'drop_ratio': 0
}

gin_params = {
    'input_dim': 3,
    'hidden_dim': 128,
    'num_layers': 9,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'drop_ratio': 0
}

cnn_params = {
    'input_dim': 3,
    'hidden_dim': 128,
    'output_dim': num_tiles+1,
    'num_layers': 10,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'drop_ratio': 0
}

hex_grid_params = {
    'input_dim': 3,
    'hidden_dim': 128,
    'num_layers': 10,
    'learning_rate': 1e-4,#5e-5,
    'weight_decay': 1e-4,
    'drop_ratio': 0
}

det_shannon_params = {
    'input_dim': 2,
    'hidden_dim': 128,
    'num_layers': 10,
    'learning_rate': 1e-4,#5e-5,
    'weight_decay': 1e-4,
    'drop_ratio': 0
}

rni_shannon_params = {
    'input_dim': 2,
    'hidden_dim': 128,
    'num_layers': 10,
    'random_dims': 6,
    'learning_rate': 1e-4,#5e-5,
    'weight_decay': 1e-4,
    'drop_ratio': 0
}

sage_params = {
    'input_dim': 2,
    'hidden_dim': 128,
    'num_layers': 10,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'drop_ratio': 0
}

params = {
    'num_epochs': 100,
    'save_interval': 2, # 0 to disable saves entirely
    'plot_interval': 1, # 0 to disable plots entirely
    'early_stop_threshold': 10,
    'policy_loss_fn': nn.CrossEntropyLoss(reduction='mean'),
    'value_loss_fn': nn.MSELoss(),
    'loss_weight': 0.5,

    'generate_and_save_triples': False,
    'generate_and_save_shannon_triples': False,
    'plot_initial_outputs': False,

    'load_gnn': False,
    'load_cnn': False,

    'load_hex_grid':    False,
    'load_det_shannon': False,
    'load_rni_shannon': False,
    'load_cnn': False,
    'load_sage': False,

    'new_hex_grid': True,
    'new_det_shannon': False,
    'new_rni_shannon': False,
    'new_cnn': True,
    'new_sage': False,

    'act_fun': F.relu,
    'batch_size': 1024,
}


In [None]:
#@title Load Models
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model_checkpoint_filenames_to_load = []

if params['load_hex_grid']:
    action_acc  = "0.37"
    outcome_acc = "0.8"
    tr_epoch    = "75"
    network_name = "HexGridGNN"
    model_checkpoint_filenames_to_load.append(
        f'/content/drive/MyDrive/Colab Notebooks/hexbot/temp/action-{action_acc}_E{tr_epoch}_{network_name}_D10_outcome{outcome_acc}.pt'
    )

if params['load_det_shannon']:
    action_acc  = "0.42"
    outcome_acc = "0.76"
    tr_epoch    = "75"
    network_name = "ShannonGNN"
    model_checkpoint_filenames_to_load.append(
        f'/content/drive/MyDrive/Colab Notebooks/hexbot/temp/action-{action_acc}_E{tr_epoch}_{network_name}_D10_outcome{outcome_acc}.pt'
    )

if params['load_rni_shannon']:
    action_acc  = "0.42"
    outcome_acc = "0.74"
    tr_epoch    = "75"
    network_name = "ShannonGNN-RNI"
    model_checkpoint_filenames_to_load.append(
        f'/content/drive/MyDrive/Colab Notebooks/hexbot/temp/action-{action_acc}_E{tr_epoch}_{network_name}_D10_outcome{outcome_acc}.pt'
    )

if params['load_cnn']:
    action_acc  = "0.28"
    outcome_acc = "0.84"
    tr_epoch    = "76"
    network_name = "CNN"
    model_checkpoint_filenames_to_load.append(
        f'/content/drive/MyDrive/Colab Notebooks/hexbot/temp/action-{action_acc}_E{tr_epoch}_{network_name}_D10_outcome{outcome_acc}.pt'
    )

if params['load_sage']:
    action_acc  = "0.32"
    outcome_acc = "0.6"
    tr_epoch    = "75"
    network_name = "GraphSAGE"
    model_checkpoint_filenames_to_load.append(
        f'/content/drive/MyDrive/Colab Notebooks/hexbot/temp/action-{action_acc}_E{tr_epoch}_{network_name}_D10_outcome{outcome_acc}.pt'
    )


In [None]:
def is_coord_in_board(coord):
    x, y = coord
    if x < board_width and x >= 0:
        if y < board_width and y >= 0:
            return True

def index_from_position(row, col):
    return col + row * board_width

def generate_hex_grid_adj_matrix():
    adj_matrix = torch.zeros(num_tiles, num_tiles)

    for i in range(board_width):
        for j in range(board_width):
            potential_neighbours = [(i, j+1), (i+1, j), (i-1, j), (i, j-1), (i-1, j+1), (i+1, j-1)]
            for (x,y) in potential_neighbours:
                if is_coord_in_board((x,y)):
                    adj_matrix[index_from_position(i,j)][index_from_position(x,y)] = 1
    return adj_matrix


from queue import Queue

def generate_3d_node_coords():
    centre_3d_coord = (0,0,0)
    centre_2d_coord = [(board_width-1)//2, (board_width-1)//2]

    queue = Queue()
    queue.put((centre_3d_coord, centre_2d_coord))

    all_3d_node_coords = torch.zeros(num_tiles, 3)

    seen_2d_coords = set()

    while not queue.empty():
        current_3d_position, current_2d_position = queue.get()
        all_3d_node_coords[index_from_position(*current_2d_position)] = torch.tensor(current_3d_position)

        offsets_3d = [(0,1,-1), (0,-1,1), (1,0,-1), (-1,0,1), (1,-1,0), (-1,1,0)]
        corresponding_2d_offsets = [(0,1), (0,-1), (1,0), (-1,0), (1,1), (-1,-1)]
        for i, offset_3d in enumerate(offsets_3d):
            offset_2d = corresponding_2d_offsets[i]
            new_2d_coord = tuple(map(sum, zip(offset_2d, current_2d_position)))
            new_3d_coord = tuple(map(sum, zip(offset_3d, current_3d_position)))
            if is_coord_in_board(new_2d_coord) and new_2d_coord not in seen_2d_coords:
                queue.put([new_3d_coord, new_2d_coord])
                seen_2d_coords.add(new_2d_coord)

    return all_3d_node_coords

def display_matrix(matrix):
    for i in range(len(matrix)):
        for j in range(len(matrix)):
            print(str(int(matrix[i][j].item())), end=' ')
        print()

adj_matrix = generate_hex_grid_adj_matrix().to(device)
all_3d_node_coords = generate_3d_node_coords().to(device)

global_edge_index = adj_matrix.nonzero().t().contiguous()
#adj_matrix = adj_matrix.to_sparse()

tile_relevances = torch.zeros(num_tiles, 3)
for i, pos in enumerate(all_3d_node_coords):
    x, y, z = pos
    tile_relevances[i] = torch.tensor([torch.abs(x), torch.abs(y), (torch.sign(x)+torch.sign(y))/2 + 1])

positions = torch.zeros(520, num_tiles, 3).to(device)
#print(tile_relevances)



# Data

In [None]:
#@title Board Object
# We need the data input to our model to be a tensor of shape (num_tiles x input_embedding_dim).
# We want out data labels to be an index in range(num_tiles)
EMPTY_VECTOR = torch.tensor([1.,0.,0.])
WHITE_VECTOR = torch.tensor([0.,1.,0.])
BLACK_VECTOR = torch.tensor([0.,0.,1.])

class Board:
    WHITE = 'W'
    BLACK = 'B'

    def __init__(self, width=board_width):
        self.state = torch.tile(EMPTY_VECTOR, (num_tiles, 1))
    def get_index_from_position(self, position):
        col = ord(position[0])-ord('a')
        row = int(position[1:])-1
        return row*board_width + col

    def update_state(self, position, colour):
        index = self.get_index_from_position(position)
        tile_vector = WHITE_VECTOR if colour == self.WHITE else BLACK_VECTOR
        assert torch.equal(self.state[index], EMPTY_VECTOR)
        #self.state[index] = WHITE_VECTOR if colour == self.WHITE else BLACK_VECTOR
        self.state[index] = tile_vector
        #assert self.state.shape == (num_tiles, len(EMPTY_VECTOR))


In [None]:
#@title Get State, Action, Outcome Triples
def get_state_action_outcome_triples_from_file(game_data_filename):
    states = []
    actions = []
    outcomes = []

    with open(game_data_filename, 'r') as file:
        intro_line = file.readline()
        for line in tqdm(file.readlines()):
            raw_game_string_array = line.split(' ')[:-1]
            last_player = raw_game_string_array[-1][0]
            outcome_wrt_last_player = int(float(line.split(' ')[-1]))
            outcome_wrt_white_player = outcome_wrt_last_player if last_player == 'W' else -outcome_wrt_last_player
            outcome_wrt_black_player = -outcome_wrt_white_player

            normal_board = Board()
            inverse_board = Board()
            for i, raw_move_string in enumerate(raw_game_string_array):
                first_open_bracket = raw_move_string.find('[')
                first_close_bracket = raw_move_string.find(']')

                white_turn = i%2 == 0
                position = raw_move_string[first_open_bracket+1:first_close_bracket]

                # We only create the state, action pairs for the second move onwards.
                # This is because the first move is random, and not chosen by the bot.
                if i > 0:
                    second_open_bracket = raw_move_string.find('[', first_close_bracket+1)
                    second_close_bracket = raw_move_string.find(']', first_close_bracket+1)

                    position_probability_pairs = raw_move_string[second_open_bracket+1:second_close_bracket].split(';')[:-1]
                    probability_vector = torch.zeros(num_tiles)

                    for s in position_probability_pairs:
                        action, probability = s.split(':')
                        index = normal_board.get_index_from_position(action) #if white_turn else rotated_board.get_index_from_position(action)
                        probability_vector[index] = float(probability)

                    # Scale the non-zero elements so that they sum to 1.
                    probability_vector /= torch.sum(probability_vector)

                    actions.append(probability_vector)
                    state_to_add = normal_board.state if white_turn else inverse_board.state
                    states.append(torch.clone(state_to_add).detach())

                    outcomes.append(outcome_wrt_white_player if white_turn else outcome_wrt_black_player)

                # Update the board state, ready for the next move
                normal_board.update_state(position, 'W' if white_turn else 'B')
                inverse_board.update_state(position, 'B' if white_turn else 'W')

    return states, actions, outcomes

filename = 'drive/MyDrive/Colab Notebooks/hexbot/game_data_triples.pt'
if params['generate_and_save_triples']:
    states, actions, outcomes = get_state_action_outcome_triples_from_file('drive/MyDrive/Colab Notebooks/hexbot/game_data')
    torch.save([states, actions, outcomes], filename)
elif params['new_hex_grid'] or params['new_cnn'] or params['load_cnn'] or params['load_hex_grid']:
    states, actions, outcomes = torch.load(filename)


In [None]:
#@title Shannon Board Object
def generate_node_adjacency_sets(inverse=False):
    adjacency_sets_by_node = { n:set() for n in range(num_tiles+2) }

    for i in range(board_width):
        for j in range(board_width):
            potential_neighbours = [(i, j+1), (i+1, j), (i-1, j), (i, j-1), (i-1, j+1), (i+1, j-1)]
            for (x,y) in potential_neighbours:
                if is_coord_in_board((x,y)):
                    adjacency_sets_by_node[index_from_position(i,j)].add(index_from_position(x,y))
    if not inverse:
        adjacency_sets_by_node[num_tiles]   = {i*board_width   for i in range(board_width)}
        adjacency_sets_by_node[num_tiles+1] = {i*board_width-1 for i in range(1, board_width+1)}
        for i in range(board_width):
            adjacency_sets_by_node[i*board_width].add(num_tiles)
            adjacency_sets_by_node[((i+1)*board_width)-1].add(num_tiles+1)
    else:
        adjacency_sets_by_node[num_tiles]   = {i for i in range(board_width)}
        adjacency_sets_by_node[num_tiles+1] = {num_tiles-i for i in range(1,board_width+1)}
        for i in range(board_width):
            adjacency_sets_by_node[i].add(num_tiles)
            adjacency_sets_by_node[num_tiles-(i+1)].add(num_tiles+1)
    return adjacency_sets_by_node

PLAYABLE_VECTOR = torch.tensor([1.,0.]).unsqueeze(0)
EDGE_VECTOR = torch.tensor([0.,1.]).unsqueeze(0)

class ShannonEquivBoard:
    def __init__(self, inverse=False):
        self.state = generate_node_adjacency_sets(inverse)
        self.node_name_lookup = { n:n for n in range(num_tiles+2) }
        self.original_num_nodes = num_tiles + 2
        self.num_playable_tiles = num_tiles

    def get_index_from_position(self, position):
        col = ord(position[0])-ord('a')
        row = int(position[1:])-1
        return row*board_width + col

    def update_state(self, position, colour):
        self.num_playable_tiles -= 1
        index = self.get_index_from_position(position)
        neighbours = self.state[index]
        for node_to_make_smaller in range(index+1, self.original_num_nodes):
            self.node_name_lookup[node_to_make_smaller] -= 1

        for node in neighbours:
            self.state[node].remove(index)
            if colour == 'W':
                self.state[node] = self.state[node] | neighbours
                self.state[node].remove(node)
        self.state[index] = set()

    def get_graph(self, move_index, outcome):
        top = []
        bottom = []
        for node, neighbours in self.state.items():
            for neighbour in neighbours:
                top.append(self.node_name_lookup[node])
                bottom.append(self.node_name_lookup[neighbour])
        edge_index = torch.tensor([top,bottom]).long()
        x = torch.cat([torch.tile(PLAYABLE_VECTOR, (self.num_playable_tiles,1)), EDGE_VECTOR, EDGE_VECTOR], dim=0)
        return Data(x=x, edge_index=edge_index, y=move_index, outcome=outcome)


In [None]:
#@title Get Shannon Graph, Action, Outcome Triples
def get_shannon_graphs_from_file(game_data_filename):
    graphs = []

    with open(game_data_filename, 'r') as file:
        intro_line = file.readline()
        for line in tqdm(file.readlines()):
            raw_game_string_array = line.split(' ')[:-1]
            last_player = raw_game_string_array[-1][0]
            outcome_wrt_last_player = int(float(line.split(' ')[-1]))
            outcome_wrt_white_player = outcome_wrt_last_player if last_player == 'W' else -outcome_wrt_last_player
            outcome_wrt_black_player = -outcome_wrt_white_player

            normal_board = ShannonEquivBoard()
            inverse_board = ShannonEquivBoard(inverse=True)
            for i, raw_move_string in enumerate(raw_game_string_array):
                first_open_bracket = raw_move_string.find('[')
                first_close_bracket = raw_move_string.find(']')

                white_turn = i%2 == 0
                position = raw_move_string[first_open_bracket+1:first_close_bracket]

                # We only create the state, action pairs for the second move onwards.
                # This is because the first move is random, and not chosen by the bot.
                if i > 0:
                    second_open_bracket = raw_move_string.find('[', first_close_bracket+1)
                    second_close_bracket = raw_move_string.find(']', first_close_bracket+1)

                    position_probability_pairs = raw_move_string[second_open_bracket+1:second_close_bracket].split(';')[:-1]
                    probability_vector = torch.zeros(normal_board.num_playable_tiles+2 if white_turn else inverse_board.num_playable_tiles+2)

                    for s in position_probability_pairs:
                        action, probability = s.split(':')
                        index = normal_board.get_index_from_position(action) #if white_turn else rotated_board.get_index_from_position(action)
                        if white_turn and normal_board.node_name_lookup[index] != 'X':
                            probability_vector[normal_board.node_name_lookup[index]] = float(probability)
                        elif not white_turn and inverse_board.node_name_lookup[index] != 'X':
                            probability_vector[inverse_board.node_name_lookup[index]] = float(probability)

                    # We now have the probability vector, we want to remove all entries corresponding to
                    # Scale the non-zero elements so that they sum to 1.
                    move_index = torch.argmax(probability_vector).unsqueeze(0)

                    graph_to_add = normal_board.get_graph(move_index, outcome_wrt_white_player) if white_turn else inverse_board.get_graph(move_index, outcome_wrt_black_player)
                    graphs.append(graph_to_add.clone().detach())

                # Update the board state, ready for the next move
                normal_board.update_state(position, 'W' if white_turn else 'B')
                inverse_board.update_state(position, 'B' if white_turn else 'W')

    return graphs

filename = f'drive/MyDrive/Colab Notebooks/hexbot/shannon_graphs.pt'
if params['generate_and_save_shannon_triples']:
    print('generating and saving')
    shannon_graphs = get_shannon_graphs_from_file('drive/MyDrive/Colab Notebooks/hexbot/game_data')
    torch.save(shannon_graphs, filename)
elif params['new_det_shannon'] or params['new_rni_shannon'] or params['new_sage'] or params['load_det_shannon'] or params['load_rni_shannon'] or params['load_sage']:
    shannon_graphs = torch.load(filename)



In [None]:
#@title Visualise State Action Pairs
def plot_hexagonal_grid_with_colors(size, color_keys, probability_mode=False, rotated=False):
    colors = np.array([(0.75,0.75,0.75), (1,1,1), (0,0,0)])
    left_and_right_color = colors[1] if not rotated else colors[2]
    top_and_bottom_color = colors[2] if not rotated else colors[1]

    _, ax = plt.subplots()
    ax.set_aspect('equal')
    ax.axis('off')

    hex_size = 1
    horizontal_spacing = np.sqrt(3) * hex_size
    vertical_spacing = 3/2 * hex_size
    for col in range(1, size-1):
        x_top = col * horizontal_spacing + (size-1)*horizontal_spacing/2
        y_bottom = (size-1) * vertical_spacing
        x_bottom = col * horizontal_spacing
        y_top = 0
        hex_border_top_outline = patches.RegularPolygon((x_top, y_top), numVertices=6, radius=hex_size, orientation=np.radians(0), edgecolor='black', facecolor='none', lw=20, zorder=2)
        hex_border_bottom_outline = patches.RegularPolygon((x_bottom, y_bottom), numVertices=6, radius=hex_size, orientation=np.radians(0), edgecolor='black', facecolor='none', lw=20, zorder=2)
        hex_border_top = patches.RegularPolygon((x_top, y_top), numVertices=6, radius=hex_size, orientation=np.radians(0), edgecolor=top_and_bottom_color, facecolor='none', lw=18, zorder=3)
        hex_border_bottom = patches.RegularPolygon((x_bottom, y_bottom), numVertices=6, radius=hex_size, orientation=np.radians(0), edgecolor=top_and_bottom_color, facecolor='none', lw=18, zorder=3)
        ax.add_patch(hex_border_top_outline)
        ax.add_patch(hex_border_bottom_outline)
        ax.add_patch(hex_border_top)
        ax.add_patch(hex_border_bottom)

    for row in range(1, size-1):
        x_left = row * horizontal_spacing/2
        y = (size-(row+1))*vertical_spacing
        x_right = row * horizontal_spacing/2 + (size-1) * horizontal_spacing
        hex_border_left_outline = patches.RegularPolygon((x_left, y), numVertices=6, radius=hex_size, orientation=np.radians(0), edgecolor='black', facecolor='none', lw=20, zorder=2)
        hex_border_right_outline = patches.RegularPolygon((x_right, y), numVertices=6, radius=hex_size, orientation=np.radians(0), edgecolor='black', facecolor='none', lw=20, zorder=2)
        hex_border_left = patches.RegularPolygon((x_left, y), numVertices=6, radius=hex_size, orientation=np.radians(0), edgecolor=left_and_right_color, facecolor='none', lw=18, zorder=3)
        hex_border_right = patches.RegularPolygon((x_right, y), numVertices=6, radius=hex_size, orientation=np.radians(0), edgecolor=left_and_right_color, facecolor='none', lw=18, zorder=3)
        ax.add_patch(hex_border_left_outline)
        ax.add_patch(hex_border_right_outline)
        ax.add_patch(hex_border_left)
        ax.add_patch(hex_border_right)

    for row in range(size):
        for col in range(size):
            x = col * horizontal_spacing + row*horizontal_spacing/2
            y = (size-(row+1)) * vertical_spacing

            # Calculate centroid of the hexagon
            #centroid_x, centroid_y = np.mean(hexagon.get_xy()[:-1], axis=0)
            hexagon = patches.RegularPolygon((x, y), numVertices=6, radius=hex_size, orientation=np.radians(0), edgecolor='k', facecolor='none', lw=1, zorder=4)

            # Add label at the centroid
            #centroid_x, centroid_y = np.mean(vertices, axis=1)
            centroid_x, centroid_y = hexagon.xy[0], hexagon.xy[1]
            # label = f"{''.join(str(x) for x in all_3d_node_coords[index_from_position(row, col)].to(int).tolist())}"  # Adjust label format as needed
            # ax.text(x, y, label, ha='center', va='center', color='green', fontsize=8, zorder=5)
            # Calculate centroid of the hexagon

            ax.add_patch(hexagon)
            color_key = color_keys[row * size + col]
            if probability_mode:
                color = plt.cm.plasma(color_key)
            else:
                color = colors[np.argmax(color_key)]
            hexagon.set_facecolor(color)

    ax.autoscale_view()
    plt.savefig('triplestate.pdf' if not probability_mode else 'tripleaction.pdf')
    plt.show()

if params['new_hex_grid']:
    to_plot = torch.tile(torch.tensor([1,0,0]), (num_tiles, 1))
    plot_hexagonal_grid_with_colors(board_width, states[80], rotated=False)
    plot_hexagonal_grid_with_colors(board_width, actions[80], probability_mode=True)
    #plot_hexagonal_grid_with_colors(board_width, to_plot, rotated=False)
    print(outcomes[80])

In [None]:
#@title Build Datasets
class MoHexGamesDataset(Dataset):
    def __init__(self, states, actions, outcomes):
        assert len(states) == len(actions) == len(outcomes)
        self.states = states
        self.actions = actions
        self.outcomes = outcomes

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        state = self.states[idx]
        action = self.actions[idx]
        outcome = self.outcomes[idx]
        sample = (state, action, outcome)
        return sample

torch.manual_seed(0)

if params['new_hex_grid'] or params['new_cnn'] or params['load_cnn'] or params['load_hex_grid']:
    train_length = int(len(actions)*0.9)
    dataset = MoHexGamesDataset(states, actions, outcomes)
    training_data, test_data = random_split(dataset, [train_length, len(dataset)-train_length])
    kwargs = {'num_workers': 1, 'pin_memory': True}
    training_dataloader = DataLoader(training_data, batch_size=params['batch_size'], shuffle=True, **kwargs)
    test_dataloader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True, **kwargs)

if params['new_det_shannon'] or params['new_rni_shannon'] or params['new_sage'] or params['load_det_shannon'] or params['load_rni_shannon'] or params['load_sage']:
    train_length = int(len(shannon_graphs)*0.9)
    shannon_training_data, shannon_test_data = random_split(shannon_graphs, [train_length, len(shannon_graphs)-train_length])
    kwargs = {'num_workers': 1, 'pin_memory': True}
    shannon_training_dataloader = ShannonDataLoader(shannon_training_data, batch_size=params['batch_size'], shuffle=True, **kwargs)
    shannon_test_dataloader = ShannonDataLoader(shannon_test_data, batch_size=params['batch_size'], shuffle=True, **kwargs)


# Networks

In [None]:
GNN_SCALING_FACTOR = 13.9123535

class GNNLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.my_linear_map        = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.neighbour_linear_map = nn.Linear(hidden_dim, hidden_dim, bias=True)
        #self.my_linear_map.weight.data.normal_(mean=0., std=sqrt(1./(GNN_SCALING_FACTOR*hidden_dim)))
        #self.neighbour_linear_map.weight.data.normal_(mean=0., std=sqrt(1./(GNN_SCALING_FACTOR*hidden_dim)))
        self.output_values = []

    def forward(self, node_embeddings, adj_matrix):
        my_message = self.my_linear_map(node_embeddings)
        neighbour_message = self.neighbour_linear_map(adj_matrix @ node_embeddings)
        #neighbour_message = self.neighbour_linear_map(torch.sparse.mm(adj_matrix, node_embeddings))
        return my_message + neighbour_message


In [None]:
#@title MLPs Within Layers
class MLP(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        factor=2
        self.first_layer = nn.Linear(hidden_dim, factor*hidden_dim)
        self.batch_norm  = nn.BatchNorm1d(factor*hidden_dim)
        self.last_layer  = nn.Linear(factor*hidden_dim, hidden_dim)

    def forward(self, x):
        x = self.first_layer(x)
        #x = torch.transpose(x, -2, -1)
        x = self.batch_norm(x)
        #x = torch.transpose(x, -2, -1)
        x = F.relu(x)
        x = self.last_layer(x)
        return x

class MessageMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.first_layer = nn.Linear(input_dim, hidden_dim)
        self.batch_norm  = nn.BatchNorm1d(hidden_dim)
        self.last_layer  = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.first_layer(x)
        x = self.batch_norm(x)
        x = F.relu(x)
        x = self.last_layer(x)
        return x

class TransposeMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.first_layer = nn.Linear(input_dim, hidden_dim)
        self.batch_norm  = nn.BatchNorm1d(hidden_dim)
        self.last_layer  = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.first_layer(x)
        x = torch.transpose(x, -2, -1)
        x = self.batch_norm(x)
        x = torch.transpose(x, -2, -1)
        x = F.relu(x)
        x = self.last_layer(x)
        return x

In [None]:
#@title Base Message Passing Layer
class MPNNLayer(MessagePassing):
    def __init__(self, input_dim, output_dim, transpose=False, aggr='add'):
        super().__init__(aggr=aggr)
        self.mlp_msg = TransposeMLP(input_dim, output_dim, output_dim) if transpose else MessageMLP(input_dim, output_dim, output_dim)
        self.self_upd = TransposeMLP(input_dim, output_dim, output_dim) if transpose else MessageMLP(input_dim, output_dim, output_dim)

    def forward(self, h, edge_index):
        out = self.propagate(edge_index, h=h)
        return out

    def message(self, h_i, h_j):
        return h_j

    def aggregate(self, inputs, index):
        return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)

    def update(self, aggr_out, h):
        return self.self_upd(h) + self.mlp_msg(aggr_out)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(emb_dim={hex_grid_params["hidden_dim"]}, aggr={self.aggr})')



In [None]:
#@title GIN Layer
class GINLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.mlp = MLP(hidden_dim)
        self.eps = nn.Parameter(torch.Tensor([0]), requires_grad=True)

    def forward(self, node_embeddings, adj_matrix):
        out = self.mlp((1 + self.eps) * node_embeddings + torch.matmul(adj_matrix, node_embeddings))
        return out


# Modules


In [None]:
#@title Shannon Module
class ShannonModule(nn.Module):
    def __init__(self, layer_type, num_layers, input_dim, hidden_dim, drop_ratio, random_dims=0, use_batch_norm=True, use_jump_and_skip_connections=True):
        super().__init__()
        self.use_batch_norm = use_batch_norm
        self.use_jump_and_skip_connections = use_jump_and_skip_connections
        self.random_dims = random_dims
        self.act_fun = F.relu

        self.drop_ratio = drop_ratio
        self.first_layer = nn.Linear(input_dim+random_dims, hidden_dim)

        self.hidden_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList() if self.use_batch_norm else None
        # If we have any hidden layers, add them
        for _ in range(num_layers):
            self.hidden_layers.append(layer_type(hidden_dim, hidden_dim))
            if self.use_batch_norm: self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.num_layers = len(self.hidden_layers)

        self.policy_layer = MessageMLP(hidden_dim, hidden_dim, 1)
        self.head_layer = nn.Sequential(
            nn.Linear(3*hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, edge_index, batch):
        if self.random_dims > 0:
            # Add random dims.
            # x has shape [nodes, 2]
            # we want to cat [nodes, random_dims]
            # so, we want to cat along the 0th dimension
            random_noise = torch.randn(x.shape[0], self.random_dims).to(device)
            x = torch.cat([x, random_noise], dim=1)
        x = self.first_layer(x)
        h_list = []

        for i, l in enumerate(self.hidden_layers):
            prev_h = h_list[-1] if i > 0 else x
            h = l(prev_h, edge_index)
            if self.use_batch_norm: h = self.batch_norms[i](h)

            if i != len(self.hidden_layers)-1:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            else:
                h = F.dropout(h, self.drop_ratio, training = self.training)

            if self.use_jump_and_skip_connections: h = prev_h + h
            h_list.append(h)

        node_representations = 0
        if self.use_jump_and_skip_connections:
            for h in h_list:
                node_representations += h
        else:
            node_representations = h_list[-1]

        pooled_sums_by_graph  = scatter(node_representations, batch, dim=0, reduce='sum')
        pooled_means_by_graph = scatter(node_representations, batch, dim=0, reduce='sum')
        pooled_maxs_by_graph  = scatter(node_representations, batch, dim=0, reduce='sum')
        pooled_representations = torch.cat([pooled_sums_by_graph, pooled_maxs_by_graph, pooled_means_by_graph], dim=-1)

        values = self.head_layer(pooled_representations)
        values = torch.squeeze(torch.tanh(values))

        policy = torch.squeeze(self.policy_layer(node_representations), dim=-1)
        return (policy, values)



In [None]:
#@title Hex Grid Module
class HexGridModule(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, drop_ratio, random_dims=0, use_batch_norm=True, use_jump_and_skip_connections=True):
        super().__init__()
        self.use_batch_norm = use_batch_norm
        self.use_jump_and_skip_connections = use_jump_and_skip_connections
        self.random_dims = random_dims
        self.act_fun = F.relu

        self.drop_ratio = drop_ratio
        self.first_layer = nn.Linear(input_dim+random_dims, hidden_dim)

        self.hidden_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList() if self.use_batch_norm else None
        # If we have any hidden layers, add them
        for _ in range(num_layers):
            self.hidden_layers.append(MPNNLayer(hidden_dim, hidden_dim, transpose=True))
            if self.use_batch_norm: self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.num_layers = len(self.hidden_layers)

        self.policy_layer = TransposeMLP(hidden_dim, hidden_dim, 1)
        self.head_layer = nn.Sequential(
            nn.Linear(3*hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, edge_index):
        if self.random_dims > 0:
            # Add random dims.
            # x has shape [nodes, 2]
            # we want to cat [nodes, random_dims]
            # so, we want to cat along the 0th dimension
            random_noise = torch.randn(x.shape[0], self.random_dims).to(device)
            x = torch.cat([x, random_noise], dim=1)
        x = self.first_layer(x)
        h_list = []

        for i, l in enumerate(self.hidden_layers):
            prev_h = h_list[-1] if i > 0 else x

            h = l(prev_h, edge_index)
            if self.use_batch_norm:
                h = torch.transpose(h, -2, -1)
                h = self.batch_norms[i](h)
                h = torch.transpose(h, -2, -1)

            if i != len(self.hidden_layers)-1:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            else:
                h = F.dropout(h, self.drop_ratio, training = self.training)

            if self.use_jump_and_skip_connections: h = prev_h + h
            h_list.append(h)

        node_representations = 0
        if self.use_jump_and_skip_connections:
            for h in h_list:
                node_representations += h
        else:
            node_representations = h_list[-1]

        pooled_sum = torch.sum(node_representations, dim=1)
        pooled_max, _ = torch.max(node_representations, dim=1)
        pooled_mean = torch.mean(node_representations, dim=1)
        pooled_representations = torch.cat([pooled_sum, pooled_max, pooled_mean], dim=-1)

        values = self.head_layer(pooled_representations)
        values = torch.squeeze(torch.tanh(values))
        policy = torch.squeeze(self.policy_layer(node_representations), dim=-1)

        return (policy, values)


In [None]:
def forward(self, x, adj_matrix, positions=None):
        x = self.first_layer(x)
        h_list = []

        for i, l in enumerate(self.hidden_layers):
            prev_h = h_list[-1] if i > 0 else x
            if positions == None:
                h = l(prev_h, adj_matrix)
            else:
                h = l(prev_h, positions, adj_matrix)

            h = torch.transpose(h, -2, -1)
            h = self.batch_norms[i](h)
            h = torch.transpose(h, -2, -1)

            if i != len(self.hidden_layers)-1:
                h = F.dropout(F.relu(h), self.drop_ratio, training = self.training)
            else:
                h = F.dropout(h, self.drop_ratio, training = self.training)

            h = prev_h + h
            h_list.append(h)

        node_representations = 0
        for h in h_list:
            node_representations += h

        pooled_sum = torch.sum(node_representations, dim=1)
        pooled_max, _ = torch.max(node_representations, dim=1)
        pooled_mean = torch.mean(node_representations, dim=1)
        pooled_representation = torch.cat([pooled_sum, pooled_max, pooled_mean], dim=-1)

        policy = torch.squeeze(self.policy_layer(node_representations), dim=-1)

        value = torch.tanh(self.head_mlp(pooled_representation))

        return torch.cat((policy, value), dim=-1)


In [None]:
#@title CNN Module
# Blue too low means factor too high, blue too high means factor too low
CNN_SCALING_FACTOR = 4.125
CNN_SCALING_FACTOR = 4

class CNNModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.name = 'CNN'
        self.act_fun = F.relu
        self.layers = nn.ModuleList()
        self.drop_ratio = cnn_params['drop_ratio']

        self.layers.append(
            nn.Conv2d(
                in_channels=cnn_params['input_dim'],
                out_channels=cnn_params['hidden_dim'],
                kernel_size=5,
                stride=1,
                padding=2
            )
        )

        for _ in range(cnn_params['num_layers']-2):
            self.layers.append(
                nn.Conv2d(
                    in_channels=cnn_params['hidden_dim'],
                    out_channels=cnn_params['hidden_dim'],
                    kernel_size=3,
                    stride=1,
                    padding=1
                )
            )

        self.layers.append(
            nn.Conv2d(
                in_channels=cnn_params['hidden_dim'],
                out_channels=1,
                kernel_size=1,
                stride=1,
            )
        )

        self.num_layers = len(self.layers)

        self.policy_layer = TransposeMLP(cnn_params['hidden_dim'], cnn_params['hidden_dim'], 1)
        self.head_layer = nn.Sequential(
            nn.Linear(3*cnn_params['hidden_dim'], cnn_params['hidden_dim']),
            nn.ReLU(),
            nn.Linear(cnn_params['hidden_dim'], 1)
        )
    def forward(self, x):
        hidden_layers = self.layers[:-1]
        last_layer = self.layers[-1]

        for l in hidden_layers:
            x = l(x)
            x = self.act_fun(x)
            x = F.dropout(x, self.drop_ratio, training = self.training)

        #x = last_layer(x)
        x = torch.reshape(x, (-1,num_tiles,cnn_params['hidden_dim']))

        pooled_sum = torch.sum(x, dim=1)
        pooled_max, _ = torch.max(x, dim=1)
        pooled_mean = torch.mean(x, dim=1)
        pooled_representation = torch.cat([pooled_sum, pooled_max, pooled_mean], dim=-1)

        values = self.head_layer(pooled_representation)
        values = torch.squeeze(torch.tanh(values))
        policy = torch.squeeze(self.policy_layer(x), dim=-1)

        return (policy, values)


# Helper Functions

In [None]:
#@title Model Wrapper
class Model:
    def __init__(self, name, module, optimiser):
        self.name = name
        self.module = module
        self.optimiser = optimiser
        self.epochs_trained = 0
        self.performance_degradation_streak = 0
        self.early_stop_threshold = params['early_stop_threshold']

        self.test_losses = {'action':[], 'outcome':[], 'combined':[]}
        self.test_accuracies = {'action':[], 'outcome':[]}

        self.training_losses = {'action':[], 'outcome':[], 'combined':[]}
        self.training_accuracies = {'action':[], 'outcome':[]}

        self.lowest_combined_loss = 100000

        self.continue_training = True if self.performance_degradation_streak < self.early_stop_threshold else False

    def update(self, new_test_losses, new_test_accuracies, new_training_losses, new_training_accuracies):
        if self.continue_training:
            self.epochs_trained += 1

            self.test_losses['action'].append(new_test_losses['action'])
            self.test_losses['outcome'].append(new_test_losses['outcome'])
            self.test_losses['combined'].append(new_test_losses['combined'])
            self.test_accuracies['action'].append(new_test_accuracies['action'])
            self.test_accuracies['outcome'].append(new_test_accuracies['outcome'])

            self.training_losses['action'].append(new_training_losses['action'])
            self.training_losses['outcome'].append(new_training_losses['outcome'])
            self.training_losses['combined'].append(new_training_losses['combined'])
            self.training_accuracies['action'].append(new_training_accuracies['action'])
            self.training_accuracies['outcome'].append(new_training_accuracies['outcome'])

        if test_losses['combined'] >= self.lowest_combined_loss:
            self.performance_degradation_streak += 1
        elif test_losses['combined'] < self.lowest_combined_loss:
            self.lowest_combined_loss = test_losses['combined']
            self.performance_degradation_streak = 0

        if self.performance_degradation_streak >= self.early_stop_threshold:
            self.continue_training = False


In [None]:
#@title Build and Load Models
all_models = []

if params['new_hex_grid']:
    mpnn_module = HexGridModule(num_layers=hex_grid_params['num_layers'],
                                input_dim=hex_grid_params['input_dim'],
                                hidden_dim=hex_grid_params['hidden_dim'],
                                drop_ratio=hex_grid_params['drop_ratio'],
                                use_batch_norm=True,
                                use_jump_and_skip_connections=False).to(device)
    mpnn_optimiser = torch.optim.Adam(mpnn_module.parameters(), lr=hex_grid_params['learning_rate'], weight_decay=hex_grid_params['weight_decay'])
    all_models.append(Model('HexGridGNN', mpnn_module, mpnn_optimiser))

if params['new_det_shannon']:
    det_shannon_module = ShannonModule(layer_type=MPNNLayer,
                                       num_layers=det_shannon_params['num_layers'],
                                       input_dim=det_shannon_params['input_dim'],
                                       hidden_dim=det_shannon_params['hidden_dim'],
                                       drop_ratio=det_shannon_params['drop_ratio'],
                                       use_batch_norm=True,
                                       use_jump_and_skip_connections=False).to(device)
    det_shannon_optimiser = torch.optim.Adam(det_shannon_module.parameters(), lr=det_shannon_params['learning_rate'], weight_decay=det_shannon_params['weight_decay'])
    all_models.append(Model('ShannonGNN', det_shannon_module, det_shannon_optimiser))

if params['new_rni_shannon']:
    rni_shannon_module = ShannonModule(layer_type=MPNNLayer,
                                       num_layers=rni_shannon_params['num_layers'],
                                       input_dim=rni_shannon_params['input_dim'],
                                       hidden_dim=rni_shannon_params['hidden_dim'],
                                       drop_ratio=rni_shannon_params['drop_ratio'],
                                       random_dims=rni_shannon_params['random_dims'],
                                       use_batch_norm=True,
                                       use_jump_and_skip_connections=False).to(device)

    rni_shannon_optimiser = torch.optim.Adam(rni_shannon_module.parameters(),
                                             lr=rni_shannon_params['learning_rate'],
                                             weight_decay=rni_shannon_params['weight_decay'])

    all_models.append(Model('ShannonGNN-RNI', rni_shannon_module, rni_shannon_optimiser))

if params['new_cnn']:
    cnn_module = CNNModule().to(device)
    cnn_optimiser = torch.optim.Adam(cnn_module.parameters(), lr=cnn_params['learning_rate'], weight_decay=cnn_params['weight_decay'])
    all_models.append(Model('CNN', cnn_module, cnn_optimiser))

if params['new_sage']:
    sage_module = ShannonModule(layer_type = SAGEConv,
                            num_layers = sage_params['num_layers'],
                            input_dim = sage_params['input_dim'],
                            hidden_dim = sage_params['hidden_dim'],
                            drop_ratio = sage_params['drop_ratio'],
                            use_batch_norm=False,
                            use_jump_and_skip_connections=False,
                            ).to(device)
    sage_optimiser = torch.optim.Adam(sage_module.parameters(), lr=sage_params['learning_rate'], weight_decay=sage_params['weight_decay'])
    all_models.append(Model('GraphSAGE', sage_module, sage_optimiser))

all_models = []
for filename in model_checkpoint_filenames_to_load:
    model = torch.load(filename)
    all_models.append(model)


In [None]:
#@title Plot Initial Layer Outputs
def get_weights_by_layer(model):
    weights_by_layer = []
    for layer in model.layers:
        weights_this_layer = []
        for child in layer.modules():
            if hasattr(child, 'weight'):
                weights_this_layer.extend(
                    torch.reshape(child.weight.data, (-1,)).detach().cpu().numpy())
        weights_by_layer.append(weights_this_layer)
    return weights_by_layer

def plot_weights_by_layer_histogram(weights_by_layer, title, bins=[10 for _ in range(100)]):
    for i, weights in enumerate(weights_by_layer):
        plt.hist(weights, bins=bins[i], label=i+1, histtype=u'step', density=True)
    plt.title(title)
    plt.legend()
    plt.show()

def get_outputs_by_layer(model):
    layers = model.module.layers
    outputs_by_layer = [[] for _ in range(len(layers))]
    for data in training_dataloader:
        x = data[0].to(device)
        if model.name == 'CNN':
            x = torch.reshape(x, (-1, 3, board_width, board_width))
        for i, layer in enumerate(layers):
            # Pass in the adjacency matrix if we are using a GNN, otherwise don't.
            # Calculate the outputs for this batch.
            x = layer(x, adj_matrix) if model.name == 'GNN' else layer(x)
            # Combine these batch outputs into a big, flat array of scalars.
            output_values_array = torch.reshape(x, (-1,)).detach().cpu().numpy()
            # Calculate how many outputs values to sample from this big array.
            num_to_sample = max(len(x)//len(layers), 1)
            # Unpack this sample into the array of samples for this layer over all batches.
            outputs_by_layer[i].extend(np.random.choice(output_values_array, num_to_sample))
            # Apply the activation funcion, ready for the next layer.
            x = model.module.act_fun(x)

    return outputs_by_layer

if params['plot_initial_outputs']:
    for model in all_models:
        initial_outputs_by_layer = get_outputs_by_layer(model)
        plot_weights_by_layer_histogram(initial_outputs_by_layer, title=f'{model.name} Initial Outputs by Layer')


In [None]:
#@title Plot all Models
def plot_results(title, xlabel, ylabel, metric, loss_mode = True):
    max_e = 0
    for i, model in enumerate(all_models):
        training_values = model.training_losses[metric] if loss_mode else model.training_accuracies[metric]
        test_values = model.test_losses[metric] if loss_mode else model.test_accuracies[metric]
        e = max(len(training_values), len(test_values))
        if e > max_e:
            max_e = e
        training_xs = range(1,len(training_values)+1)
        test_xs = range(1,len(test_values)+1)
        training_loss_line = plt.plot(training_xs, training_values, label=f'{model.name} training')[0]
        plt.plot(test_xs, test_values, '--', label=f'{model.name} test', color=training_loss_line.get_color())
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()
    filename = f'drive/MyDrive/Colab Notebooks/hexbot/diagrams/E{max_e}-accuracies-{random.randint(1,1000)}.pdf' if not loss_mode else f'drive/MyDrive/Colab Notebooks/hexbot/diagrams/E{max_e}-losses-{random.randint(1,1000)}.pdf'
    plt.savefig(filename)
    plt.show()

def display_all_model_results():
    plot_results('Training and Test Policy Loss by Epoch', 'Epoch', 'Loss', 'action', True)
    plot_results('Training and Test Value Loss by Epoch', 'Epoch', 'Loss', 'outcome', True)
    plot_results('Training and Test Combined Loss by Epoch', 'Epoch', 'Loss', 'combined', True)

    plot_results('Training and Test Policy Accuracy by Epoch', 'Epoch', 'Accuracy', 'action', False)
    plot_results('Training and Test Value Accuracy by Epoch', 'Epoch', 'Accuracy', 'outcome', False)


In [None]:
#@title Model Saver
def save_model(model):
    filename = f'drive/MyDrive/Colab Notebooks/hexbot/temp/action-{round(max(model.test_accuracies["action"]), 2)}_E{model.epochs_trained}_{model.name}_D{model.module.num_layers}_outcome{round(max(model.test_accuracies["outcome"]), 2)}.pt'
    torch.save(model, filename)

    if model.name == 'GNN' or model.name == 'GIN':
        example = torch.rand(1, num_tiles, gnn_params['input_dim']).to(device)
        traced_script_module = torch.jit.trace(model.module, example_kwarg_inputs={'edge_index':adj_matrix, 'x':example})


# Training

In [None]:
#@title Get Model Outputs on Data
def get_model_outputs_on_data(model, x, true_actions, true_outcomes):
    if model.name == 'CNN':
        x = torch.reshape(x, (-1, cnn_params['input_dim'], board_width, board_width))
        policy_head_output, value_head_output = model.module(x)
    elif model.name == 'HexGridGNN':
        policy_head_output, value_head_output = model.module(x, global_edge_index)

    losses = {}
    losses['action']   = params['policy_loss_fn'](policy_head_output, true_actions)
    losses['outcome']  = params['value_loss_fn'](value_head_output, true_outcomes.float())
    losses['combined'] = torch.add(params['loss_weight'] * losses['action'], (1-params['loss_weight']) * losses['outcome'])

    action_votes  = torch.argmax(F.softmax(policy_head_output, dim=0), dim=1)
    outcome_votes = torch.sign(value_head_output)

    true_actions = torch.argmax(true_actions, dim=1)

    actions_predicted_correctly = torch.sum(true_actions == action_votes).item()
    outcomes_predicted_correctly = torch.sum(true_outcomes == outcome_votes).item()
    accuracies = {
        'action':actions_predicted_correctly/len(x),
        'outcome':outcomes_predicted_correctly/len(x)
    }

    return losses, accuracies


In [None]:
#@title Get Shannon Model Outputs on Data
def sum_pool(x, node_to_graph_map):
   # Implement the function here
   return scatter(x, node_to_graph_map, dim=0, reduce="sum")

def get_shannon_model_outputs_on_data(model, x, edge_index, true_actions, true_outcomes, first_index_of_each_batch=None, batch=None):
    policy_head_output, value_head_output = model.module(x, edge_index, batch)

    policy_by_graph = torch.tensor_split(policy_head_output, first_index_of_each_batch[1:-1])
    losses = {'action': 0.0}
    action_votes = torch.zeros(len(true_actions)).to(device)
    for i, policy in enumerate(policy_by_graph):
        losses['action'] += params['policy_loss_fn'](policy, true_actions[i])
        action_votes[i] = torch.argmax(F.softmax(policy, dim=0), dim=-1)

    losses['action'] /= len(policy_by_graph)

    losses['outcome']  = params['value_loss_fn'](value_head_output, true_outcomes.float())
    losses['combined'] = torch.add(params['loss_weight'] * losses['action'], (1-params['loss_weight']) * losses['outcome'])

    outcome_votes = torch.sign(value_head_output)

    actions_predicted_correctly = torch.sum(true_actions == action_votes).item()
    outcomes_predicted_correctly = torch.sum(true_outcomes == outcome_votes).item()
    accuracies = {
        'action':actions_predicted_correctly/len(true_actions),
        'outcome':outcomes_predicted_correctly/len(true_actions)
    }

    return losses, accuracies

In [None]:
#@title Model Training Iteration
def train_one_epoch(model):
    total_chunk_losses = {'action':0.0, 'outcome':0.0, 'combined':0.0}
    sum_of_chunk_accuracies = {'action':0.0, 'outcome':0.0}

    if model.name == 'HexGridGNN' or model.name == 'CNN':
        num_batches_in_chunk = ceil(len(actions)*0.9/params['batch_size'] * 0.1)
        loader = training_dataloader
    else:
        num_batches_in_chunk = ceil(len(shannon_graphs)*0.9/params['batch_size'] * 0.1)
        loader = shannon_training_dataloader

    model.module.train()
    for i, data in enumerate(loader):
        # Every data instance is an input + label pair
        if model.name == 'HexGridGNN' or model.name == 'CNN':
            x, true_actions, game_outcomes = data[0].to(device), data[1].to(device), data[2].to(device)
            batch_losses, batch_accuracies = get_model_outputs_on_data(model, x, true_actions, game_outcomes)
        else:
            x, edge_index, true_actions, game_outcomes, first_index_of_each_batch, batch = data.x.to(device), data.edge_index.to(device), data.y.to(device), data.outcome.to(device), data.ptr, data.batch.to(device)
            batch_losses, batch_accuracies = get_shannon_model_outputs_on_data(model, x, edge_index, true_actions, game_outcomes, first_index_of_each_batch, batch)

        # Zero your gradients for every batch!
        model.optimiser.zero_grad()
        # Compute the loss gradients
        batch_losses['combined'].backward()
        # Adjust learning weights along this gradient
        model.optimiser.step()
        # Gather data and report
        total_chunk_losses['action'] += batch_losses['action'].item()
        total_chunk_losses['outcome'] += batch_losses['outcome'].item()
        total_chunk_losses['combined'] += batch_losses['combined'].item()

        sum_of_chunk_accuracies['action'] += batch_accuracies['action']
        sum_of_chunk_accuracies['outcome'] += batch_accuracies['outcome']

        if i % num_batches_in_chunk == num_batches_in_chunk-1:
            avg_losses_this_chunk = {
                'action':total_chunk_losses['action'] / num_batches_in_chunk,
                'outcome':total_chunk_losses['outcome'] / num_batches_in_chunk,
                'combined':total_chunk_losses['combined'] / num_batches_in_chunk
            }
            avg_accuracies_this_chunk = {
                'action':sum_of_chunk_accuracies['action'] / num_batches_in_chunk,
                'outcome':sum_of_chunk_accuracies['outcome'] / num_batches_in_chunk,
            }

            print(f'    batch {(i+1)//num_batches_in_chunk} loss: {tuple(avg_losses_this_chunk.values())} acc: {tuple(avg_accuracies_this_chunk.values())}')
            total_chunk_losses = {'action':0.0, 'outcome':0.0, 'combined':0.0}
            sum_of_chunk_accuracies = {'action':0.0, 'outcome':0.0}

    return avg_losses_this_chunk, avg_accuracies_this_chunk


In [None]:
#@title Model Testing Iteration
def test_model(model):
    total_losses = {'action':0, 'outcome':0, 'combined':0}
    sum_of_accuracies = {'action':0, 'outcome':0}

    if model.name == 'GNN' or model.name == 'GIN' or model.name == 'HexGridGNN' or model.name == 'CNN':
        loader = test_dataloader
    else:
        loader = shannon_test_dataloader

    model.module.eval()
    with torch.no_grad():
        for data in loader:
            if model.name == 'HexGridGNN' or model.name == 'CNN':
                x, label_distributions, game_outcomes = data[0].to(device), data[1].to(device), data[2].to(device)
                losses, accuracies = get_model_outputs_on_data(model, x, label_distributions, game_outcomes)
            else:
                x, edge_index, label_distributions, game_outcomes, first_index_of_each_batch, batch = data.x.to(device), data.edge_index.to(device), data.y.to(device), data.outcome.to(device), data.ptr, data.batch.to(device)
                losses, accuracies = get_shannon_model_outputs_on_data(model, x, edge_index, label_distributions, game_outcomes, first_index_of_each_batch, batch)

            total_losses['action'] += losses['action'].item()
            total_losses['outcome'] += losses['outcome'].item()
            total_losses['combined'] += losses['combined'].item()

            sum_of_accuracies['action'] += accuracies['action']
            sum_of_accuracies['outcome'] += accuracies['outcome']

    num_examples = len(loader)

    avg_losses = {
        'action'  : total_losses['action'] / num_examples,
        'outcome' : total_losses['outcome'] / num_examples,
        'combined': total_losses['combined'] / num_examples
    }
    avg_accuracies = {
        'action'  : sum_of_accuracies['action'] / num_examples,
        'outcome' : sum_of_accuracies['outcome'] / num_examples
    }

    return avg_losses, avg_accuracies


In [None]:
#@title Training Loop

for epoch in range(params['num_epochs']):
    print('EPOCH {}:'.format(epoch + 1))
    for model in all_models:
        if model.continue_training:
            print(f'  {model.name}:')
            training_losses, training_accuracies = train_one_epoch(model)
            test_losses, test_accuracies = test_model(model)

            model.update(test_losses, test_accuracies,
                         training_losses, training_accuracies)

            #if params['save_interval'] > 0 and test_accuracies['action'] > best_test_accuracy:
            #    best_test_accuracy = test_accuracies['action']
            #    save_model(model)

            print(f'    {model.name} TEST LOSS {tuple(test_losses.values())} TEST ACC {tuple(test_accuracies.values())}')
        else:
            print(f'{model.name} Skipped')

    if params['plot_interval'] > 0:
        if epoch % params['plot_interval'] == params['plot_interval']-1:
            display_all_model_results()
    if params['save_interval'] > 0:
        if epoch % params['save_interval'] == params['save_interval']-1:
            for model in all_models:
                save_model(model)


# Results

In [None]:
#@title Final Results Display
if params['save_interval'] > 0:
    for model in all_models:
        save_model(model)

display_all_model_results()

for i, model in enumerate(all_models):
    print(f'{model.name} - Best test action accuracy:', max(model.test_accuracies['action']))
    print(f'{model.name} - Best test outcome accuracy:', max(model.test_accuracies['outcome']))


In [None]:
#@title Output Parameter Count
from itertools import chain

for model in all_models:
    total_params = sum(p.numel() for p in model.module.parameters())
    print(model.name, total_params)
