In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict

In [50]:

class GameOfLifeNN(nn.Module):
    """ Implementation of minimal, Game of Life model from: "It’s Hard For Neural Networks to Learn the Game of Life"
    """
    def __init__(self, n: int, m: int):
        """ Constructor. 
        
        Args:
            n: Both the number of steps in the Game of Life for the model to simulation, but also
                linear w.r.t. the number of convolutional layers (well, 2n+1).
            m: "where m is the factor of overcompleteness". Describes how much additional model
                capacity to add to the model above theoretical minimum bound (m=1).
        """
        super().__init__()
        self.n = n
        self.m = m
        
        # n layers for each of n life steps in GOL.
        # m filters for overcompleteness affects input/output shape
        layers = [
            ((f"conv_{i}", nn.Conv2d(1, 2 * m, (3, 3), padding='same')), 
            (f"reduce_{i}", nn.Conv2d(2 * m, m, (1, 1), padding='same'))) for i in range(n) 
        ]
        layers = [l for nn in layers for l in nn]
        
        self.steps = nn.Sequential(OrderedDict(layers))
        self.reduce_final = nn.Conv2d(m, 1, (1, 1))
        

    def forward(self, x):
        for conv in self.steps:
            x = F.relu(conv(x))

        return torch.sigmoid(self.reduce_final(x))


### From paper
- Initialize the weights randomly from a unit normal distribution.
- Adam optimizer (α = 0.001, β1 = 0.9, β2 = 0.999)
- binary cross-entropy loss function on the output of the model.
- Each instance is trained with 1 million randomly generated training examples.
    - 100 epochs. 10,000 training examples per epoch
    - 32 x 32 binary input. Value is 1 with probability d, drawn uniformly from [0, 1]
- Batch size of 8.


In [24]:
class GameOfLife(nn.Module):
    """ Implementation of predesigned Game of Life model for a single iteration.
    """
    def __init__(self):
        """ Constructor. """
        super().__init__()

        # n layers for each of n life steps in GOL.
        # m filters for overcompleteness affects input/output shape
        layers = [
            ((f"conv_{i}", nn.Conv2d(1, 2 * m, (3, 3), padding='same')), 
            (f"reduce_{i}", nn.Conv2d(2 * m, m, (1, 1), padding='same'))) for i in range(n) 
        ]
        layers = [l for nn in layers for l in nn]
        
        self.steps = nn.Sequential(OrderedDict(layers))
        self.reduce_final = nn.Conv2d(m, 1, (1, 1))
       
    def neighbours(self, x):
        """Create a fixed Conv layer that counts direct (up/down/left/right) number of neighbours."""
        ws = torch.zeros(3, 3)
        ws[0][1] = 1.0
        ws[1][0] = 1.0
        ws[1][2] = 1.0
        ws[2][1] = 1.0
        return F.conv2d(x, ws.expand(1,1, 3,3), bias=None, stride=1, padding=1)

    def tensor_equals_v(self, x, v: float):
        """Pair-wise equality between tensor x and scalar value v."""
        return torch.mul(x, torch.eq(torch.Tensor([v]).expand(x.shape), x).float())

    def forward(self, x):
        nbhs = self.neighbours(x)

        alive_twos = self.tensor_equals_v(x, 1.0) * self.tensor_equals_v(nbhs, 2.0)
        
        threes = self.tensor_equals_v(nbhs, 3.0)

        x = alive_twos + threes
        return (x > 0).float()
             

tensor([[[1., 0., 0., 0., 1.],
         [1., 0., 1., 1., 0.],
         [0., 1., 1., 1., 0.],
         [1., 1., 1., 1., 1.],
         [0., 1., 0., 0., 0.]]])
tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])


In [53]:
def create_random_inputs(examples: int = 10000, grid_size: int = 32):
    """Create a set of random Game of Life inputs, where each example has an alive
    probability (for its whole board), uniformly sampled from [0, 1].
    
    
    Output shape: [examples, grid_size, grid_size]
    """
    d = torch.rand(examples, 1, 1, 1)
    x = torch.rand(examples, 1, grid_size, grid_size)
    return (x < d).float()

def run_single_GOL_step(examples: torch.Tensor) -> torch.Tensor:
    """"""
    # TODO: Implement correctly.
    shape = examples.shape
    return create_random_inputs(examples = shape[0], grid_size = shape[2])
    
def GOL_step(example: torch.Tensor) -> torch.Tensor:
    pass 


game_of_life = GameOfLifeNN(1, 1)
data = create_random_inputs(examples=4, grid_size=5)
target = game_of_life(data)
(target== target).reshape(4, -1).all(dim=-1)



