# Experiment: Random walk CPM 
The goal of this experiment is to apply gradient-based learning to a simplified CPM model simulating a random walk 

import modules

In [2]:
import torch as t
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("../../")
from periodic_padding import periodic_padding
from hamiltonian_diff import model
from cell_typing import CellKind, CellMap

## 1. Splitting the grid into subdomains for parallel checkerboarding update scheme

In [3]:
def split_grids(batch:t.Tensor, kernel_width: int, kernel_height: int) -> t.Tensor:
    unfold_transform = t.nn.Unfold(kernel_size=(kernel_height,kernel_width))
    
    padded_batch = periodic_padding(batch).float()
    padded_batch = padded_batch.unsqueeze(1)
    
    return t.transpose(unfold_transform(padded_batch), dim0=1, dim1=2)

The goal is to split the grid into smaller subdomains of a given size. E.g. we may want to create subgrids of size 2x2. To update these subdomains, we also need the Moore Neighborhood around them. 

To achieve that, we can apply periodic padding and then use an Unfolding convolution layer. Using a kernel size of $k=(4,4)$ (subdomain dims + 2) we get convolutional blocks containing the desired 2x2 subgrids + Moore neighborhood, unrolled into vectors (one vector for each convolutional block).

In [4]:
test = t.arange(2*4*4).reshape(2,4,4)
print(test)

split = split_grids(test, kernel_width=4, kernel_height=4)
print(split.shape)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]]])
torch.Size([2, 9, 16])


we can also use different kernel sizes for the width and height. E.g. we may want to only consider 1x2 subgrids:

In [5]:
test = t.arange(2*4*4).reshape(2,4,4)
print(test)

split = split_grids(test, kernel_width=4, kernel_height=3)
print(split.shape)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]]])
torch.Size([2, 12, 12])


To apply checkerboarding, we want to choose a set of subgrids where there is at least 2 columns/rows between them:

In [6]:
test = t.arange(64.).reshape(1,8,8)
print(test)

tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11., 12., 13., 14., 15.],
         [16., 17., 18., 19., 20., 21., 22., 23.],
         [24., 25., 26., 27., 28., 29., 30., 31.],
         [32., 33., 34., 35., 36., 37., 38., 39.],
         [40., 41., 42., 43., 44., 45., 46., 47.],
         [48., 49., 50., 51., 52., 53., 54., 55.],
         [56., 57., 58., 59., 60., 61., 62., 63.]]])


On this 8x8 grid, we want to consider a split into 16 2x2 subdomains:

In [7]:
split = split_grids(test, kernel_width=4, kernel_height=4)
print(f"split shape: {split.shape}")
# we can visualize the first convolutional block we receive this way
print(split[0,0].reshape(4,4))

split shape: torch.Size([1, 49, 16])
tensor([[63., 56., 57., 58.],
        [ 7.,  0.,  1.,  2.],
        [15.,  8.,  9., 10.],
        [23., 16., 17., 18.]])


In the cell above, we see that when we reshape the vector into a 2D block, we indeed get the top-left 2x2 subgrid with its Moore neighborhood.

The 3rd convolutional block on the other hand corresponds to the next subgrid:

In [8]:
print(split[0,2].reshape(4,4))

tensor([[57., 58., 59., 60.],
        [ 1.,  2.,  3.,  4.],
        [ 9., 10., 11., 12.],
        [17., 18., 19., 20.]])


We can create 4 mutually exclusive sets of subgrids that have at least a spacing of two columns/rows between each other:

In [9]:
subset1 = split[0, [0, 4, 28, 32]]
print(subset1.reshape(4,4,4))

tensor([[[63., 56., 57., 58.],
         [ 7.,  0.,  1.,  2.],
         [15.,  8.,  9., 10.],
         [23., 16., 17., 18.]],

        [[59., 60., 61., 62.],
         [ 3.,  4.,  5.,  6.],
         [11., 12., 13., 14.],
         [19., 20., 21., 22.]],

        [[31., 24., 25., 26.],
         [39., 32., 33., 34.],
         [47., 40., 41., 42.],
         [55., 48., 49., 50.]],

        [[27., 28., 29., 30.],
         [35., 36., 37., 38.],
         [43., 44., 45., 46.],
         [51., 52., 53., 54.]]])


