## Node and Graph

In [90]:
class Node:
    def __init__(self, val):
        self.val = val # Value of node = its position (left, right, zigzag, continue).
        self.adj_list = set()

# create a list of edges for Game
class Graph:
    def __init__(self, num_rows, num_cols):
        self.num_rows = num_rows
        self.num_cols = num_cols
        self.nodes_int = self._create_nodes()
        
        self.edges = set()
        for row in range(self.num_rows): # Kinda like initialising a 2D matrix. From the nodes generated, add the edges to the set.
            for col in range(self.num_cols):
                node = self.nodes_mat[row][col]
                
                adj_list = self.mapper[node.val].adj_list
                node = self.mapper[node.val]
                for other_node in adj_list:
                    self.edges.add((node, other_node))
                    self.edges.add((other_node, node))
        
    def init_start_and_end(self):
        self.start = 1 # top-left
        self.end = self.num_rows * self.num_cols # bottom-right
        
    def _create_nodes(self):
        nodes_int, i = [], 1
        nodes = [[Node((row, col)) for col in range(self.num_cols)] for row in range(self.num_rows)] 
        mapper = dict()
        int_mapper = dict()
        for row in range(self.num_rows):
            for col in range(self.num_cols):
                node = Node(i) # Create the nodes, number 1 to (m * n).
                nodes_int.append(node)
                mapper[(row, col)] = node
                int_mapper[i] = node
                i += 1
        
        for row in range(self.num_rows):
            for col in range(self.num_cols):
                node = nodes[row][col] # this stuff below is where we initialise the neighbours.
                if row > 0:
                    node.adj_list.add(nodes[row - 1][col])  # Upper neighbour.
                    nodes[row - 1][col].adj_list.add(node)
                    
                    current_mapping = mapper[(row, col)]
                    nbr_mapping = mapper[(row - 1, col)]
                    
                    current_mapping.adj_list.add(nbr_mapping)
                    nbr_mapping.adj_list.add(current_mapping)
                    
                if row < self.num_rows - 1:
                    node.adj_list.add(nodes[row + 1][col])  # Lower neighbour.
                    nodes[row + 1][col].adj_list.add(node)
                    
                    current_mapping = mapper[(row, col)]
                    nbr_mapping = mapper[(row + 1, col)]
                    
                    current_mapping.adj_list.add(nbr_mapping)
                    nbr_mapping.adj_list.add(current_mapping)
                    
                if col > 0:
                    node.adj_list.add(nodes[row][col - 1])  # Left neighbour.
                    nodes[row][col - 1].adj_list.add(node)
                    
                    current_mapping = mapper[(row, col)]
                    nbr_mapping = mapper[(row, col - 1)]
                    
                    current_mapping.adj_list.add(nbr_mapping)
                    nbr_mapping.adj_list.add(current_mapping)
                    
                if col < self.num_cols - 1:
                    node.adj_list.add(nodes[row][col + 1])  # Right neighbour.
                    nodes[row][col + 1].adj_list.add(node)
                    
                    current_mapping = mapper[(row, col)]
                    nbr_mapping = mapper[(row, col + 1)]
                    
                    current_mapping.adj_list.add(nbr_mapping)
                    nbr_mapping.adj_list.add(current_mapping)
                
        self.nodes_mat = nodes
        self.mapper = mapper
        self.int_mapper = int_mapper
                    
        return nodes_int
        
    def print_graph(self): # for debugging purposes
        print([node.val for row in self.nodes_mat for node in row])
        print()
        print([node.val for node in self.nodes_int])

## GameAI

