# Experiment: Random walk CPM 
The goal of this experiment is to apply gradient-based learning to a simplified CPM model simulating a random walk 

import modules

In [1]:
import torch as t
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("../../")
from periodic_padding import periodic_padding
from hamiltonian_diff import model
from cell_typing import CellKind, CellMap

## 1. Splitting the grid into subdomains for parallel checkerboarding update scheme

In [28]:
def split_grids(batch:t.Tensor, kernel_width: int, kernel_height: int) -> t.Tensor:
    unfold_transform = t.nn.Unfold(kernel_size=(kernel_height,kernel_width))
    
    padded_batch = periodic_padding(batch).float()
    padded_batch = padded_batch.unsqueeze(1)
    
    return t.transpose(unfold_transform(padded_batch), dim0=1, dim1=2)

The goal is to split the grid into smaller subdomains of a given size. E.g. we may want to create subgrids of size 2x2. To update these subdomains, we also need the Moore Neighborhood around them. 

To achieve that, we can apply periodic padding and then use an Unfolding convolution layer. Using a kernel size of $k=(4,4)$ (subdomain dims + 2) we get convolutional blocks containing the desired 2x2 subgrids + Moore neighborhood, unrolled into vectors (one vector for each convolutional block).

In [40]:
test = t.arange(2*4*4).reshape(2,4,4)
print(test)

split = split_grids(test, kernel_width=4, kernel_height=4)
print(split.shape)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]]])
torch.Size([2, 9, 16])


we can also use different kernel sizes for the width and height. E.g. we may want to only consider 1x2 subgrids:

In [42]:
test = t.arange(2*4*4).reshape(2,4,4)
print(test)

split = split_grids(test, kernel_width=4, kernel_height=3)
print(split.shape)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]]])
torch.Size([2, 12, 12])


To apply checkerboarding, we want to choose a set of subgrids where there is at least 2 columns/rows between them:

In [43]:
test = t.arange(64.).reshape(1,8,8)
print(test)

tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11., 12., 13., 14., 15.],
         [16., 17., 18., 19., 20., 21., 22., 23.],
         [24., 25., 26., 27., 28., 29., 30., 31.],
         [32., 33., 34., 35., 36., 37., 38., 39.],
         [40., 41., 42., 43., 44., 45., 46., 47.],
         [48., 49., 50., 51., 52., 53., 54., 55.],
         [56., 57., 58., 59., 60., 61., 62., 63.]]])


On this 8x8 grid, we want to consider a split into 16 2x2 subdomains:

In [55]:
split = split_grids(test, kernel_width=4, kernel_height=4)
print(f"split shape: {split.shape}")
# we can visualize the first convolutional block we receive this way
print(split[0,0].reshape(4,4))

split shape: torch.Size([1, 49, 16])
tensor([[63., 56., 57., 58.],
        [ 7.,  0.,  1.,  2.],
        [15.,  8.,  9., 10.],
        [23., 16., 17., 18.]])


In the cell above, we see that when we reshape the vector into a 2D block, we indeed get the top-left 2x2 subgrid with its Moore neighborhood.

The 3rd convolutional block on the other hand corresponds to the next subgrid:

In [47]:
print(split[0,2].reshape(4,4))

tensor([[57., 58., 59., 60.],
        [ 1.,  2.,  3.,  4.],
        [ 9., 10., 11., 12.],
        [17., 18., 19., 20.]])


We can create 4 mutually exclusive sets of subgrids that have at least a spacing of two columns/rows between each other:

In [60]:
subset1 = split[0, [0, 4, 28, 32]]
print(subset1.reshape(4,4,4))

tensor([[[63., 56., 57., 58.],
         [ 7.,  0.,  1.,  2.],
         [15.,  8.,  9., 10.],
         [23., 16., 17., 18.]],

        [[59., 60., 61., 62.],
         [ 3.,  4.,  5.,  6.],
         [11., 12., 13., 14.],
         [19., 20., 21., 22.]],

        [[31., 24., 25., 26.],
         [39., 32., 33., 34.],
         [47., 40., 41., 42.],
         [55., 48., 49., 50.]],

        [[27., 28., 29., 30.],
         [35., 36., 37., 38.],
         [43., 44., 45., 46.],
         [51., 52., 53., 54.]]])


In [65]:
subset2 = split[0, [2, 6, 30, 34]]
print(subset2.reshape(4,4,4))

tensor([[[57., 58., 59., 60.],
         [ 1.,  2.,  3.,  4.],
         [ 9., 10., 11., 12.],
         [17., 18., 19., 20.]],

        [[61., 62., 63., 56.],
         [ 5.,  6.,  7.,  0.],
         [13., 14., 15.,  8.],
         [21., 22., 23., 16.]],

        [[25., 26., 27., 28.],
         [33., 34., 35., 36.],
         [41., 42., 43., 44.],
         [49., 50., 51., 52.]],

        [[29., 30., 31., 24.],
         [37., 38., 39., 32.],
         [45., 46., 47., 40.],
         [53., 54., 55., 48.]]])


In [71]:
subset3 = split[0, [16, 20, 44, 48]]
print(subset3.reshape(4,4,4))

tensor([[[ 9., 10., 11., 12.],
         [17., 18., 19., 20.],
         [25., 26., 27., 28.],
         [33., 34., 35., 36.]],

        [[13., 14., 15.,  8.],
         [21., 22., 23., 16.],
         [29., 30., 31., 24.],
         [37., 38., 39., 32.]],

        [[41., 42., 43., 44.],
         [49., 50., 51., 52.],
         [57., 58., 59., 60.],
         [ 1.,  2.,  3.,  4.]],

        [[45., 46., 47., 40.],
         [53., 54., 55., 48.],
         [61., 62., 63., 56.],
         [ 5.,  6.,  7.,  0.]]])