In [10]:
subset2 = split[0, [2, 6, 30, 34]]
print(subset2.reshape(4,4,4))

tensor([[[57., 58., 59., 60.],
         [ 1.,  2.,  3.,  4.],
         [ 9., 10., 11., 12.],
         [17., 18., 19., 20.]],

        [[61., 62., 63., 56.],
         [ 5.,  6.,  7.,  0.],
         [13., 14., 15.,  8.],
         [21., 22., 23., 16.]],

        [[25., 26., 27., 28.],
         [33., 34., 35., 36.],
         [41., 42., 43., 44.],
         [49., 50., 51., 52.]],

        [[29., 30., 31., 24.],
         [37., 38., 39., 32.],
         [45., 46., 47., 40.],
         [53., 54., 55., 48.]]])


In [11]:
subset3 = split[0, [16, 20, 44, 48]]
print(subset3.reshape(4,4,4))

tensor([[[ 9., 10., 11., 12.],
         [17., 18., 19., 20.],
         [25., 26., 27., 28.],
         [33., 34., 35., 36.]],

        [[13., 14., 15.,  8.],
         [21., 22., 23., 16.],
         [29., 30., 31., 24.],
         [37., 38., 39., 32.]],

        [[41., 42., 43., 44.],
         [49., 50., 51., 52.],
         [57., 58., 59., 60.],
         [ 1.,  2.,  3.,  4.]],

        [[45., 46., 47., 40.],
         [53., 54., 55., 48.],
         [61., 62., 63., 56.],
         [ 5.,  6.,  7.,  0.]]])


In [12]:
subset4 = split[0, [14, 18, 42, 46]]
print(subset4.reshape(4,4,4))

tensor([[[15.,  8.,  9., 10.],
         [23., 16., 17., 18.],
         [31., 24., 25., 26.],
         [39., 32., 33., 34.]],

        [[11., 12., 13., 14.],
         [19., 20., 21., 22.],
         [27., 28., 29., 30.],
         [35., 36., 37., 38.]],

        [[47., 40., 41., 42.],
         [55., 48., 49., 50.],
         [63., 56., 57., 58.],
         [ 7.,  0.,  1.,  2.]],

        [[43., 44., 45., 46.],
         [51., 52., 53., 54.],
         [59., 60., 61., 62.],
         [ 3.,  4.,  5.,  6.]]])


In [13]:
grid_subsets = t.cat((
    subset1.reshape(4,4,4).unsqueeze(0),
    subset2.reshape(4,4,4).unsqueeze(0),
    subset3.reshape(4,4,4).unsqueeze(0),
    subset4.reshape(4,4,4).unsqueeze(0)
))
grid_subsets.shape

torch.Size([4, 4, 4, 4])

## Simple update probability function

In [413]:
def p_update(subdomains, cur_vol, target_vol, temperature, src_coords, tgt_coords):
    batch_size, _, _ = subdomains.shape 
    src_x, src_y = src_coords
    tgt_x, tgt_y = tgt_coords
    vol_changes = (-1 * subdomains[range(batch_size), tgt_y, tgt_x]) + subdomains[range(batch_size), src_y, src_x]
    total_vol_change = t.sum(vol_changes)
    #print(f"vol change: {total_vol_change}")
    adjusted_vol = cur_vol + total_vol_change
    p_update = t.tensor(0.)
    if adjusted_vol <= 2 and adjusted_vol > 0:
        p_update += t.exp(-(target_vol - adjusted_vol)**2/temperature)
    return p_update

## Conversion function for subdomain coords to batch coords

