In [153]:
import sys
import torch as t
sys.path.append("../../")
from periodic_padding import periodic_padding
import random

## grid splitting with reflective boundary neighborhood

In [2]:
def split_grids(batch:t.Tensor, kernel_width: int, kernel_height: int) -> t.Tensor:
    unfold_transform = t.nn.Unfold(kernel_size=(kernel_height,kernel_width))
    reflective_padding = t.nn.ReflectionPad2d(1)
    
    padded_batch = reflective_padding(batch).float()
    padded_batch = padded_batch.unsqueeze(1)
    
    return t.transpose(unfold_transform(padded_batch), dim0=1, dim1=2)

In [22]:
test1 = t.arange(64).reshape(1,8,8).float()
split1 = split_grids(test1, kernel_width=4, kernel_height=4)
assert split1.shape == (1, 49, 16)
assert t.all(split1[0,0] == t.tensor([
    9,  8,  9,  10,
    1,  0,  1,  2,
    9,  8,  9,  10,
    17, 16, 17, 18
])) 

test2 = t.arange(32**2).reshape(1,32,32).float()
split2 = split_grids(test2, kernel_width=6, kernel_height=6)
assert split2.shape == (1, 841, 36)
assert t.all(split2[0,0] == t.tensor([
    33,  32,  33,  34,  35,  36,
    1,   0,   1,   2,   3,   4,
    33,  32,  33,  34,  35,  36,
    65,  64,  65,  66,  67,  68,
    97,  96,  97,  98,  99,  100,
    129, 128, 129, 130, 131, 132
]))

## converting tile indices of the checkerboard to index in the unfolded conv list

In [41]:
def tile_idx2conv_idx(row, col, tile_size, batch_size):
    _, batch_height, batch_width = batch_size
    tile_height, tile_width = tile_size
    tile_width -= 2
    tile_height -= 2
    tiles_per_row = batch_width-tile_width+1
    return int(tile_height*row*tiles_per_row + tile_width*col)

In [42]:
test1 = t.arange(64).reshape(1,8,8).float()
split1 = split_grids(test1, kernel_width=4, kernel_height=4)
idx = tile_idx2conv_idx(1,0,(4,4),test1.shape)
assert idx == 14
assert t.all(split1[0,idx] == t.tensor([
    9,  8,  9,  10,
    17, 16, 17, 18,
    25, 24, 25, 26,
    33, 32, 33, 34
]))

In [None]:
def tile_coords2grid_coords(tile_idx, tile_coords, grid_dim, tile_dim)

## creating independent checkerboard sets

In [120]:
def indep_chkboard_sets(batch, tile_height, tile_width):
    _, batch_height, batch_width = batch.shape
    tiles_per_row = batch_width//(tile_width-2)
    tiles_per_col = batch_height//(tile_height-2)
    
    print(tiles_per_row)
    
    assert tiles_per_row % 2 == 0, "please provide a tile_width that allows for even tiling of the row"
    assert tiles_per_col % 2 == 0, "please provide a tile_height that allows for even tiling of the columns"
    
    meshgrid_coords = []
    for row_start_idx in range(2):
        for col_start_idx in range(2):
            row_idxs = t.tensor(range(row_start_idx, tiles_per_row,2))
            col_idxs = t.tensor(range(col_start_idx, tiles_per_col,2))
            meshgrid_coords.append(t.meshgrid((row_idxs, col_idxs), indexing="ij"))
    return meshgrid_coords
     

In [121]:
checkerboard_sets = indep_chkboard_sets(test1, tile_height=4, tile_width=4)

4


## Update probability function

In [160]:
def p_update(tiles, cur_vol, target_vol, temperature, src_coords, tgt_coords):
    batch_size, _, _ = tiles.shape 
    src_x, src_y = src_coords
    tgt_x, tgt_y = tgt_coords
    vol_changes = (-1 * tiles[range(batch_size), tgt_y, tgt_x]) + tiles[range(batch_size), src_y, src_x]
    total_vol_change = t.sum(vol_changes)
    #print(f"vol change: {total_vol_change}")
    adjusted_vol = cur_vol + total_vol_change
    p_update = t.tensor(0.)
    if adjusted_vol <= 2 and adjusted_vol > 0:
        p_update += t.exp(-(target_vol - adjusted_vol)**2/temperature)
    return p_update

## MCS algorithm 

In [163]:
def MCS(batch, checkerboard_sets, target_vol, temperature):
    _, batch_height, batch_width = batch.shape
    vol_kernel = t.tensor([[[
        [0., 0., 0.],
        [0., 1., 0.],
        [0., 0., 0.]
    ]]])
    
    steps = []
    for chkboard_set in checkerboard_sets:
        padded_batch = t.nn.ReflectionPad2d(1)(batch).float()
        cur_vol = t.sum(t.nn.functional.conv2d(padded_batch, vol_kernel))
        conv_stack = split_grids(batch, kernel_width=4, kernel_height=4)
        n_rows, n_cols = chkboard_set[0].shape
        print(n_rows, n_cols)
        conv_stack_idxs = [
            tile_idx2conv_idx(
                chkboard_set[0][i,j],
                chkboard_set[1][i,j],
                tile_size=(4,4),
                batch_size=batch.shape
            ) for i in range(n_rows) for j in range(n_cols)
        ]
        print(conv_stack_idxs)
        tiles = t.vstack([conv_stack[0, conv_stack_idxs]]).reshape(len(conv_stack_idxs),4,4)
        print(tiles.shape)
        src_x = t.randint_like(t.zeros(tiles.shape[0],), low=1, high=3)
        src_x = src_x.type(t.long)
        src_y = t.randint_like(src_x, low=1, high=3)
        # For each random sample in src, we sample a random value from [-1, 0, 1]
        # and add it on to the src_idx
        step_sizes = t.tensor(random.choices([
            (1, 0),
            (1, 1),
            (0, 1),
            (-1, 1),
            (-1, 0),
            (-1, -1),
            (0, -1),
            (1, -1)
        ], k=tiles.shape[0]))
        tgt_x = src_x + step_sizes[:, 1]
        tgt_y = src_y + step_sizes[:, 0]
        
        print(src_x, src_y)
        print(tgt_x, tgt_y)
        
        update_probability = p_update(
            tiles,
            cur_vol,
            target_vol,
            temperature,
            (src_x, src_y),
            (tgt_x, tgt_y)
        )
        print(update_probability)
        
        residual = t.rand_like(update_probability)
        
        if update_probability > residual:
            # update the batch here
            pass
        
    

In [164]:
MCS(test1, checkerboard_sets, 1, 1)

2 2
[0, 4, 28, 32]
torch.Size([4, 4, 4])
tensor([1, 1, 2, 1]) tensor([1, 2, 1, 1])
tensor([1, 2, 2, 0]) tensor([2, 2, 2, 0])
tensor(0.)
2 2
[2, 6, 30, 34]
torch.Size([4, 4, 4])
tensor([1, 2, 2, 1]) tensor([1, 1, 2, 2])
tensor([0, 2, 2, 0]) tensor([0, 0, 3, 3])
tensor(0.)
2 2
[14, 18, 42, 46]
torch.Size([4, 4, 4])
tensor([2, 2, 1, 1]) tensor([2, 2, 1, 1])
tensor([3, 3, 1, 2]) tensor([1, 2, 2, 1])
tensor(0.)
2 2
[16, 20, 44, 48]
torch.Size([4, 4, 4])
tensor([1, 2, 2, 1]) tensor([2, 2, 1, 1])
tensor([0, 1, 2, 2]) tensor([2, 2, 2, 0])
tensor(0.)
