In [None]:
from IPython.display import HTML
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
import torch
from phiml.dataset import LandscapeSimulationDataset
from phiml.helpers import get_binary_function
from model_loader import load_model_directory, load_sigma_history
from plotting import plot_training_loss_history, plot_validation_loss_history
from plotting import plot_loss_history, plot_train_vs_valid_history
from plotting import plot_sigma_history
from plotting import build_video

In [None]:
MODEL = "model7342176"

In [None]:
MODELDIR = f"../models/{MODEL}"
OUTDIR = f"../out/model_analysis/{MODEL}"

os.makedirs(OUTDIR, exist_ok=True)

In [None]:
!cat {MODELDIR + "/log_args.txt"}

In [None]:
model, model_args, loss_hist_train, loss_hist_valid = load_model_directory(
    MODELDIR, MODEL, verbosity=0
)

for k, v in sorted(model_args.items()):
    print(f"{k} : {v}")

In [None]:
print("*** Inferred Model Parameters ***")
print(f"Sigma: {np.exp(model.logsigma.item()):.4g}")
print(f"Tilt map:\n{list(model.tilt_nn.parameters())[0].detach().numpy()}")

## Training/Validation Loss History


In [None]:
logplot = True
startidx = 0
loss_method = model_args['loss']
optimizer = model_args['optimizer']

plot_training_loss_history(
    loss_hist_train, 
    startidx=startidx, log=logplot, 
    title=f"Training Loss ({loss_method}, {optimizer})",
    saveas=f"{OUTDIR}/loss_hist_training.png",
);

plot_validation_loss_history(
    loss_hist_valid, 
    startidx=startidx, log=logplot, 
    title=f"Validation Loss ({loss_method}, {optimizer})",
    saveas=f"{OUTDIR}/loss_hist_validation.png",
);

plot_loss_history(
    loss_hist_train, loss_hist_valid,
    startidx=startidx, log=logplot, 
    title=f"Loss History ({loss_method}, {optimizer})",
    saveas=f"{OUTDIR}/loss_hist.png",
);

plot_train_vs_valid_history(
    loss_hist_train, loss_hist_valid, 
    startidx=startidx, log=logplot, 
    title=f"Loss History ({loss_method}, {optimizer})",
    saveas=f"{OUTDIR}/loss_train_vs_valid.png",
);

## Evolution of Inferred Parameters

In [None]:
sigma_history = load_sigma_history(MODELDIR)
plot_sigma_history(
    sigma_history,
    saveas=f"{OUTDIR}/sigma_history.png",
);

## Visualizing the Inferred Landscape


In [None]:
model.plot_phi(
    signal=[0, 0],
    r=4, res=400, show=True, 
    normalize=True, 
    log_normalize=True,
    clip=None,
    saveas=f"{OUTDIR}/phi_untilted_heatmap.png"
);

In [None]:
model.plot_phi(
    signal=[0, 0],
    r=2, res=200, normalize=True, log_normalize=False,
    clip=None,
    view_init=(60,-20),
    plot3d=True,
    show=True,
    saveas=f"{OUTDIR}/phi_untilted_landscape.png"
);

In [None]:
sig = [0, 0]
model.plot_f(
    signal=sig, r=2, res=20, show=True,
    title=f"$\\vec{{F}}(x,y|\\vec{{s}}=\\langle{sig[0]:.2g},{sig[1]:.2g}\\rangle)$",
    cbar_title="$|\\vec{F}|$",
    saveas=f"{OUTDIR}/field_untilted.png"
);

## Animation of Landscape Evolution

In [None]:
tfin = 1
dt = 1e-2
ts = np.linspace(0, tfin, 1 + int((tfin - 0) / dt))
sig0 = [0, 1]
sig1 = [1, 0]
fs0 = lambda t: (sig1[0] - sig0[0]) / (ts[-1] - ts[0]) * (t - ts[0]) + sig0[0]
fs1 = lambda t: (sig1[1] - sig0[1]) / (ts[-1] - ts[0]) * (t - ts[0]) + sig1[0]
signal_hist = [[fs0(t), fs1(t)] for t in ts]

def plot_frame(signal, t):
    ax = model.plot_phi(
        signal=signal, 
        r=2, res=100,
        plot3d=True,
        log_normalize=True,
        title=f"$t={t:.3f}$, $\\vec{{s}}=[{signal[0]:.3f}, {signal[1]:.3f}]$"
    )
    return ax

anim = build_video(
    lambda i: plot_frame(signal_hist[i], ts[i]),
    nframes=len(ts),
    interval=50
)

anim.save(f"{OUTDIR}/landscape_animation.mp4")

HTML(anim.to_html5_video())

## Animation of cell evolution

In [None]:
y0 = np.zeros([model.get_ncells(), 2])
y0[:,1] = -0.5
tfin = 20
sigparams = [10, 0, 1, 1, 0]
dt = 1e-1
y, yhist = model.simulate_single_batch(0, tfin, y0, sigparams, dt)
ts = np.linspace(0, tfin, 1 + int((tfin - 0) / dt))
fsig = get_binary_function(sigparams[0], sigparams[1:3], sigparams[3:])
signal_hist = [fsig(t) for t in ts]

def plot_frame(data, signal, t):
    ax = model.plot_phi(
        signal=signal, 
        r=2, res=100,
        log_normalize=False,
    )
    ax.plot(data[:,0], data[:,1], '.', c='k', markersize=2)
    ax.set_title(f"$t={t:.3f}$")
    return ax

anim = build_video(
    lambda i: plot_frame(yhist[i], signal=signal_hist[i], t=ts[i]),
    nframes=len(yhist),
    interval=50
)

anim.save(f"{OUTDIR}/cell_animation.mp4")

HTML(anim.to_html5_video())

## Examine training and validation datasets

In [None]:
datdir_train = "../" + model_args['training_data']
datdir_valid = "../" + model_args['validation_data']
nsims_train = model_args['nsims_training']
nsims_valid = model_args['nsims_validation']

In [None]:
train_dataset = LandscapeSimulationDataset(
    datdir_train, nsims_train, model_args['ndims'], 
    transform='tensor', 
    target_transform='tensor'
)

validation_dataset = LandscapeSimulationDataset(
    datdir_valid, nsims_valid, model_args['ndims'], 
    transform='tensor', 
    target_transform='tensor'
)

In [None]:
idx = 1
r = 2
train_dataset.preview(idx, xlims=[-r,r], ylims=[-r,r]);

In [None]:
simidx = 0
r = 2
ani = train_dataset.animate(
    simidx, xlims=[-r, r], ylims=[-r,r], 
    show=False, interval=1000
)

HTML(ani)