## Graph

In [114]:
class Node:
    def __init__(self, row, col):
        self.val = (row, col) # Value of node = its coordinates ((0, 0) = top-left/start, (m - 1, n - 1) = bottom-right/end).
        self.adj_list = set()

# create a list of edges for Game
class Graph:
    def __init__(self, num_rows, num_cols):
        self.num_rows = num_rows
        self.num_cols = num_cols
        self.nodes = self._create_nodes()
        
        self.edges = set()
        for row in range(self.num_rows): # Kinda like initialising a 2D matrix. From the nodes generated, add the edges to the set.
            for col in range(self.num_cols):
                node = self.nodes[row][col]

                adj_list = node.adj_list
                for other_node in adj_list:
                    self.edges.add((node, other_node))
                    self.edges.add((other_node, node))
        
    def init_start_and_end(self):
        self.start = self.nodes[0][0] # top-left
        self.end = self.nodes[self.num_rows - 1][self.num_cols - 1] # bottom-right
        
    def _create_nodes(self):
        nodes = [[Node(row, col) for col in range(self.num_cols)] for row in range(self.num_rows)] # create normal node for entire matrix
        for row in range(self.num_rows):
            for col in range(self.num_cols):
                node = nodes[row][col] # this stuff below is where we initialise the neighbours.
                if row > 0:
                    node.adj_list.add(nodes[row - 1][col])  # Upper neighbour.
                    nodes[row - 1][col].adj_list.add(node)
                if row < self.num_rows - 1:
                    node.adj_list.add(nodes[row + 1][col])  # Lower neighbour.
                    nodes[row + 1][col].adj_list.add(node)
                if col > 0:
                    node.adj_list.add(nodes[row][col - 1])  # Left neighbour.
                    nodes[row][col - 1].adj_list.add(node)
                if col < self.num_cols - 1:
                    node.adj_list.add(nodes[row][col + 1])  # Right neighbour.
                    nodes[row][col + 1].adj_list.add(node)
        return nodes
        
    def print_graph(self): # for debugging purposes
        for row in range(self.num_rows):
            for col in range(self.num_cols):
                print(self.nodes[row][col].val)
            print()
            
    def print_adjacencies(self): # for debugging purposes
        for row in range(self.num_rows):
            for col in range(self.num_cols):
                print((row, col), [node.val for node in self.nodes[row][col].adj_list])

In [115]:
g = Graph(3, 3)
g.init_start_and_end()

## Game

Rules:
1. s is in top-left (0, 0), t is bottom-right (m - 1, n - 1).
2. Fix-type player wants is to secure a path from s to t; to do this, the fix-type player secures an edge in the graph in each iteration.
3. Cut-type player wants to disconnect s and t; to do this, the cut-type player deletes an unsecured edge in the graph.
4. Game ends when there is a secured path from s to t (fix) or there are no paths between s and t (cut).

In [126]:
# fix: check for a path (dfs)
# cut: no more valid edges to choose
import random

