In [None]:
import os
from os import path, getcwd, listdir, mkdir

import sys
sys.path.insert(0, '../../')

import torch as t
from torch.distributions import uniform
import numpy as np
import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt

from experiment_helper import chop_and_shuffle_data, generate_sequence, dataset_dist
from diffusion_gumbel import diffuse_STE
from reaction_diff import  rho_STE

In [None]:
device = t.device('cuda' if t.cuda.is_available() else 'cpu')
print('Using device:', device)

generate dataset

In [None]:
grid = t.zeros((2, 64, 64), device=device)
grid[:] = 25
grid[0, 29:35] = 40

gamma = 0.005
k1 = k1_bar = 0.98
k2 = k2_bar = 0.1
k3 = k3_bar = 0.2

N = 50
num_steps = 50_000
DA = 0.1
DB = 0.4

generate_sequence(
    grid,
    num_steps,
    N,
    use_diffusion=True,
    DA=DA,
    DB=DB,
    use_reaction=True,
    gamma=gamma,
    k1=k1,
    k1_bar=k1_bar,
    k2=k2,
    k2_bar=k2_bar,
    k3=k3,
    k3_bar=k3_bar,
    create_vis=False,
    save_steps=True,
    create_seq=False,
)

In [None]:
data_dir = path.join(getcwd(), "data", "1674905783.5314908")
ref_state = t.load(path.join(data_dir, "batch_500", "0.pt"), map_location=device)

In [None]:
print(ref_state.shape)
fig, axs = plt.subplots(1,2)
axs[0].imshow(ref_state[0].cpu(),cmap="Greys",interpolation="nearest",vmin=0,vmax=50)
axs[0].set_title("A species")
axs[0].axis("off")
axs[1].imshow(ref_state[1].cpu(),cmap="Greys",interpolation="nearest",vmin=0,vmax=50)
axs[1].set_title("B species")
axs[1].axis("off")
plt.show()

In [None]:
def loss_fn(X, Y):
    return t.log(t.sum((X[:, 0] - Y[:,0])**2 + (X[:, 1] - Y[:,1])**2, dim=(0,1,2)))

In [None]:
# define a sequence of DA values to test
DA_vals = t.linspace(0.0001, 0.9999, 100, device=device)
# define a sequence of DB values to test
DB_vals = t.linspace(0.0001, 0.9999, 100, device=device)
grid_DA, grid_DB = t.meshgrid(DA_vals, DB_vals, indexing="ij")

# rate coefficients
k1 = t.tensor(0.98, device=device)
k1_bar = t.tensor(0.98, device=device)
k2 = t.tensor(0.1, device=device)
k2_bar = t.tensor(0.1, device=device)
k3 = t.tensor(0.2, device=device)
k3_bar = t.tensor(0.2, device=device)
# reaction time constant
gamma = t.tensor(0.005, device=device)


# collect the results of each test run here
grads_DA = []
grads_DB = []
distances = []
# iterate over the DB values
for DA_idx in tqdm(range(grid_DA.shape[0])):
    # iterate over the DA values
    for DB_idx in range(grid_DB.shape[0]):
        X = ref_state.detach().clone()
        
        DA = grid_DA[DA_idx, DB_idx]
        DA.requires_grad_()
        DB = grid_DB[DA_idx, DB_idx]
        DB.requires_grad_()
        for i in range(100):
            # 1. run the diffusion step on each sample
            X = diffuse_STE(X, N, DA, DB)
            # 2. run the reaction step on each sample
            X = rho_STE(X, N, gamma, k1, k1_bar, k2, k2_bar, k3, k3_bar)
        # use the sum of dim-specific distances
        dist_val = loss_fn(X, ref_state)
        distances.append(dist_val.detach().cpu().numpy())
        # find the gradient of the distance w.r.t. the diffusion coefficients
        grad_DA, grad_DB = t.autograd.grad(dist_val, (DA, DB))
        grads_DA.append(grad_DA.detach().cpu().numpy())
        grads_DB.append(grad_DB.detach().cpu().numpy())