In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict

In [None]:

class GameOfLifeNN(nn.Module):
    """ Implementation of minimal, Game of Life model from: "It’s Hard For Neural Networks to Learn the Game of Life"
    """
    def __init__(self, n: int, m: int):
        """ Constructor. 
        
        Args:
            n: Both the number of steps in the Game of Life for the model to simulation, but also
                linear w.r.t. the number of convolutional layers (well, 2n+1).
            m: "where m is the factor of overcompleteness". Describes how much additional model
                capacity to add to the model above theoretical minimum bound (m=1).
        """
        super().__init__()
        self.n = n
        self.m = m
        
        # n layers for each of n life steps in GOL.
        # m filters for overcompleteness affects input/output shape
        layers = [
            ((f"conv_{i}", nn.Conv2d(1, 2 * m, (3, 3), padding='same')), 
            (f"reduce_{i}", nn.Conv2d(2 * m, m, (1, 1), padding='same'))) for i in range(n) 
        ]
        layers = [l for nn in layers for l in nn]
        
        self.steps = nn.Sequential(OrderedDict(layers))
        self.reduce_final = nn.Conv2d(m, 1, (1, 1))
        

    def forward(self, x):
        for conv in self.steps:
            x = F.relu(conv(x))

        return torch.sigmoid(self.reduce_final(x))


### From paper
- Initialize the weights randomly from a unit normal distribution.
- Adam optimizer (α = 0.001, β1 = 0.9, β2 = 0.999)
- binary cross-entropy loss function on the output of the model.
- Each instance is trained with 1 million randomly generated training examples.
    - 100 epochs. 10,000 training examples per epoch
    - 32 x 32 binary input. Value is 1 with probability d, drawn uniformly from [0, 1]
- Batch size of 8.


In [None]:
n=1
m=1
model = GameOfLifeNN(n, m)
loss_fn = nn.BCELoss()
batch_size = 8
epochs = 100
optim = torch.optim.Adam(model.parameters())


In [None]:
def create_random_inputs(examples: int = 10000, grid_size: int = 32):
    """Create a set of random Game of Life inputs, where each example has an alive
    probability (for its whole board), uniformly sampled from [0, 1].
    
    
    Output shape: [examples, grid_size, grid_size]
    """
    d = torch.rand(examples, 1, 1, 1)
    x = torch.rand(examples, 1, grid_size, grid_size)
    return (x < d).float()

def run_single_GOL_step(examples: torch.Tensor) -> torch.Tensor:
    """"""
    # TODO: Implement correctly.
    shape = examples.shape
    return create_random_inputs(examples = shape[0], grid_size = shape[2])
    
def GOL_step(example: torch.Tensor) -> torch.Tensor:
    pass

In [None]:

def train(epoch: int, model: nn.Module):
    print(f"Epoch: {epoch}")
    # Create data in bulk, split into batches
    data = create_random_inputs()
    target = run_single_GOL_step(data)
    xx = torch.split(data, batch_size)
    yy = torch.split(target, batch_size)

    model.train()

    for batch_idx, (x, y) in enumerate(zip(xx, yy)):
        optim.zero_grad()

        output = model(x)
        loss = loss_fn(output, y)

        loss.backward()
        optim.step()
        if batch_idx % 100 ==0:
            print(f'\t batch: {batch_idx} \tLoss: {loss.data:.6f}')


for i in range(4):
    train(i, model)