class Game:
    def __init__(self, graph):
        self.graph = graph
        self.m = self.graph.num_rows
        self.n = self.graph.num_cols
        self.unsecured_count = (2 * self.m * self.n) - self.m - self.n # this is for the CUT player
        self.secured = [] # this is what the FIX player chooses; nodes
        self.secured_edges = [] # fix  
        self.removed_edges = [] # cut
        
        # these are the remaining unsecured edges    
        # i've used a set comprehension so it's easier to see
        # need to ensure both directions of edges are deleted when work is done (e.g. both ((0, 0), (1, 0)) and ((1, 0), (0, 0))
        self.remaining = {(node1.val, node2.val) for node1, node2 in self.graph.edges} 
        
        self.fix_win = False
        self.end = False
        
    def reset(self):
        self.unsecured_count = (2 * self.m * self.n) - self.m - self.n # this is for the CUT player
        self.secured = [] # this is what the FIX player chooses; nodes
        self.secured_edges = [] # fix  
        self.removed_edges = [] # cut
        self.remaining = {(node1.val, node2.val) for node1, node2 in self.graph.edges} 
        
        self.fix_win = False
        self.end = False
        
    # Plays step-by-step. This is what we'll use for "learning".
    def next_step_player(self):
        if self.unsecured_count > 0:
            if not self.end:
                # 1. CUT player's turn
                if len(self.remaining) == 0:
                    # No more valid edges to choose.
                    self.fix_win = False
                    self.end = True
                else:
                    edge_to_cut = self.choose_edge_to_cut()
                    self.cut(edge_to_cut)

            if not self.end:
                # 2. FIX player's turn
                if len(self.remaining) == 0:
                    # No more valid edges to choose.
                    self.fix_win = False
                    self.end = True
                else:
                    edge_to_fix = self.choose_edge_to_fix()
                    self.fix(edge_to_fix)
            
            if self.is_fix_path_complete():
                self.fix_win = True
                self.end = True
        else:
            self.end = True
        
    # Plays the entire thing.
    def play(self):
        while self.unsecured_count > 0:
            # 1. CUT player's turn
            if len(self.remaining) == 0:
                # No more valid edges to choose.
                self.fix_win = False
                self.end = True
                break
                
            edge_to_cut = self.choose_edge_to_cut()
            self.cut(edge_to_cut)
        
            # 2. FIX player's turn
            if len(self.remaining) == 0:
                # No more valid edges to choose.
                self.fix_win = False
                self.end = True
                break

            edge_to_fix = self.choose_edge_to_fix()
            self.fix(edge_to_fix)
            
            if self.is_fix_path_complete():
                self.fix_win = True
                self.end = True
                break
            
    def choose_edge_to_cut(self):
        # Need to implement some strategy here. Return as a tuple of coordinates.
        edge_to_cut = random.choice(list(self.remaining))
        return edge_to_cut

    def choose_edge_to_fix(self):
        # Need to implement some strategy here. Return as a tuple of coordinates.
        edge_to_fix = random.choice(list(self.remaining))
        return edge_to_fix

    def is_fix_path_complete(self): # this does BFS to check if there is a path from the start to the end.
        visited = set()
        stack = [(0, 0)]

        while stack:
            current_node = stack.pop()
            if current_node == (self.m - 1, self.n - 1):
                return True

            for i in range(len(self.secured) - 1):
                edge = (self.secured[i], self.secured[i + 1])
                reverse_edge = (self.secured[i + 1], self.secured[i])

                if (edge in self.secured_edges or reverse_edge in self.secured_edges) and current_node == self.secured[i]:
                    next_node = self.secured[i + 1]
                    if next_node not in visited:
                        visited.add(next_node)
                        stack.append(next_node)

        return False

    # 1. CUT player's function; removes unsecured edge in question (and its reverse).
    # Ideally we don't check if the edge is in self.remaining (we just assume it is).
    # But perhaps the choose function might fuck up.
    def cut(self, edge):
        # edge = ex: ((0, 0), (1,0))
        if edge in self.remaining:
            self.remaining.remove(edge)
            self.removed_edges.append(edge)
            self.unsecured_count -= 1
            
        # Also remove the reverse direction of the edge.
        reverse_edge = (edge[1], edge[0])
        if reverse_edge in self.remaining:
            self.remaining.remove(reverse_edge)
            self.unsecured_count -= 1

    def fix(self, edge):
        # edge = ex: ((0, 0), (1,0))
        if edge in self.remaining:
            self.remaining.remove(edge)
            self.secured.append(edge[0])
            self.secured.append(edge[1])
            self.secured_edges.append(edge)
            
        # Also remove the reverse direction of the edge.
        reverse_edge = (edge[1], edge[0])
        if reverse_edge in self.remaining:
            self.remaining.remove(reverse_edge)
            
    # Reward function
    def get_reward(self):
        if self.fix_win:
            # Positive reward when the FIX player wins
            reward = 1.0
        elif self.end:
            # Negative reward when the FIX player loses
            reward = -1.0
        else:
            # Intermediate reward for the ongoing game
            reward = 0.0            

        return reward
            
    def get_state(self):
        # Define the state representation based on the game state.
        secured_count = len(self.secured_edges)
        remaining_count = int(len(self.remaining) / 2) # Because reverse edges are here too.
        secured_edges = self.secured_edges
        deleted_edges = self.removed_edges
        remaining_edges = list(self.remaining) # Yet for this, we'll keep the reverse edges. Bit hypocritical, but fuck it.
        
        state = (secured_edges, deleted_edges, remaining_edges, secured_count, remaining_count)
    
        return state

