In [None]:
from google.colab import drive
drive.mount('/content/drive')
import sys, os
sys.path.append('drive/MyDrive/CNEEP_v2/')
sys.path.append('drive/MyDrive/CNEEP_v2/data/beads')

from argparse import Namespace
import numpy as np
import torch
import torchvision
from datetime import datetime
from utils.sampler import CartesianSeqSampler
from models.train import train
from models.validate import validate
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

In [None]:
#
# Hyper parameters
#
opt = Namespace()
opt.device = "cuda"

opt.N = 2

# alpha-NEEP parameter
opt.alpha   = -0.5
# Masking Regularization parameters
opt.lam     = 0.0
opt.threshold = 0.01

opt.positional = True

opt.latent_size = 10

# gradient descent parameters
opt.n_iter = 200
opt.train_batch_size = 1024
opt.test_batch_size = 4096
opt.video_batch_size = 512
opt.n_hidden = 512
opt.lr = 1e-3
opt.wd = 1e-5

opt.record_freq = 1000
opt.seed = 3

# dataset configurations
opt.n_layer = 2
opt.n_channel = 32
opt.cell_size = 20
opt.input_shape = (opt.cell_size, opt.cell_size * opt.N)
opt.M = 10
opt.L = 1000
opt.burn_in = 2000
opt.seq_len = 2
opt.time_step = 0.01

torch.manual_seed(opt.seed)


#
# path fot results
#
data_folder = "drive/MyDrive/CNEEP_v2/data"
result_folder = "drive/MyDrive/CNEEP_v2/results"
current_result_folder = f"{result_folder}/{datetime.now().strftime("Beads-%Y-%m-%d-%H%M%S")}"
os.makedirs(current_result_folder)

current_checkpoint_path = f"{current_result_folder}/model_parameter.pth.tar"


if not os.path.exists(result_folder): os.makedirs(result_folder)

In [None]:
from data.beads.generate_trajectories import NBeadsModel
from data.beads.generate_animations import generate_brownian_frames

model = NBeadsModel(n_beads=opt.N, dt=opt.time_step)

train_trajectories = model.generate_trajectories(opt.M, opt.L, burn_in=opt.burn_in)
train_video_list = []
for i in range(opt.M):
    frames = generate_brownian_frames(train_trajectories[i], opt.cell_size)
    train_video_list.append(frames)
train_video = torch.tensor(np.array(train_video_list), dtype=torch.float32)[:,:,:,:,1].unsqueeze(2)
print(train_video.shape)

test_trajectories = model.generate_trajectories(opt.M, opt.L, burn_in=opt.burn_in)
test_video_list = []
GT_ent_list = []
GT_ent_per_bead_list = []
GT_heat_list = []

for i in range(opt.M):
    frames = generate_brownian_frames(test_trajectories[i], opt.cell_size)
    test_video_list.append(frames)

    # Note: These methods should exist within your NBeadsModel
    GT_ent_list.append(model.compute_entropy_production_rate(test_trajectories[i]))
    GT_ent_per_bead_list.append(model.compute_entropy_production_per_bead(test_trajectories[i]))
    GT_heat_list.append(model.compute_heat_per_bead(test_trajectories[i]))

test_video = torch.tensor(np.array(test_video_list), dtype=torch.float32)[:,:,:,:,1].unsqueeze(2)
GT_ent = np.array(GT_ent_list)
GT_ent_per_bead = np.array(GT_ent_per_bead_list)
GT_heat = np.array(GT_heat_list)

mean = torch.mean(train_video[0][0])
std = torch.std(train_video[0][0])
transform = lambda x: (x - mean) / (std)

In [None]:
#
# Building our model
#
from models.UNEEP_1P import CNEEP

model = CNEEP(opt)
model = model.to(opt.device)
optim = torch.optim.Adam(
    model.parameters(), opt.lr, weight_decay=opt.wd)
train_sampler = CartesianSeqSampler(
    opt.M, opt.L, opt.seq_len, opt.train_batch_size, device=opt.device
)
test_sampler = CartesianSeqSampler(
    opt.M, opt.L, opt.seq_len, opt.test_batch_size, device=opt.device,
    train=False
)

In [None]:
#
# Training the model
#

