In [None]:
from IPython.display import HTML
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
import torch
from phiml.dataset import LandscapeSimulationDataset
from model_loader import load_model_directory
from plotting import plot_training_loss_history, plot_validation_loss_history

In [None]:
MODEL = "model7296801"

In [None]:
MODELDIR = f"../models/{MODEL}"
OUTDIR = f"../out/model_analysis/{MODEL}_tester"

os.makedirs(OUTDIR, exist_ok=True)

In [None]:
!cat {MODELDIR + "/log_args.txt"}

In [None]:
model, model_args, loss_hist_train, loss_hist_valid = load_model_directory(
    MODELDIR, MODEL, verbosity=0
)

for k, v in sorted(model_args.items()):
    print(f"{k} : {v}")

In [None]:
print("*** Inferred Model Parameters ***")
print(f"Sigma: {np.exp(model.logsigma.item()):.4g}")
print(f"Tilt map:\n{list(model.tilt_nn.parameters())[0].detach().numpy()}")

In [None]:
logplot = False
startidx = 0
loss_method = model_args['loss']
optimizer = model_args['optimizer']

plot_training_loss_history(
    loss_hist_train, 
    startidx=startidx, log=logplot, 
    title=f"Training Loss ({loss_method}, {optimizer})",
    saveas=f"{OUTDIR}/loss_hist_training.png",
);

plot_validation_loss_history(
    loss_hist_valid, 
    startidx=startidx, log=logplot, 
    title=f"Validation Loss ({loss_method}, {optimizer})",
    saveas=f"{OUTDIR}/loss_hist_validation.png",
);

In [None]:
model.plot_phi(
    r=2, res=400, show=True, 
    normalize=True, 
    log_normalize=False,
    clip=None,
    saveas=f"{OUTDIR}/phi_heatmap.png"
);

In [None]:
model.plot_phi(
    r=3, res=200, show=True, normalize=True, log_normalize=False,
    plot3d=True,
    saveas=f"{OUTDIR}/phi_landscape.png"
);

In [None]:
sig = [0, 0]
model.plot_f(
    signal=sig, r=2, res=20, show=True,
    title=f"$\\vec{{F}}(x,y|\\vec{{s}}=\\langle{sig[0]:.2g},{sig[1]:.2g}\\rangle)$",
    cbar_title="$|\\vec{F}|$",
    saveas=f"{OUTDIR}/f_plot.png"
);

In [None]:
y0 = np.ones([model.get_ncells(), 2]) * -0.5
sigparams = [10, 0, 0, 1, 1]
dt = 1e-1
y, yhist = model.simulate_single_batch(0, 20, y0, sigparams, dt)

In [None]:
def plot_frame(data, col='b', size=2, xlabel='$x$', ylabel='$y$',
               xlim=[-2,2], ylim=[-2,2]):
    fig, ax = plt.subplots(1, 1)
    ax.plot(data[:,0], data[:,1], '.', 
            c=col, markersize=size)
    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    return ax

video = []  # Collect frames into an array
for idx in range(len(yhist)):
    ax = plot_frame(yhist[idx])  # axis plotting function
    ax.figure.canvas.draw()
    data = np.frombuffer(ax.figure.canvas.tostring_rgb(), dtype=np.uint8)
    data = data.reshape(ax.figure.canvas.get_width_height()[::-1] + (3,))
    video.append(data)
    plt.close()
video = np.array(video)

fig = plt.figure()
plt.axis('off')
plt.tight_layout()
im = plt.imshow(video[0])
plt.close() 
def init():
    im.set_data(video[0])

def ani_func(i):
    im.set_data(video[i])
    return im

anim = animation.FuncAnimation(
    fig, ani_func, init_func=init,
    frames=video.shape[0],
    interval=50
)

ani = anim.to_html5_video()

In [None]:
HTML(ani)

## Examine training and validation datasets

In [None]:
datdir_train = "../" + model_args['training_data']
datdir_valid = "../" + model_args['validation_data']
nsims_train = model_args['nsims_training']
nsims_valid = model_args['nsims_validation']

In [None]:
train_dataset = LandscapeSimulationDataset(
    datdir_train, nsims_train, model_args['ndims'], 
    transform='tensor', 
    target_transform='tensor'
)

validation_dataset = LandscapeSimulationDataset(
    datdir_valid, nsims_valid, model_args['ndims'], 
    transform='tensor', 
    target_transform='tensor'
)

We can preview an individual training datum, which consists of a distribution of cells $X_0\in\mathbb{R}^{n\times d}$ at time $t_0$, and a subsequent distribution of cells $X_1\in\mathbb{R}^{n\times d}$ at time $t_1$, along with parameters $\vec{p}\in\mathbb{R}^{n_{p}}$, where $n$ is the number of cells, $d$ the dimension of the cells, and $n_p$ the number of parameters that parameterize the signal function $f_{sig}(t,\cdot)$.


In [None]:
idx = 1
r = 2
train_dataset.preview(idx, xlims=[-r,r], ylims=[-r,r]);

We can view an animation of a full simulation, which consists of many consecutive time point pairs

In [None]:
simidx = 0
r = 2
ani = train_dataset.animate(
    simidx, xlims=[-r, r], ylims=[-r,r], 
    show=False, interval=1000
)

In [None]:
HTML(ani)