In [151]:
import random
class GameAI:
    def __init__(self, graph):
        self.graph = graph
        self.m = self.graph.num_rows
        self.n = self.graph.num_cols
        
        self.node_mapping = dict()
        for i in range(1, (self.m * self.n) + 1):
            self.node_mapping[i] = self.graph.nodes_int[i - 1] # e.g. 1 is in index 0, 2 is index 1, etc.
        
        self.edges = []
        for i in range(1, (self.graph.num_rows * self.graph.num_cols) + 1):
            adj_list = [node.val for node in self.graph.int_mapper[i].adj_list]
            for adj_node in adj_list:
                if (adj_node, i) in self.edges:
                    continue
                self.edges.append((i, adj_node))
        
        self.edges = sorted(self.edges)
        
        self.unsecured_count = (2 * self.m * self.n) - self.m - self.n # this is for the CUT player
        self.secured_edges = [] # fix  
        self.removed_edges = [] # cut
        
        self.remaining = {(node1.val, node2.val) for node1, node2 in self.graph.edges} 
        self.fix_win = False
        self.end = False
        
    def reset(self): # Reset everything for the next training iteration.
        self.unsecured_count = (2 * self.m * self.n) - self.m - self.n 
        self.secured_edges = []
        self.removed_edges = []
        
        self.remaining = {(node1.val, node2.val) for node1, node2 in self.graph.edges} 
        self.fix_win = False
        self.end = False
        
    # This is what our AI will train against. Random shit.
    def choose_edge_to_cut(self):
        edge_to_cut = random.choice(list(self.remaining))
        return edge_to_cut    
    
    # Plays step-by-step. This is what we'll use for "learning".
    def next_step_player(self, chosen_edge):
        if self.unsecured_count > 0 and not self.end:
            if not self.end:
                # 1. FIX player's turn: where the magic happens. (HE NOW GOES FIRST)
                if len(self.remaining) == 0:
                    # No more valid edges to choose.
                    self.fix_win = False
                    self.end = True
                else:
                    self.fix(chosen_edge)
            
            if self.is_fix_path_complete():
                self.fix_win = True
                self.end = True
                return
                
            # 2. CUT/bot player's turn.
            if not self.end:
                if len(self.remaining) == 0:
                    # No more valid edges to choose.
                    self.fix_win = False
                    self.end = True
                else:
                    edge_to_cut = self.choose_edge_to_cut()
                    self.cut(edge_to_cut)
                    
            if len(self.remaining) == 0:
                self.end = True
        else:
            self.end = True

    def is_fix_path_complete(self): # This does BFS to check if there is a path from the start to the end.
        visited = set()
        stack = [1]

        while stack:
            current_node = stack.pop()
            if current_node == self.m * self.n: # e.g. 4x4, 16 is the bottom-right.
                return True

            if current_node not in visited:
                visited.add(current_node)
                adj_list = self.node_mapping[current_node].adj_list
                
                for nbr in adj_list:
                    edge = (current_node, nbr.val)
                    reverse_edge = (nbr.val, current_node)

                    if edge in self.secured_edges or reverse_edge in self.secured_edges:
                        if nbr.val not in visited:
                            stack.append(nbr.val)

        return False
    
    # 1. FIX player's function; secures unsecured edge in question (and its reverse).
    def fix(self, edge):
        # edge = ex: (1, 4)
        if edge in self.remaining:
            self.remaining.remove(edge)
            self.secured_edges.append(edge)
            self.unsecured_count -= 1
            
        # Also remove the reverse direction of the edge.
        reverse_edge = (edge[1], edge[0])
        if reverse_edge in self.remaining:
            self.remaining.remove(reverse_edge)

    # 2. CUT player's function; removes unsecured edge in question (and its reverse).
    def cut(self, edge):
        # edge = ex: (1, 4)
        if edge in self.remaining:
            self.remaining.remove(edge)
            self.removed_edges.append(edge)
            self.unsecured_count -= 1
            
        # Also remove the reverse direction of the edge.
        reverse_edge = (edge[1], edge[0])
        if reverse_edge in self.remaining:
            self.remaining.remove(reverse_edge)
  
    # Reward function
    def get_reward(self):
        if self.fix_win:
            # Positive reward when the FIX player wins
            reward = 1.0
        elif self.end:
            # Negative reward when the FIX player loses
            reward = -1.0
        else:
            # Intermediate reward for the ongoing game
            reward = 0.0            

        return reward
            
    def get_state(self):
        # Define the state representation based on the game state.
        secured_count = len(self.secured_edges)
        remaining_count = self.unsecured_count
        secured_edges = self.secured_edges
        deleted_edges = self.removed_edges
        remaining_edges = list(self.remaining) # Yet for this, we'll keep the reverse edges. Bit hypocritical, but fuck it.
        
        state = (secured_edges, deleted_edges, remaining_edges, secured_count, remaining_count)
    
        return state
    
    # This is still here purely for debugging purposes.
    def play(self):
        while self.unsecured_count > 0:
            # 1. CUT player's turn
            if len(self.remaining) == 0:
                # No more valid edges to choose.
                self.fix_win = False
                self.end = True
                break
                
            edge_to_cut = self.choose_edge_to_cut()
            self.cut(edge_to_cut)
        
            # 2. FIX player's turn
            if len(self.remaining) == 0:
                # No more valid edges to choose.
                self.fix_win = False
                self.end = True

            edge_to_fix = self.choose_edge_to_fix()
            self.fix(edge_to_fix)
            
            if self.is_fix_path_complete():
                self.fix_win = True
                self.end = True
                break
            elif self.unsecured_count == 0:
                self.fix_win = False
                self.end = True

    # This is still here purely for debugging purposes.
    def choose_edge_to_fix(self):
        edge_to_fix = random.choice(list(self.remaining))
        return edge_to_fix