In [414]:
def get_grid_coords(grid_dim, tile_dim, tile_row, tile_col, tile_coords):
    tile_y, tile_x = tile_coords
    tile_height, tile_width = tile_dim
    grid_height, grid_width = grid_dim
    
    grid_x = tile_col*tile_width+(tile_x-1)
    grid_y = tile_row*tile_height+(tile_y-1)
    
    grid_x[grid_x == -1] = grid_width-1
    grid_x[grid_x == grid_width] = 0
    
    grid_y[grid_y == -1] = grid_height-1
    grid_y[grid_y == grid_height] = 0
    
    return t.vstack((grid_y, grid_x)).T

In [415]:
test_grid = t.arange(64.).reshape(8,8)
grid_dim = (8,8)
tile_dim = (2,2)


tile_row = t.tensor(0)
tile_col = t.tensor(0)
tile_coords = (t.tensor(1),t.tensor(1)) 
coords = get_grid_coords(grid_dim, tile_dim, tile_row, tile_col, tile_coords)
assert t.all(coords == t.tensor([(0,0)]))
assert test_grid[coords[:,0], coords[:,1]] == 0

tile_row = t.tensor(0)
tile_col = t.tensor(0)
tile_coords = (t.tensor(0),t.tensor(0)) 
coords = get_grid_coords(grid_dim, tile_dim, tile_row, tile_col, tile_coords)
assert t.all(coords == t.tensor([(7,7)]))
assert test_grid[coords[:,0], coords[:,1]] == 63.

tile_row = t.tensor(1)
tile_col = t.tensor(1)
tile_coords = (t.tensor(1),t.tensor(1)) 
coords = get_grid_coords(grid_dim, tile_dim, tile_row, tile_col, tile_coords)
assert t.all(coords == t.tensor([(2,2)]))
assert test_grid[coords[:,0], coords[:,1]] == 18.

tile_row = t.tensor(1)
tile_col = t.tensor(1)
tile_coords = (t.tensor(2),t.tensor(1)) 
coords = get_grid_coords(grid_dim, tile_dim, tile_row, tile_col, tile_coords)
assert t.all(coords == t.tensor([(3,2)]))
assert test_grid[coords[:,0], coords[:,1]] == 26.

tile_row = t.tensor((0, 1))
tile_col = t.tensor((0, 1))
tile_coords = (t.tensor((0, 2)),t.tensor((0, 1))) 
coords = get_grid_coords(grid_dim, tile_dim, tile_row, tile_col, tile_coords)
assert t.all(coords == t.tensor([(7,7),(3,2)]))
assert t.all(test_grid[coords[:,0], coords[:,1]] == t.tensor((63, 26)))

## Monte Carlo step scheme

In [416]:
def get_subdomains(grid):
    subgrids = split_grids(grid, kernel_width=4, kernel_height=4)
    tile_set1 = subgrids[0, [0, 4, 28, 32]].reshape(4,4,4)
    tile_rows1 = t.tensor([0,0,2,2])
    tile_cols1 = t.tensor([0,2,0,2])
    subdomain1 = (tile_set1, tile_rows1, tile_cols1)

    tile_set2 = subgrids[0, [2, 6, 30, 34]].reshape(4,4,4)
    tile_rows2 = t.tensor([0,0,2,2])
    tile_cols2 = t.tensor([1,3,1,3])
    subdomain2 = (tile_set2, tile_rows2, tile_cols2)

    tile_set3 = subgrids[0, [16, 20, 44, 48]].reshape(4,4,4)
    tile_rows3 = t.tensor([1,1,3,3])
    tile_cols3 = t.tensor([1,3,1,3])
    subdomain3 = (tile_set3, tile_rows3, tile_cols3)

    tile_set4 = subgrids[0, [14, 18, 42, 46]].reshape(4,4,4)
    tile_rows4 = t.tensor([1,1,3,3])
    tile_cols4 = t.tensor([0,2,0,2])
    subdomain4 = (tile_set4, tile_rows4, tile_cols4)


    checkerboard_batch = (subdomain1, subdomain2, subdomain3, subdomain4)
    
    return checkerboard_batch

