In [1]:
import torch as t
from mcs import MCS
from PIL import Image
from scipy import ndimage
import numpy as np
import random

In [2]:
class STESpike(t.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


In [28]:
MOORE_OFFSETS = t.tensor([(1, 1), (1, -1), (1, 0), (-1, 0), (-1, 1), (-1, -1), (0, 1), (0, -1)])

def get_moore_nbh(batch:t.tensor):
    batch_size, batch_height, _ = batch.shape
    cell_pixel_coords = (batch == 1).nonzero()
    moore_nbhs = []
    for batch_idx in range(batch_size):
        batch_slice = cell_pixel_coords[:, 0] == batch_idx
        #print(cell_pixel_coords[batch_slice])
        coords = cell_pixel_coords[batch_slice]
        batch_moore_nbhs = []
        for coord in coords:
            nbh_coords = coord[1:]+MOORE_OFFSETS
            idx_pad = t.zeros(8,1) + batch_idx
            nbh_coords = t.vstack((coord, t.hstack((idx_pad, nbh_coords))))
            #print(nbh_coords)
            batch_moore_nbhs.append(nbh_coords)
        res = t.vstack(batch_moore_nbhs).unique(dim=0)
        res[res == -1] = 1
        res[res == batch_height] = batch_height - 2
        moore_nbhs.append(res)
    return moore_nbhs


In [56]:
def MCS(batch, target_vol, temperature):
    batch_size, batch_height, batch_width = batch.shape
    frontiers = get_moore_nbh(batch)
    for batch_idx, frontier in enumerate(frontiers):
        print(frontier)
        src_coords = random.choice(frontier).type(t.long)
        step_size = random.choice(MOORE_OFFSETS)
        tgt_coords = src_coords.clone()
        tgt_coords[1:] += step_size
        tgt_coords[tgt_coords == -1] = 1
        tgt_coords[tgt_coords == batch_height] = batch_height -2
        print(f"coordinates of the source pixel on the grid: {src_coords}")
        print(f"step size: {step_size}")
        print(f"coordinates of the target pixel on the grid: {tgt_coords}")
        src_i, src_j, src_k = src_coords
        tgt_i, tgt_j, tgt_k = tgt_coords
        
        cur_vol = t.sum(batch[batch_idx])
        vol_change = (-1 * batch[tgt_i, tgt_j, tgt_k]) + batch[src_i, src_j, src_k]
        print("vol change", vol_change)
        adjusted_vol = cur_vol + vol_change
        print(f"adjusted vol: {adjusted_vol}")
        if batch[tgt_i, tgt_j, tgt_k] == batch[src_i, src_j, src_k]:
            # source is equal to target, no update
            print("source is equal to target, no update")
            pass
        elif adjusted_vol > 2 or adjusted_vol <= 0:
            # Changes would violate the hard constraints, no update
            print("Changes would violate the hard constraints, no update")
            pass
        elif cur_vol == 2 and vol_change == -1:
            # Negative Hamiltonian, accepted
            print("Negative Hamiltonian, accepted")
            batch[tgt_i, tgt_j, tgt_k] += vol_change
        else:
            update_probability = t.exp(-((target_vol - adjusted_vol) ** 2) / temperature)
            print(f"update probability: {update_probability}")
            
            residual = t.rand(1)
            print(residual)
            
            upd_val = STESpike.apply(update_probability - residual) * vol_change
            
            print(upd_val)
            
            print(batch[tgt_i, tgt_j, tgt_k])
            print(batch[tgt_i, tgt_j, tgt_k] + upd_val)

            batch[tgt_i, tgt_j, tgt_k] += upd_val.squeeze()
    return batch

In [None]:
def step(batch:t.Tensor, dist_matrix:t.Tensor, temperature:t.Tensor, target_vol:float, eta:float):
    temperature.requires_grad_()
    for i in range(100):
        print(f"-------- MCS {i} --------------")
        batch = MCS(batch, target_vol, temperature)
        if t.any(t.sum(batch, dim=(-1,-2)) == 0) or t.any(t.sum(batch, dim=(-1,-2)) > 2):
            print("ISSUE DETECTED, STOP SIM")
            return None, None
    # calc the distance after 1000 steps    
    dist = t.sum(t.sum(batch * dist_matrix, dim=(-1,-2)) / t.sum(batch, dim=(-1,-2)))
    grad = t.autograd.grad(dist, temperature)[0]
    return temperature - eta * grad, grad

In [62]:
init_state = t.zeros(2,8,8)
init_state[:,4,4] += 1

dist_matrix:np.ndarray = ndimage.distance_transform_edt(1-init_state[0], return_indices=False)
dist_matrix = dist_matrix**2
dist_matrix_t = t.from_numpy(dist_matrix)

print(dist_matrix_t)

target_vol = 1.
temperature = t.tensor(27.)
temperature.requires_grad_()

tensor([[32.0000, 25.0000, 20.0000, 17.0000, 16.0000, 17.0000, 20.0000, 25.0000],
        [25.0000, 18.0000, 13.0000, 10.0000,  9.0000, 10.0000, 13.0000, 18.0000],
        [20.0000, 13.0000,  8.0000,  5.0000,  4.0000,  5.0000,  8.0000, 13.0000],
        [17.0000, 10.0000,  5.0000,  2.0000,  1.0000,  2.0000,  5.0000, 10.0000],
        [16.0000,  9.0000,  4.0000,  1.0000,  0.0000,  1.0000,  4.0000,  9.0000],
        [17.0000, 10.0000,  5.0000,  2.0000,  1.0000,  2.0000,  5.0000, 10.0000],
        [20.0000, 13.0000,  8.0000,  5.0000,  4.0000,  5.0000,  8.0000, 13.0000],
        [25.0000, 18.0000, 13.0000, 10.0000,  9.0000, 10.0000, 13.0000, 18.0000]],
       dtype=torch.float64)


tensor(27., requires_grad=True)

In [58]:
states = [init_state.detach().clone().squeeze().numpy()]
state = init_state.detach().clone()
for i in range(100):
    print(f"-------- MCS {i} --------------")
    state = MCS(state, target_vol, temperature)
    print(state)
    states.append(state.detach().clone().squeeze().numpy())
    if t.any(t.sum(state, dim=(-1,-2)) == 0) or t.any(t.sum(state, dim=(-1,-2)) > 2):
        print("ISSUE DETECTED, STOP SIM",t.sum(state))
        break

-------- MCS 0 --------------
tensor([[0., 3., 3.],
        [0., 3., 4.],
        [0., 3., 5.],
        [0., 4., 3.],
        [0., 4., 4.],
        [0., 4., 5.],
        [0., 5., 3.],
        [0., 5., 4.],
        [0., 5., 5.]])
coordinates of the source pixel on the grid: tensor([0, 4, 3])
step size: tensor([1, 1])
coordinates of the target pixel on the grid: tensor([0, 5, 4])
vol change tensor(0.)
adjusted vol: 1.0
source is equal to target, no update
tensor([[1., 3., 3.],
        [1., 3., 4.],
        [1., 3., 5.],
        [1., 4., 3.],
        [1., 4., 4.],
        [1., 4., 5.],
        [1., 5., 3.],
        [1., 5., 4.],
        [1., 5., 5.]])
coordinates of the source pixel on the grid: tensor([1, 4, 3])
step size: tensor([1, 1])
coordinates of the target pixel on the grid: tensor([1, 5, 4])
vol change tensor(0.)
adjusted vol: 1.0
source is equal to target, no update
tensor([[[0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.

In [60]:
batch_idx = 1
imgs = [Image.fromarray((1-state[batch_idx])*255) for state in states]
imgs[0].save("t_27_reflective.gif", save_all=True, append_images=imgs[1:], duration=10, loop=100)