In [153]:
for i in range(1000):
    g = Graph(4, 4)
    g.init_start_and_end()
    game = Game(g)
    game.play()
    if game.fix_win:
        break

In [165]:
game.secured_edges, game.remaining

([((0, 2), (1, 2)),
  ((1, 1), (1, 0)),
  ((1, 1), (1, 2)),
  ((1, 3), (2, 3)),
  ((3, 3), (2, 3)),
  ((1, 2), (1, 3)),
  ((2, 1), (1, 1)),
  ((2, 2), (1, 2)),
  ((0, 0), (1, 0))],
 {((0, 1), (1, 1)),
  ((0, 2), (0, 3)),
  ((0, 3), (0, 2)),
  ((1, 1), (0, 1)),
  ((2, 0), (3, 0)),
  ((2, 1), (2, 2)),
  ((2, 1), (3, 1)),
  ((2, 2), (2, 1)),
  ((3, 0), (2, 0)),
  ((3, 1), (2, 1)),
  ((3, 1), (3, 2)),
  ((3, 2), (3, 1))})

## Game (Updated to Include AI)

In [300]:
class Node:
    def __init__(self, val):
        self.val = val # Value of node = its position (left, right, zigzag, continue).
        self.adj_list = set()

# create a list of edges for Game
class Graph:
    def __init__(self, num_rows, num_cols):
        self.num_rows = num_rows
        self.num_cols = num_cols
        self.nodes_int = self._create_nodes()
        
        self.edges = set()
        for row in range(self.num_rows): # Kinda like initialising a 2D matrix. From the nodes generated, add the edges to the set.
            for col in range(self.num_cols):
                node = self.nodes_mat[row][col]

                adj_list = self.mapper[node.val].adj_list
                node = self.mapper[node.val]
                for other_node in adj_list:
                    self.edges.add((node, other_node))
                    self.edges.add((other_node, node))
        
    def init_start_and_end(self):
        self.start = 1 # top-left
        self.end = self.num_rows * self.num_cols # bottom-right
        
    def _create_nodes(self):
        nodes_int, i = [], 1
        nodes = [[Node((row, col)) for col in range(self.num_cols)] for row in range(self.num_rows)] 
        mapper = dict()
        for row in range(self.num_rows):
            for col in range(self.num_cols):
                node = Node(i) # Create the nodes, number 1 to (m * n).
                nodes_int.append(node)
                mapper[(row, col)] = node
                i += 1
        
        for row in range(self.num_rows):
            for col in range(self.num_cols):
                node = nodes[row][col] # this stuff below is where we initialise the neighbours.
                if row > 0:
                    node.adj_list.add(nodes[row - 1][col])  # Upper neighbour.
                    nodes[row - 1][col].adj_list.add(node)
                    
                    current_mapping = mapper[(row, col)]
                    nbr_mapping = mapper[(row - 1, col)]
                    
                if row < self.num_rows - 1:
                    node.adj_list.add(nodes[row + 1][col])  # Lower neighbour.
                    nodes[row + 1][col].adj_list.add(node)
                    
                    current_mapping = mapper[(row, col)]
                    nbr_mapping = mapper[(row + 1, col)]
                    
                if col > 0:
                    node.adj_list.add(nodes[row][col - 1])  # Left neighbour.
                    nodes[row][col - 1].adj_list.add(node)
                    
                    current_mapping = mapper[(row, col)]
                    nbr_mapping = mapper[(row, col - 1)]
                    
                if col < self.num_cols - 1:
                    node.adj_list.add(nodes[row][col + 1])  # Right neighbour.
                    nodes[row][col + 1].adj_list.add(node)
                    
                    current_mapping = mapper[(row, col)]
                    nbr_mapping = mapper[(row, col + 1)]
                    
                current_mapping.adj_list.add(nbr_mapping)
                nbr_mapping.adj_list.add(current_mapping)
                
        self.nodes_mat = nodes
        self.mapper = mapper
                    
        return nodes_int
        
    def print_graph(self): # for debugging purposes
        print([node.val for row in self.nodes_mat for node in row])
        print()
        print([node.val for node in self.nodes_int])