tensor([[[[0.4370, 0.4370, 0.4309, 0.4351, 0.4289],
          [0.4370, 0.4370, 0.4235, 0.4478, 0.4410],
          [0.4610, 0.4239, 0.4453, 0.4309, 0.4351],
          [0.4309, 0.4351, 0.4289, 0.4235, 0.4340],
          [0.4235, 0.4340, 0.4430, 0.4370, 0.4370]]],


        [[[0.4309, 0.4472, 0.4269, 0.4379, 0.4453],
          [0.4235, 0.4279, 0.4258, 0.4270, 0.4289],
          [0.4370, 0.4235, 0.4221, 0.4308, 0.4430],
          [0.4370, 0.4610, 0.4543, 0.4631, 0.4379],
          [0.4370, 0.4309, 0.4290, 0.4221, 0.4270]]],


        [[[0.4370, 0.4370, 0.4370, 0.4370, 0.4370],
          [0.4370, 0.4370, 0.4370, 0.4370, 0.4370],
          [0.4370, 0.4370, 0.4610, 0.4543, 0.4379],
          [0.4370, 0.4370, 0.4309, 0.4290, 0.4270],
          [0.4370, 0.4370, 0.4235, 0.4221, 0.4308]]],


        [[[0.4290, 0.4270, 0.4228, 0.4351, 0.4289],
          [0.4221, 0.4308, 0.4497, 0.4411, 0.4518],
          [0.4370, 0.4610, 0.4262, 0.4372, 0.4270],
          [0.4370, 0.4572, 0.4221, 0.4221, 0.4308],


tensor([True, True, True, True])

In [59]:
game_of_life = GameOfLife()
n=1
m=2
model = GameOfLifeNN(n, m)
loss_fn = nn.BCELoss()
batch_size = 8
epochs = 100
optim = torch.optim.Adam(model.parameters())


def train(epoch: int, model: nn.Module):
    print(f"Epoch: {epoch}")
    # Create data in bulk, split into batches
    data = create_random_inputs()
    target = game_of_life(data)

    # target = run_single_GOL_step(data)
    xx = torch.split(data, batch_size)
    yy = torch.split(target, batch_size)

    model.train()

    for batch_idx, (x, y) in enumerate(zip(xx, yy)):
        optim.zero_grad()

        output = model(x)
        loss = loss_fn(output, y)

        loss.backward()
        optim.step()

        if batch_idx % 1000 ==0:
            eval_data = create_random_inputs()

            eval_target = game_of_life(eval_data)
            eval_y = model(eval_data)

            correct_y = ((output >= 0.5).float() == y).reshape(y.shape[0], -1).all(dim=-1)
            correct_percent = torch.mean(correct_y.float())
    
            correct_eval_y = (eval_target == (eval_y >0.5).float()).reshape(eval_y.shape[0], -1).all(dim=-1)
            correct_eval_percent = torch.mean(correct_eval_y.float())
            eval_loss = loss_fn(eval_target, eval_y)

            print(f'\t batch: {batch_idx} \tLoss: {loss.data:.6f}')
            print(f'\t batch: {batch_idx} \t Eval Loss: {eval_loss.data:.6f}')
            print(f'\t batch: {batch_idx} \t Eval percent correct: {correct_percent:.6f}')
            print(f'\t batch: {batch_idx} \t Eval percent correct: {correct_eval_percent:.6f}')

for i in range(4):
    train(i, model)

Epoch: 0
	 batch: 0 	Loss: 0.651987
	 batch: 0 	 Eval Loss: 46.561619
	 batch: 0 	 Eval percent correct: 0.000000
	 batch: 0 	 Eval percent correct: 0.041000
	 batch: 1000 	Loss: 0.451410
	 batch: 1000 	 Eval Loss: 25.891346
	 batch: 1000 	 Eval percent correct: 0.000000
	 batch: 1000 	 Eval percent correct: 0.048000
Epoch: 1
	 batch: 0 	Loss: 0.386452
	 batch: 0 	 Eval Loss: 24.829203
	 batch: 0 	 Eval percent correct: 0.000000
	 batch: 0 	 Eval percent correct: 0.171700
	 batch: 1000 	Loss: 0.386394
	 batch: 1000 	 Eval Loss: 23.068624
	 batch: 1000 	 Eval percent correct: 0.000000
	 batch: 1000 	 Eval percent correct: 0.171400
Epoch: 2
	 batch: 0 	Loss: 0.172227
	 batch: 0 	 Eval Loss: 22.697010
	 batch: 0 	 Eval percent correct: 0.375000
	 batch: 0 	 Eval percent correct: 0.173600
	 batch: 1000 	Loss: 0.243832
	 batch: 1000 	 Eval Loss: 22.890638
	 batch: 1000 	 Eval percent correct: 0.250000
	 batch: 1000 	 Eval percent correct: 0.162100
Epoch: 3
	 batch: 0 	Loss: 0.431553
	 batch