In [53]:
# DQN model
import random
import torch
import gym
from gym import spaces
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt


class GameOfLifeEnv(gym.Env):
    def __init__(self, grid_size=1024):
        super(GameOfLifeEnv, self).__init__()
        self.grid_size = grid_size
        #initialize the grid to be a boolean matrix of zeros of the grid size
        #self.grid = torch.(self.grid_size, self.grid_size, dtype=torch.float32)
        #for testing starts with a rand initialization
        self.grid = torch.randint(low=0, high=2, size=(grid_size,grid_size), dtype=torch.float32)

    def step(self, action):
        #action is used to update the grid before the step
        self.update_grid(action)
        kernel = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.float32).view(1, 1, 3, 3)
        result = F.conv2d(self.grid.unsqueeze(0).unsqueeze(0), kernel, padding=1)
        output = ((result == 3) | ((result == 2) & (self.grid.unsqueeze(0).unsqueeze(0) == 1))).float()
        next_state = output.squeeze()
        reward = self.calc_reward(next_state)
        self.grid = next_state
        return next_state, reward

    def reset(self):
        self.grid = torch.zeros(self.grid_size, self.grid_size, dtype=torch.float32)
        return self.grid

    def render(self, mode='human'):
        plt.imshow(self.grid, cmap='gray_r')
        plt.colorbar()
        plt.show()

    def update_grid(self, actions):
        for i in range(0, actions.size(1), 2):
            x = int(actions[0][i])
            y = int(actions[0][i+1])
            self.grid[x][y] = 1


    def calc_reward(self, next_state):
        state_count = self.count_living(self.grid)
        next_state_count = self.count_living(next_state)
        return next_state_count / state_count

    def count_living(self, grid):
        count = 0
        for i in range(grid.size(0)):
            for j in range(grid.size(1)):
                if(grid[i][j] == 1):
                    count += 1
        return count

class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)

        self.fc1 = nn.Linear(128*64*64, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 200)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    

class AutomataSolver:
    def __init__(self, n_episodes=1000, gamma=1.0, epsilon=1.0, epsilon_min=0.01, alpha=0.01, alpha_decay=0.01, batch_size=64, n_steps=1000):
        self.n_episodes = n_episodes
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.alpha = alpha
        self.alpha_decay = alpha_decay
        self.batch_size = batch_size
        self.env = GameOfLifeEnv()
        self.model = DQN()
        self.memory = []
        self.memory_capacity = 10000
        self.memory_position = 0
        self.opt = torch.optim.Adam(self.model.parameters(), lr=alpha)
        self.criterion = nn.MSELoss()
        self.n_steps = n_steps

    def train(self):
        for episode in range(self.n_episodes):
            state = self.env.reset().unsqueeze(0)
            for step in range(self.n_steps):
                action = self.get_action(state)
                next_state, reward = self.env.step(action)
                next_state = next_state.unsqueeze(0) 
                self.remember(state, next_state, reward)
                
                if len(self.memory) >= self.batch_size:
                    states, next_states, rewards = self.sample(self.batch_size)
                    

                    current_q_values = self.model(states).gather(1, action.unsqueeze(-1)).squeeze(-1)
                    max_next_q_values = self.model(next_states).detach().max(1)[0]
                    expected_q_values = rewards + (self.gamma * max_next_q_values)
                    
                    # Compute loss
                    loss = self.criterion(current_q_values, expected_q_values)
                    
                    # Gradient descent
                    self.opt.zero_grad()
                    loss.backward()
                    self.opt.step()
                
                state = next_state
                
                # Epsilon decay
                if self.epsilon > self.epsilon_min:
                    self.epsilon -= (self.epsilon - self.epsilon_min) * self.alpha_decay

                
    def test(self):
        for episode in range(self.n_episodes):
            state = self.env.reset().unsqueeze(0)
            for step in range(self.n_steps):
                action = self.get_action(state)
                print(f'actionshape: {action.shape}')
                next_state, reward = self.env.step(action)
                next_state = next_state.unsqueeze(0) 
                self.remember(state, next_state, reward)
                
                if len(self.memory) >= self.batch_size:
                    states, next_states, rewards = self.sample(self.batch_size)
                    

                    current_q_values = self.model(states).gather(1, action.unsqueeze(-1)).squeeze(-1)
                    max_next_q_values = self.model(next_states).detach().max(1)[0]
                    expected_q_values = rewards + (self.gamma * max_next_q_values)
                    
                    # Compute loss
                    loss = self.criterion(current_q_values, expected_q_values)
                    
                    # Gradient descent
                    self.opt.zero_grad()
                    loss.backward()
                    self.opt.step()
                
                state = next_state
                
                # Epsilon decay
                if self.epsilon > self.epsilon_min:
                    self.epsilon -= (self.epsilon - self.epsilon_min) * self.alpha_decay

    def remember(self, state, next_state, reward):
        if len(self.memory) < self.memory_capacity: 
            self.memory.append(None)
        reward_tensor = torch.tensor([reward], dtype=torch.float32)
        self.memory[self.memory_position] = (state, next_state, reward_tensor)
        self.memory_position = (self.memory_position + 1) % self.memory_capacity

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        state, next_state, reward = zip(*batch)
        return torch.stack(state), torch.stack(next_state), torch.stack(reward) 

    def get_action(self, state):
        state_tensor = torch.FloatTensor(state).unsqueeze(0) 
        with torch.no_grad():
            q_values = self.model(state_tensor)
        return q_values
        


In [54]:
automata = AutomataSolver()
automata.test()

actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: torch.Size([1, 200])
actionshape: t

RuntimeError: gather(): Expected dtype int64 for index

In [None]:
#idea run 1000 episodes of the game of life over 100 timesteps based on the original game rules 
# then run the model on the same game of life and see how well it does compared to the random initialization