In [None]:

import random
import torch
import gym
from gym import spaces
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
from collections import namedtuple, deque

'''
uses openais gym environment to create a game of life environment
modified to use pytorch tensors and convolutions

methods:
    __init__(): initializes the environment with a grid size of 1024, and creates a grid of random 1s and 0s
    step(action): updates the grid based on the action, calculates the reward, and checks if the game is done
    check_done(next_state): checks if the game is done
    reset(): resets the grid to a random state
    render(state): renders the grid on a matplotlib plot
    update_grid(actions): updates the grid based on the actions
    calc_reward(next_state): calculates the reward based on the next state
    count_living(grid): counts the number of living cells in the grid
'''
class GameOfLifeEnv(gym.Env):
    def __init__(self, grid_size=1024):
        super(GameOfLifeEnv, self).__init__()
        self.grid_size = grid_size
        #set the device to run the computations on the gpu if available
        self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        #initialize the grid with random 1s and 0s
        self.grid = torch.randint(low=0, high=2, size=(grid_size,grid_size), dtype=torch.float32, device=self.device)
        #initialize the kernel for the convolution
        self.kernel = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.float32, device=self.device).view(1, 1, 3, 3)

    def step(self, action):
        #update the grid based on the action
        self.update_grid(action)
        #apply the convolution to the grid
        result = F.conv2d(self.grid.unsqueeze(0).unsqueeze(0), self.kernel, padding=1)
        #apply the rules of the game of life
        output = ((result == 3) | ((result == 2) & (self.grid.unsqueeze(0).unsqueeze(0) == 1))).float()
        #get the next state
        next_state = output.squeeze()
        reward = self.calc_reward(next_state)
        self.grid = next_state
        done = self.check_done(next_state)

        return next_state, reward, done
    #check if the game is done
    def check_done(self, next_state):
        if self.count_living(next_state) == self.grid_size * self.grid_size:
            return True
        else: return False
    #reset the grid to a random state
    def reset(self):
        self.grid = torch.randint(low=0, high=2, size=(self.grid_size,self.grid_size), dtype=torch.float32, device=self.device)
        return self.grid
    #render the grid with matplotlib
    def render(self, state):
        grid = state.to('cpu')
        plt.imshow(grid, cmap='gray_r')
        plt.colorbar()
        plt.show()
    #update the grid based on the actions
    def update_grid(self, actions):
        #unpack the actions
        for i in range(0, actions.size(1), 2):
            x = int(actions[0][i])
            y = int(actions[0][i+1])
            #update the grid based on the x and y coordinates
            self.grid[x][y] = 1
        
    #calculate the reward based on the next state and the current state of the grid
    def calc_reward(self, next_state):
        state_count = self.count_living(self.grid)
        next_state_count = self.count_living(next_state)
        return next_state_count / state_count
    #count the number of living cells in the grid
    def count_living(self, grid):
        return torch.sum(grid == 1).item()

'''
DQN model for the game of life environment
convolutional neural network with 5 convolutional layers and 4 fully connected layers
methods:
    __init__(): initializes the model with 5 convolutional layers and 4 fully connected layers
    forward(x): forward pass of the model
'''
class DQN(nn.Module):
    def __init__(self):
        super().__init__()  
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)  
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)  
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)  
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
         
        self.fc1 = nn.Linear(4096, 2048)  
        self.fc2 = nn.Linear(2048, 1024)  
        self.fc3 = nn.Linear(1024, 256)
        self.fc4 = nn.Linear(256, 200)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    
#used for pusing to the memory stack    
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
#memory class for storing the transitions
class Memory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self,batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)



'''
Trainer class for training the DQN model
methods:
    __init__(): initializes the trainer with the environment, model, device, episodes, iterations, batch size, memory size, gamma, learning rate, tau, optimizer, losses, and rewards
    train(): trains the model over the specified number of episodes and iterations
    optimize(): optimizes the model by calculating the loss and updating the weights
    sanity_check(): prints the model, environment, number of parameters, and model size
    get_action(state): gets the action based on the state
    save_model(path): saves the model to the specified path
    load_model(path): loads the model from the specified path
    training_bar(iteration): prints a progress bar for the training
    plot_metrics(): plots the losses per episode
attributes:
    env: the game of life environment
    model: the DQN model
    device: the device to run the computations on
    episodes: the number of episodes to train the model
    iterations: the number of iterations per episode
    batch_size: the batch size for training
    memory: the memory for storing the transitions
    target_net: the target network for updating the weights
    lr: the learning rate
    tau: the tau value for updating the target network
    optimizer: the optimizer for training the model
    losses: the losses per episode
    rewards: the rewards per episode'''