## Model

In [175]:
# 1. model.py
# 2. agent.py
# model is the FFNN, agent is what trains the model
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import os

# This is the Feedforward Neural Network.
class ShannonModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, edges, game):
        self.edges = edges
        self.game = game
        
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = self.fc4(x)
        
        sorted_x, indices = torch.sort(x, descending = True)
        max_probability = None
        chosen_edge = None
    
        for index in indices:
            edge = self.edges[index.item()]
            reverse_edge = (edge[1], edge[0])
            
            # Could be a problem here. Don't know what, but just troubleshooting for future.
            if ((edge not in self.game.removed_edges and edge not in self.game.secured_edges) and 
                (reverse_edge not in self.game.removed_edges and reverse_edge not in self.game.secured_edges)): 
                # print("Edge:", edge, "Removed:", game.removed_edges, "Secured:", game.secured_edges)
                max_probability = x[index]
                chosen_edge = edge
                break
                
        return x, max_probability, chosen_edge # Return the probabilities and the valid edge with the max probability.

## Training the Model

In [176]:
import random
import torch
from collections import deque
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import StepLR

def format_state(state):
    num_nodes = find_max_number([state[0], state[1], state[2]])
        
    secured_edges = convert_to_adj_matrix(state[0], num_nodes)
    deleted_edges = convert_to_adj_matrix(state[1], num_nodes)
    remaining_edges = convert_to_adj_matrix(state[2], num_nodes)
    secured_count, remaining_count = state[3], state[4]
        
    secured_edges, deleted_edges, remaining_edges, secured_count, remaining_count = (torch.tensor(secured_edges).flatten(), 
                                                                                     torch.tensor(deleted_edges).flatten(), 
                                                                                     torch.tensor(remaining_edges).flatten(), 
                                                                                     torch.tensor([secured_count]), 
                                                                                     torch.tensor([remaining_count]))
            
    formatted_state = np.concatenate([secured_edges, deleted_edges, remaining_edges, secured_count, remaining_count]).tolist()
    formatted_state = torch.tensor(formatted_state)
    
    return formatted_state

def convert_to_adj_matrix(edges, num_nodes):
    nodes = set()
    for edge in edges:
        nodes.add(edge[0])
        nodes.add(edge[1])

    adj_matrix = np.zeros((num_nodes + 1, num_nodes + 1)) # 0th col and 0th row will just be to pad.

    # Populate the adjacency matrix
    for edge in edges:
        adj_matrix[edge[0]][edge[1]] = 1
        
    return adj_matrix

def find_max_number(lists_of_tuples):
    max_number = float('-inf')

    for list_of_tuples in lists_of_tuples:
        for tup in list_of_tuples:
            numbers = [x for x in tup if isinstance(x, (int, float))]
            if numbers:
                current_max = max(numbers)
                if current_max > max_number:
                    max_number = current_max

    return max_number

In [184]:
# Parameters.
num_epochs = 512
batch_size = 64
learning_rate = 0.001
input_size = 302
hidden_size = 256
output_size = 12
gamma = 0.90
target_update = 16