In [301]:
g = Graph(3, 3)
g.init_start_and_end()
# gameAI_test = GameAI(g)

In [313]:
for item in g.edges:
    print(str(item[0].val) + " - " + str(item[1].val))

11 - 12
14 - 13
2 - 3
6 - 5
11 - 10
16 - 15
7 - 6
3 - 2
7 - 8
8 - 7
3 - 4
6 - 7
14 - 15
5 - 6
10 - 9
9 - 10
1 - 2
2 - 1
12 - 11
10 - 11
13 - 14
4 - 3
15 - 14
15 - 16


In [302]:
g.print_graph()

[(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1), (1, 2), (1, 3), (2, 0), (2, 1), (2, 2), (2, 3), (3, 0), (3, 1), (3, 2), (3, 3)]

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]


In [195]:
class GameAI:
    def __init__(self, graph):
        self.graph = graph
        self.m = self.graph.num_rows
        self.n = self.graph.num_cols
        self.unsecured_count = (2 * self.m * self.n) - self.m - self.n # this is for the CUT player
        self.secured = [] # this is what the FIX player chooses; nodes
        self.secured_edges = [] # fix  
        self.removed_edges = [] # cut
        
        self.remaining = {(node1.val, node2.val) for node1, node2 in self.graph.edges} 
        
        self.fix_win = False
        self.end = False
        
    def reset(self): # Reset everything for the next training iteration.
        self.unsecured_count = (2 * self.m * self.n) - self.m - self.n 
        self.secured = []
        self.secured_edges = []
        self.removed_edges = []
        self.remaining = {(node1.val, node2.val) for node1, node2 in self.graph.edges} 
        
        self.fix_win = False
        self.end = False
        
    # This is what our AI will train against. Random shit.
    def choose_edge_to_cut(self):
        edge_to_cut = random.choice(list(self.remaining))
        return edge_to_cut    
    
    # Plays step-by-step. This is what we'll use for "learning".
    def next_step_player(self, chosen_edge):
        if self.unsecured_count > 0:
            if not self.end:
                # 1. CUT/bot player's turn.
                if len(self.remaining) == 0:
                    # No more valid edges to choose.
                    self.fix_win = False
                    self.end = True
                else:
                    edge_to_cut = self.choose_edge_to_cut()
                    self.cut(edge_to_cut)

            if not self.end:
                # 2. FIX player's turn: where the magic happens.
                if len(self.remaining) == 0:
                    # No more valid edges to choose.
                    self.fix_win = False
                    self.end = True
                else:
                    self.fix(chosen_edge)
            
            if self.is_fix_path_complete():
                self.fix_win = True
                self.end = True
        else:
            self.end = True

    def is_fix_path_complete(self): # This does BFS to check if there is a path from the start to the end.
        visited = set()
        stack = [(0, 0)]

        while stack:
            current_node = stack.pop()
            if current_node == (self.m - 1, self.n - 1):
                return True

            for i in range(len(self.secured) - 1):
                edge = (self.secured[i], self.secured[i + 1])
                reverse_edge = (self.secured[i + 1], self.secured[i])

                if (edge in self.secured_edges or reverse_edge in self.secured_edges) and current_node == self.secured[i]:
                    next_node = self.secured[i + 1]
                    if next_node not in visited:
                        visited.add(next_node)
                        stack.append(next_node)

        return False

    # 1. CUT player's function; removes unsecured edge in question (and its reverse).
    def cut(self, edge):
        # edge = ex: ((0, 0), (1,0))
        if edge in self.remaining:
            self.remaining.remove(edge)
            self.removed_edges.append(edge)
            self.unsecured_count -= 1
            
        # Also remove the reverse direction of the edge.
        reverse_edge = (edge[1], edge[0])
        if reverse_edge in self.remaining:
            self.remaining.remove(reverse_edge)
            self.unsecured_count -= 1

    # 2. FIX player's function; secures unsecured edge in question (and its reverse).
    def fix(self, edge):
        # edge = ex: ((0, 0), (1,0))
        if edge in self.remaining:
            self.remaining.remove(edge)
            self.secured.append(edge[0])
            self.secured.append(edge[1])
            self.secured_edges.append(edge)
            
        # Also remove the reverse direction of the edge.
        reverse_edge = (edge[1], edge[0])
        if reverse_edge in self.remaining:
            self.remaining.remove(reverse_edge)
            
    # Reward function
    def get_reward(self):
        if self.fix_win:
            # Positive reward when the FIX player wins
            reward = 1.0
        elif self.end:
            # Negative reward when the FIX player loses
            reward = -1.0
        else:
            # Intermediate reward for the ongoing game
            reward = 0.0            

        return reward
            
    def get_state(self):
        # Define the state representation based on the game state.
        secured_count = len(self.secured_edges)
        remaining_count = int(len(self.remaining) / 2) # Because reverse edges are here too.
        secured_edges = self.secured_edges
        deleted_edges = self.removed_edges
        remaining_edges = list(self.remaining) # Yet for this, we'll keep the reverse edges. Bit hypocritical, but fuck it.
        
        state = (secured_edges, deleted_edges, remaining_edges, secured_count, remaining_count)
    
        return state