class Trainer:
    def __init__(self, episodes=200, iterations=1000, batch_size=32, lr=1e-4, memory_size=1000, tau=0.005):
        self.env = GameOfLifeEnv()
        self.model = DQN()
        self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
        self.episodes = episodes
        self.iterations = iterations
        self.batch_size = batch_size
        self.memory = Memory(memory_size)

        self.lr = lr
        self.tau = tau
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, amsgrad=True)
        self.model.to(self.device)
        self.losses = []
        self.rewards = []

    def train(self):
        print("Training Started...")
        for episode in range(self.episodes):
            #reset the environment
            state = self.env.reset()
            iteration = 0
            done = False
            cum_reward = 0
            print(f'Episode: {episode} of {self.episodes}')
            #loop through the iterations
            for iteration in range(self.iterations):
                #print the progress bar
                self.training_bar(iteration)
                #get the action based on the state
                state = state.unsqueeze(0).unsqueeze(1)
                action = self.get_action(state)
                #take a step in the environment
                next_state, reward, done = self.env.step(action)
                cum_reward += reward
                reward = torch.tensor([reward], device=self.device)
                #push the transition to the memory
                self.memory.push(state,action,next_state,reward)
                #update the state
                state = next_state
                #optimize the model
                loss = self.optimize()
                #append the loss
                self.losses.append(loss)
                
                if done:
                    break
                iteration += 1
            self.rewards.append(cum_reward/iteration)
            
        print("Training Ended...")
        self.plot_metrics()
        

    def optimize(self):
        #if the memory is less than the batch size dont optimize
        if len(self.memory) < self.batch_size:
            return  
        #sample the transitions
        transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))
        #concatenate the states and actions
        state_batch = torch.cat(batch.state)
        reward_batch = torch.cat(batch.reward)
        #get the model output
        model_output = self.model(state_batch)
        #get the expected values
        expected_values = torch.repeat_interleave(reward_batch.unsqueeze(1), 200, dim=1)  
        criterion = nn.MSELoss() 
        #calculate the loss
        loss = criterion(model_output,expected_values)
        #zero the gradients
        self.optimizer.zero_grad()
        #backpropagate the loss
        loss.backward()
        #clip the gradients
        torch.nn.utils.clip_grad_value_(self.model.parameters(), 100)
        #update the weights
        self.optimizer.step()

        return loss.item()
        
    def sanity_check(self):
            print(f"model: {self.model}")
            print(f"env: {self.env}")
            n_params = sum(p.numel() for p in self.model.parameters())
            print(f"Number of parameters: {n_params}")
            torch.save(self.model.state_dict(), 'sanity_check_model.pth')
            model_size_mb = os.path.getsize('sanity_check_model.pth') / 1e6
            print(f"Model size: {model_size_mb} MB")
    #get the action based on the state
    def get_action(self, state):
        #randomly select an action
        if random.random() > 0.1:
            return torch.randint(low=0, high=self.env.grid_size, size=(1, 200), dtype=torch.float32, device=self.env.device)   
        else:
        #get the model output
            return self.model(state)
        
    def save_model(self, path):
        torch.save(self.model.state_dict(), path)

    def load_model(self, path):
        self.model.load_state_dict(torch.load(path))

    def training_bar(self, iteration):
        if(iteration %10 == 0): print('|#', end='')

    def plot_metrics(self):
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.plot(self.losses, label='Loss')
        plt.title('Loss per Episode')
        plt.legend()
        plt.show()

automata = Trainer(episodes=200)
automata.sanity_check()
automata.train()
automata.save_model('model.pth')

model: DQN(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (fc1): Linear(in_features=4096, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=1024, bias=True)
  (fc3): Linear(in_features=1024, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=200, bias=True)
)
env: <GameOfLifeEnv instance>
Number of parameters: 11194952
Model size: 44.786002 MB
Tr

In [None]:
#idea run 1000 episodes of the game of life over 100 timesteps based on the original game rules 
# then run the model on the same game of life and see how well it does compared to the random initialization