In [417]:
def get_tile_batch(grid, idx):
    subgrids = split_grids(grid, kernel_width=4, kernel_height=4)
    if idx == 0:
        tile_set = subgrids[0, [0, 4, 28, 32]].reshape(4,4,4)
        tile_rows = t.tensor([0,0,2,2])
        tile_cols = t.tensor([0,2,0,2])
        subdomain = (tile_set, tile_rows, tile_cols)
        return subdomain

    elif idx == 1:
        tile_set = subgrids[0, [2, 6, 30, 34]].reshape(4,4,4)
        tile_rows = t.tensor([0,0,2,2])
        tile_cols = t.tensor([1,3,1,3])
        subdomain = (tile_set, tile_rows, tile_cols)
        return subdomain

    elif idx == 2:
        tile_set = subgrids[0, [16, 20, 44, 48]].reshape(4,4,4)
        tile_rows = t.tensor([1,1,3,3])
        tile_cols = t.tensor([1,3,1,3])
        subdomain = (tile_set, tile_rows, tile_cols)
        return subdomain

    else:
        tile_set = subgrids[0, [14, 18, 42, 46]].reshape(4,4,4)
        tile_rows = t.tensor([1,1,3,3])
        tile_cols = t.tensor([0,2,0,2])
        subdomain = (tile_set, tile_rows, tile_cols)
        return subdomain

In [453]:
import random
def MCS(batch, target_vol, temperature):
    _, batch_height, batch_width = batch.shape
    vol_kernel = t.tensor([[[
        [0., 0., 0.],
        [0., 1., 0.],
        [0., 0., 0.]
    ]]])
    
    #checkerboard_batch = get_subdomains(batch)
    # iterate over the sets of spaced out sub-grids
    steps = []
    for domain_idx in range(4):
        padded_batch = periodic_padding(batch).unsqueeze(1)
        cur_vol = t.sum(t.nn.functional.conv2d(padded_batch, vol_kernel))
        #print(f"current volume: {cur_vol}")
        subdomains = get_tile_batch(batch, domain_idx)
        subgrids, tile_rows, tile_cols = subdomains
        tiles_per_batch, subgrid_height, subgrid_width = subgrids.shape
        src_x = t.randint_like(t.zeros(tiles_per_batch,), low=1, high=int(subgrid_width-2+1))
        src_x = src_x.type(t.long)
        src_y = t.randint_like(src_x, low=1, high=int(subgrid_height-2+1))
        # For each random sample in src, we sample a random value from [-1, 0, 1]
        # and add it on to the src_idx
        step_sizes = t.tensor(random.choices([
            (1, 0),
            (1, 1),
            (0, 1),
            (-1, 1),
            (-1, 0),
            (-1, -1),
            (0, -1),
            (1, -1)
        ], k=tiles_per_batch))
        tgt_x = src_x + step_sizes[:, 1]
        tgt_y = src_y + step_sizes[:, 0]
        
        update_probability = p_update(
            subgrids,
            cur_vol,
            target_vol,
            temperature,
            (src_x, src_y),
            (tgt_x, tgt_y)
        )
        
        
        #print(f"p_update: {update_probability}")
        
        residuals = t.rand_like(update_probability)
        
        #mask = update_probabilities > residuals
        #update_x_tgt = tgt_x[mask]
        #update_y_tgt = tgt_y[mask]
        #update_x_src = src_x[mask]
        #update_y_src = src_y[mask]
        
        #batch_coords_src = get_grid_coords(
        #    grid_dim=(batch_height, batch_width),
        #    tile_dim=(subgrid_height-2, subgrid_width-2),
        #    tile_row=tile_rows[mask],
        #    tile_col=tile_cols[mask],
        #    tile_coords=(update_y_src, update_x_src)
        #)
        
        #batch_coords_tgt = get_grid_coords(
        #    grid_dim=(batch_height, batch_width),
        #    tile_dim=(subgrid_height-2, subgrid_width-2),
        #    tile_row=tile_rows[mask],
        #    tile_col=tile_cols[mask],
        #    tile_coords=(update_y_tgt, update_x_tgt)
        #)
        
        if update_probability > residuals:
        
            batch_coords_src = get_grid_coords(
                grid_dim=(batch_height, batch_width),
                tile_dim=(subgrid_height-2, subgrid_width-2),
                tile_row=tile_rows,
                tile_col=tile_cols,
                tile_coords=(src_y, src_x)
            )
            
            batch_coords_tgt = get_grid_coords(
                grid_dim=(batch_height, batch_width),
                tile_dim=(subgrid_height-2, subgrid_width-2),
                tile_row=tile_rows,
                tile_col=tile_cols,
                tile_coords=(tgt_y, tgt_x)
            )
            
            
            #print(f"src: X: {src_x}, Y: {src_y}, V: {subgrids[range(tiles_per_batch), src_x, src_y]}")
            #print(f"tgt: X: {tgt_x}, Y: {tgt_y}, V: {subgrids[range(tiles_per_batch), tgt_x, tgt_y]}")
            #print(mask)
            #subdomains[t.where(mask), update_x_tgt, update_y_tgt] *= 0
            #subdomains[t.where(mask), update_x_tgt, update_y_tgt] += subdomains[t.where(mask), update_x_src, update_y_src]
            batch[0, batch_coords_tgt[:,0], batch_coords_tgt[:,1]] *= 0
            batch[0, batch_coords_tgt[:,0], batch_coords_tgt[:,1]] += batch[0, batch_coords_src[:,0], batch_coords_src[:,1]]
    
            steps.append(batch.detach().squeeze().clone().numpy())
    return batch, steps
        