In [245]:
i = 0
rows = gameAI_test.graph.num_rows
columns = gameAI_test.graph.num_cols
depth = rows * columns
matrix = np.zeros((depth, rows, columns), dtype = np.int32)

for row in range(gameAI_test.graph.num_rows):
    for col in range(gameAI_test.graph.num_cols):
        i += 1

In [227]:
gameAI_test.get_state()

([],
 [],
 [((1, 2), (1, 3)),
  ((1, 2), (0, 2)),
  ((1, 1), (1, 0)),
  ((2, 2), (3, 2)),
  ((2, 1), (1, 1)),
  ((3, 2), (3, 1)),
  ((1, 1), (2, 1)),
  ((1, 3), (1, 2)),
  ((1, 0), (2, 0)),
  ((2, 1), (2, 2)),
  ((3, 1), (2, 1)),
  ((1, 3), (2, 3)),
  ((0, 1), (0, 2)),
  ((1, 2), (1, 1)),
  ((0, 0), (1, 0)),
  ((2, 3), (1, 3)),
  ((1, 1), (1, 2)),
  ((1, 2), (2, 2)),
  ((0, 2), (0, 3)),
  ((3, 0), (3, 1)),
  ((0, 3), (1, 3)),
  ((1, 1), (0, 1)),
  ((2, 0), (1, 0)),
  ((3, 1), (3, 0)),
  ((0, 3), (0, 2)),
  ((2, 0), (2, 1)),
  ((2, 1), (3, 1)),
  ((3, 0), (2, 0)),
  ((3, 3), (2, 3)),
  ((0, 1), (1, 1)),
  ((2, 2), (2, 1)),
  ((2, 1), (2, 0)),
  ((2, 3), (3, 3)),
  ((1, 0), (1, 1)),
  ((0, 1), (0, 0)),
  ((0, 2), (1, 2)),
  ((3, 2), (3, 3)),
  ((0, 0), (0, 1)),
  ((3, 1), (3, 2)),
  ((2, 2), (1, 2)),
  ((0, 2), (0, 1)),
  ((2, 3), (2, 2)),
  ((2, 0), (3, 0)),
  ((1, 0), (0, 0)),
  ((1, 3), (0, 3)),
  ((3, 3), (3, 2)),
  ((3, 2), (2, 2)),
  ((2, 2), (2, 3))],
 0,
 24)

## Agent

In [217]:
# https://www.youtube.com/watch?v=L8ypSXwyBds
import random
from collections import deque

LR = 0.001
MAX_MEMORY = 1_000_000
BATCH_SIZE = 1000