# Create the Graph and GameAI instances.
num_rows = 3
num_cols = 3
graph = Graph(num_rows, num_cols)
game = GameAI(graph)
edges = game.edges

# Initialize the Feedforward Neural Network (policy network) and target network.
policy_net = ShannonModel(input_size, hidden_size, output_size, edges, game)
target_net = ShannonModel(input_size, hidden_size, output_size, edges, game)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=50, gamma=0.5)

# Training loop
for epoch in range(num_epochs):
    state_list = []
    action_list = []
    reward_list = []
    next_state_list = []
    done_list = []

    while not game.end:
        # Get the current state
        state = game.get_state()
        formatted_state = format_state(state)

        # Choose an action using the policy network.
        _, _, action = policy_net(formatted_state)

        # Perform the action and get the reward
        game.next_step_player(action)
        reward = game.get_reward()
        next_state = game.get_state()
        formatted_next_state = format_state(next_state)
        done = game.end

        # Store the transition
        state_list.append(formatted_state)
        action_list.append(action)
        reward_list.append(reward)
        next_state_list.append(formatted_next_state)
        done_list.append(done)
                
    # Reset the game if it has ended
    game.reset()

    # Create a DataLoader for the collected data
    dataset = TensorDataset(torch.stack(state_list), 
                            torch.tensor(action_list), 
                            torch.tensor(reward_list, dtype=torch.float32), 
                            torch.stack(next_state_list), 
                            torch.tensor(done_list))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Train the policy network
    for batch_idx, (states, actions, rewards, next_states, dones) in enumerate(dataloader):
        for i in range(len(states)):           
            q_values, max_q, chosen_edge = policy_net(states[i])
            next_q_values, max_next_q, chosen_next_edge = target_net(next_states[i])
            
            target_q_values = rewards[i] + (1 - dones[i].float()) * gamma * next_q_values.max()

            loss = criterion(q_values, target_q_values.detach())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Update the target network
    if epoch % target_update == 0:
        target_net.load_state_dict(policy_net.state_dict())

    # Update the learning rate
    scheduler.step()

    print(f"Epoch: {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

Epoch: 1/512, Loss: 0.0031
Epoch: 2/512, Loss: 0.0012
Epoch: 3/512, Loss: 0.0026
Epoch: 4/512, Loss: 0.0004
Epoch: 5/512, Loss: 0.7156
Epoch: 6/512, Loss: 0.0060
Epoch: 7/512, Loss: 0.7488
Epoch: 8/512, Loss: 0.0008
Epoch: 9/512, Loss: 0.4214
Epoch: 10/512, Loss: 2.2766
Epoch: 11/512, Loss: 0.0125
Epoch: 12/512, Loss: 0.0256
Epoch: 13/512, Loss: 0.0000
Epoch: 14/512, Loss: 0.0003
Epoch: 15/512, Loss: 2.6585
Epoch: 16/512, Loss: 0.0008
Epoch: 17/512, Loss: 0.0001
Epoch: 18/512, Loss: 0.0536
Epoch: 19/512, Loss: 0.0018
Epoch: 20/512, Loss: 0.0003
Epoch: 21/512, Loss: 0.4842
Epoch: 22/512, Loss: 0.0004
Epoch: 23/512, Loss: 0.0002
Epoch: 24/512, Loss: 0.0016
Epoch: 25/512, Loss: 0.0008
Epoch: 26/512, Loss: 0.1354
Epoch: 27/512, Loss: 0.0005
Epoch: 28/512, Loss: 0.0001
Epoch: 29/512, Loss: 0.0005
Epoch: 30/512, Loss: 0.0005
Epoch: 31/512, Loss: 0.0002
Epoch: 32/512, Loss: 0.0070
Epoch: 33/512, Loss: 0.0000
Epoch: 34/512, Loss: 0.0011
Epoch: 35/512, Loss: 0.3557
Epoch: 36/512, Loss: 0.0060
E

In [143]:
torch.tensor([ 4,  3,  9, 10,  0, 11,  5,  6,  2,  1,  8,  7])

tensor([ 4,  3,  9, 10,  0, 11,  5,  6,  2,  1,  8,  7])