In [454]:
test = t.zeros(1,8,8)
test[0,3,3] = 1
target_vol = 1.
temperature = 27.


test, steps = MCS(test, target_vol, temperature)
test

tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.]]])

In [455]:
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image

In [463]:
test = t.zeros(1,8,8)
test[0,3,3] = 1
target_vol = 1.
temperature = 27.

states = [test.detach().clone().squeeze().numpy()]
for i in tqdm(range(1000)):
    #print(f"--------------MCS: {i+1} --------------")
    test, steps = MCS(test, target_vol, temperature)
    states += steps
    if t.sum(test) == 0 or t.sum(test) > 2:
        print(test)
        break

imgs = [Image.fromarray((1-state)*255) for state in states]
imgs[0].save("t_27.gif", save_all=True, append_images=imgs[1:], duration=10, loop=100)

100%|██████████| 1000/1000 [00:04<00:00, 217.77it/s]


In [464]:
test = t.zeros(1,8,8)
test[0,3,3] = 1
target_vol = 1.
temperature = 13.

states = [test.detach().clone().squeeze().numpy()]
for i in tqdm(range(1000)):
    #print(f"--------------MCS: {i+1} --------------")
    test, steps = MCS(test, target_vol, temperature)
    states += steps
    if t.sum(test) == 0 or t.sum(test) > 2:
        print(test)
        break

imgs = [Image.fromarray((1-state)*255) for state in states]
imgs[0].save("t_13.gif", save_all=True, append_images=imgs[1:], duration=10, loop=100)

100%|██████████| 1000/1000 [00:04<00:00, 214.82it/s]


In [467]:
test = t.zeros(1,8,8)
test[0,3,3] = 1
target_vol = 1.
temperature = 0.

states = [test.detach().clone().squeeze().numpy()]
for i in tqdm(range(1000)):
    #print(f"--------------MCS: {i+1} --------------")
    test, steps = MCS(test, target_vol, temperature)
    states += steps
    if t.sum(test) == 0 or t.sum(test) > 2:
        print(test)
        break

imgs = [Image.fromarray((1-state)*255) for state in states]
imgs[0].save("t_0.gif", save_all=True, append_images=imgs[1:], duration=10, loop=100)

100%|██████████| 1000/1000 [00:02<00:00, 387.02it/s]