class Agent:
    def __init__(self, state_size, game):
        self.n_games = 0 # Iteration number
        self.epsilon = 0 # Randomness for epsilon-greedy.
        self.gamma = 0.9 # "Discount rate"
        self.memory = deque(maxlen = MAX_MEMORY) # popleft()
        
        # Should set secured edges to -1 so agent doesn't select them. ShannonModel needs a function.
        self.model = ShannonModel(state_size, 256, (2 * game.m * game.n) - game.m - game.n)
        self.trainer = Trainer(self.model, lr = LR, gamma = self.gamma)
        
    def get_state(self, game):
        # (secured_edges, deleted_edges, remaining_edges, secured_count, remaining_count)
        state = game.get_state() # Could probably clean this up a bit and turn it into tensors in here.

        return np.array(state, dtype = object)
    
    def remember(self, state, action, reward, next_state, done): # Stores this shit into the deque so it can be used for training later.
        self.memory.append((state, action, reward, next_state, done)) # popleft if MAX_MEMORY is reached

    def train_short_memory(self, state, action, reward, next_state, done):
        self.trainer.train_step(state, action, reward, next_state, done) # Pretty much the same shit as remember, but trainer uses data.
        
    def train_long_memory(self):
        if len(self.memory) > BATCH_SIZE:
            mini_sample = random.sample(self.memory, BATCH_SIZE)
        else:
            mini_sample = self.memory

        states, actions, rewards, next_states, dones = zip(*mini_sample) # From remember/memory
        self.trainer.train_step(states, actions, rewards, next_states, dones)
        
    def get_action(self, state):
        # Exploration / exploitation
        self.epsilon = 80 - self.n_games # Hardcoded, can change this shit
        
        remaining_edges = state[2]
        
        if random.randint(0, 200) < self.epsilon: # This works.
            final_edge = random.choice(remaining_edges)
        else: # This, not so much.
            state0 = torch.tensor(state) # I need to turn state into a tensor.
            prediction = self.model(state0)
            final_edge = torch.argmax(prediction).item()

        return final_edge

In [224]:
g = Graph(4, 4)
g.init_start_and_end()
gameAI = GameAI(g)
agent = Agent(5, gameAI)

In [225]:
agent.get_action(agent.get_state(gameAI))

TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

array([list([]), list([]),
       list([((1, 2), (1, 3)), ((1, 2), (0, 2)), ((1, 1), (1, 0)), ((2, 2), (3, 2)), ((2, 1), (1, 1)), ((3, 2), (3, 1)), ((1, 1), (2, 1)), ((1, 3), (1, 2)), ((1, 0), (2, 0)), ((2, 1), (2, 2)), ((3, 1), (2, 1)), ((1, 3), (2, 3)), ((0, 1), (0, 2)), ((1, 2), (1, 1)), ((0, 0), (1, 0)), ((1, 2), (2, 2)), ((0, 2), (0, 3)), ((2, 3), (1, 3)), ((1, 1), (1, 2)), ((3, 0), (3, 1)), ((0, 3), (1, 3)), ((2, 0), (1, 0)), ((1, 1), (0, 1)), ((3, 1), (3, 0)), ((0, 3), (0, 2)), ((2, 0), (2, 1)), ((2, 1), (3, 1)), ((3, 0), (2, 0)), ((3, 3), (2, 3)), ((0, 1), (1, 1)), ((2, 2), (2, 1)), ((2, 1), (2, 0)), ((2, 3), (3, 3)), ((1, 0), (1, 1)), ((0, 0), (0, 1)), ((0, 2), (1, 2)), ((3, 2), (3, 3)), ((0, 1), (0, 0)), ((3, 1), (3, 2)), ((0, 2), (0, 1)), ((2, 0), (3, 0)), ((2, 3), (2, 2)), ((2, 2), (2, 3)), ((1, 0), (0, 0)), ((1, 3), (0, 3)), ((3, 3), (3, 2)), ((3, 2), (2, 2)), ((2, 2), (1, 2))]),
       0, 24], dtype=object)

## Model

In [173]:
# 1. model.py
# 2. agent.py
# model is the FFNN, agent is what trains the model
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import os

In [174]:
# This is the Feedforward Neural Network.
class ShannonModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size) 
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        x = torch.relu(x)
        x = self.fc3(x)
        x = torch.relu(x)
        x = self.fc4(x)
        return x

## Trainer

In [180]:
# This trains the ShannonModel() initialised above.
class Trainer: 
    def __init__(self, model, lr, gamma):
        self.lr = lr # Learning Rate
        self.gamma = gamma # https://ai.stackexchange.com/questions/8100/what-is-the-purpose-of-the-gamma-parameter-in-svms
        self.model = model # ShannonModel()
        self.optimizer = optim.Adam(model.parameters(), lr = self.lr) # Adam Algorithm (some shit idek)
        self.criterion = nn.MSELoss() # Standard MSE

    def train_step(self, state, action, reward, next_state, done): # Where the magic happens
        pass