train_losses = []
R_values = []
valid_losses = []
best_valid_loss = float('inf')
for i in tqdm(range(1, opt.n_iter + 1)) :
    train_loss, R_value = train(
        opt, model, optim, train_video, train_sampler, transform
    )
    train_losses.append(train_loss)
    R_values.append(R_value)

    _, _, valid_loss = validate(
        opt, model, test_video, test_sampler, transform
    )
    valid_losses.append(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        state = {
            'epoch': i,
            'settings': opt.__dict__,
            'state_dict': model.state_dict(),
            'optimizer': optim.state_dict(),
        }
        torch.save(state, current_checkpoint_path)

model.load_state_dict(
    torch.load(current_checkpoint_path)['state_dict'])
plt.plot(train_losses)
plt.savefig(f"{current_result_folder}/train_loss.png")
plt.clf()
plt.plot(valid_losses)
plt.savefig(f"{current_result_folder}/valid_loss.png")
plt.clf()
plt.plot(R_values)
plt.savefig(f"{current_result_folder}/R_values.png")
plt.clf()

In [None]:
#
# R2-scatter plot
#
from scipy import stats
import seaborn as sns
pred_ent, _, _ = validate(opt, model, test_video, test_sampler, transform)
plt.scatter(GT_ent.flatten(), pred_ent.flatten(), alpha=0.5)

pred = pred_ent.flatten()
ent = GT_ent.flatten()

pred_rate, _, r_value, p_value, _  = stats.linregress(ent, pred)
plt.figure(figsize=(3,3), dpi=100)
sns.regplot(x = pred, y = pred,
    color='C3',
    line_kws={
        'lw':2.0,
        'label':'$R^2=$ %.4f' %(r_value**2)},
    scatter_kws={
        'color':'grey',
        'alpha':0.03,
        's':3,
        'rasterized':True})
plt.savefig(f"{current_result_folder}/R2_plot.png")


In [None]:
#
# PCA study
#
latent_results = []
hooks = []
def hook_latent(module, input, output):
    latent_results.append(output.cpu().detach().numpy())
hooks.append(
    model._modules.get("latent")
    .register_forward_hook(hook_latent)
)

In [None]:
test_sampler = CartesianSeqSampler(
    1, opt.L, opt.seq_len, opt.video_batch_size,
    device = opt.device, train=False)
ent, ent_map, _ = validate(opt, model, test_video[0], test_sampler, transform)

In [None]:
latent_vectors = latent_results[0] - np.mean(latent_results[0], axis = 0)
latent_vectors = latent_vectors / np.std(latent_vectors, axis = 0)
U, S, V = torch.pca_lowrank(torch.tensor(latent_vectors), q=opt.latent_size)
x = U[:, 0]
y = U[:, 1]
color_data = U[:, 2]
colors = (color_data - color_data.mean()) / color_data.std()
plt.scatter(x, y, c=colors, cmap='viridis')
plt.colorbar()
plt.savefig(f"{current_result_folder}/PCA_sccater(0, 1, 2).png")
plt.clf()

np.set_printoptions(precision=2, suppress=True)
print(S)

In [None]:
#
# Visualizing results
#

# animation
ent_map_normalized = (ent_map - ent_map.min()) / (ent_map.max() - ent_map.min())
test_video_np = test_video[0].clone().squeeze(0).squeeze(1).cpu().numpy()

fig, ax = plt.subplots()
im = ax.imshow(test_video_np[0], cmap='gray', animated=True)
overlay = ax.imshow(ent_map_normalized[0], cmap='viridis', alpha=0.5, animated=True)

def update(frame):
    im.set_array(test_video_np[frame])
    overlay.set_array(ent_map_normalized[frame])
    return [im, overlay]

ani = FuncAnimation(fig, update, frames=len(ent_map), blit=True)
ani.save(f"{current_result_folder}/ent_map_animation.mp4", fps=10)
plt.clf()

# mean local EP density
mean_map = ent_map.mean(axis=0)
plt.imshow(mean_map, cmap='viridis')
plt.colorbar(label='Mean Local EP Density')
plt.title('Mean Local EP Density Map')
plt.savefig(f"{current_result_folder}/mean_map.png")
plt.clf()