# Script for testing DeepPhi models trained on FACS data

In [None]:
import argparse
from argparse import Namespace
import os
import warnings
import numpy as np
import matplotlib.pyplot as plt
import tqdm.notebook as tqdm
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrandom

import equinox as eqx

from plnn.models import DeepPhiPLNN
from plnn.dataset import get_dataloaders
from plnn.data_generation.plnn_animator import PLNNSimulationAnimator
from plnn.io import load_model_from_directory, load_model_training_metadata
from plnn.loss_functions import select_loss_function
from plnn.optimizers import get_dt_schedule

In [None]:
SEED = None
SAVE_ANIMATION = False

BASEOUTDIR = f"./out/test_models_facs"   # Output directory
BASEDIR = "../data/trained_models/facs"  # Directory containing models

# Model directory
MODEL_DIR = "model_facs_v3_dec1b_2dpca_v12_20240716_142138"

MODEL_NAME = MODEL_DIR[0:-16]  # strip time to get model name

In [None]:
# Directory containing training data
if 'facs_v2' in MODEL_DIR:
    DATDIRBASE = "../data/training_data/facs_v2"
elif 'facs_v3' in MODEL_DIR:
    DATDIRBASE = "../data/training_data/facs_v3"
else:
    DATDIRBASE = "../data/training_data/facs"


if "dec1a_2dpca" in MODEL_DIR:
    DATDIR = f"{DATDIRBASE}/pca/dec1/transition1_subset_epi_tr_ce_an_pc12"
elif "dec1b_2dpca" in MODEL_DIR:
    DATDIR = f"{DATDIRBASE}/pca/dec1_fitonsubset/transition1_subset_epi_tr_ce_an_pc12"
elif "dec2a_2dpca" in MODEL_DIR:
    DATDIR = f"{DATDIRBASE}/pca/dec2/transition2_subset_ce_pn_m_pc12"
elif "dec2b_2dpca" in MODEL_DIR:
    DATDIR = f"{DATDIRBASE}/pca/dec2_fitonsubset/transition2_subset_ce_pn_m_pc12"
else:
    raise RuntimeError("Cannot determine DATDIR from MODEL_DIR!")

In [None]:
# If running as a script, overwrite parameters with command line args

def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__ # type: ignore
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

SCRIPT = not is_notebook()

if SCRIPT:
    import tqdm as tqdm
    parser = argparse.ArgumentParser()
    parser.add_argument("--basedir", type=str, 
                        default="data/trained_models/facs")
    parser.add_argument("--modeldir", type=str, required=True)
    parser.add_argument("--modelname", type=str, required=True)
    parser.add_argument("--datdirbase", type=str, 
                        default="data/training_data/facs")
    parser.add_argument("--datdir", type=str, required=True)
    parser.add_argument("--baseoutdir", type=str, 
                        default="notebooks/out/test_models_facs")
    parser.add_argument("--save_animation", action='store_true')
    parser.add_argument("--seed", type=int, default=None)
    args = parser.parse_args()

    BASEDIR = args.basedir
    MODEL_DIR = args.modeldir
    MODEL_NAME = args.modelname
    DATDIRBASE = args.datdirbase
    DATDIR = f"{DATDIRBASE}/{args.datdir}"
    BASEOUTDIR = args.baseoutdir
    SAVE_ANIMATION = args.save_animation
    SEED = args.seed

In [None]:
rng = np.random.default_rng(seed=SEED)

modeldir = f"{BASEDIR}/{MODEL_DIR}"

OUTDIR = f"{BASEOUTDIR}/{MODEL_DIR}"
os.makedirs(OUTDIR, exist_ok=True)

datdir_train = f"{DATDIR}/training"
datdir_valid = f"{DATDIR}/validation"
datdir_test = f"{DATDIR}/testing"

nsims_train = np.genfromtxt(f"{datdir_train}/nsims.txt", dtype=int)
nsims_valid = np.genfromtxt(f"{datdir_valid}/nsims.txt", dtype=int)

try:
    nsims_test = np.genfromtxt(f"{datdir_test}/nsims.txt", dtype=int)
except FileNotFoundError as e:
    msg = f"{e} Reverting to validation data instead."
    warnings.warn(msg)
    datdir_test = f"{DATDIR}/validation"
    nsims_test = np.genfromtxt(f"{datdir_test}/nsims.txt", dtype=int)

### Load the model

In [None]:
# Load the model
model, hyperparams, idx, model_name, model_fpath = load_model_from_directory(
    modeldir, 
    subdir="states",
    idx='best',
    model_class=DeepPhiPLNN,
    dtype=jnp.float64,
)

# Load the argument dictionary and training run dictionary
logged_args, run_dict = load_model_training_metadata(
    modeldir,
    load_all=True
)

loss_id = logged_args['loss']
loss_fn = select_loss_function(
    loss_id, 
    kernel=logged_args.get('kernel'),
    bw_range=logged_args.get('bw_range'),
)

loss_train = run_dict['loss_hist_train']
loss_valid = run_dict['loss_hist_valid']
sigma_hist = run_dict['sigma_hist']
lr_hist = run_dict['learning_rate_hist']
dt_hist = run_dict['dt_hist']

try:
    if dt_hist is None or len(dt_hist) < len(sigma_hist):
        print("Calculuating `dt_hist` to match length of `sigma_hist`")
        dt_schedule = get_dt_schedule(logged_args['dt_schedule'], logged_args)
        dt_hist = np.array([dt_schedule(i) for i in range(len(sigma_hist))])
except (RuntimeError, TypeError) as e:
    print("Could not calculate `dt_hist` to match length of `sigma_hist`")
    print(e)

print(f"Loading model `{model_name}` at epoch {idx} from file: {model_fpath}.")

### Load testing data

In [None]:
ncells_sample = logged_args['ncells_sample']
length_multiplier = logged_args['passes_per_epoch']

_, _, test_loader, _, _, test_dset = get_dataloaders(
    datdir_train, datdir_valid, nsims_train, nsims_valid,
    shuffle_train=False,
    return_datasets=True,
    include_test_data=True,
    datdir_test=datdir_test, nsims_test=nsims_test, shuffle_test=True,
    batch_size_test=20,  # TODO: Batch Testing
    ncells_sample=ncells_sample,
    length_multiplier=length_multiplier,
    seed=rng.integers(2**32)
)

print("Loaded datasets using parameters:")
print("\tncells_sample:", ncells_sample)
print("\tlength_multiplier:", length_multiplier)

# Perform model evaluation on the testing data
Perform one pass through the testing dataset, computing the loss as done in the training process.

In [None]:
key = jrandom.PRNGKey(seed=rng.integers(2**32))

In [None]:
@eqx.filter_jit
def compute_loss(model, x, y, loss_fn, key):
    t0, y0, t1, sigparams = x
    y_pred = model(t0, t1, y0, sigparams, key)
    return loss_fn(y_pred, y), y_pred

In [None]:
from plnn.model_training import validation_step
import time

time0 = time.time()
n = len(test_loader)
running_vloss = 0.0
for i, data in enumerate(tqdm.tqdm(test_loader, disable=SCRIPT)):
    inputs, y1 = data
    key, subkey = jrandom.split(key, 2)
    loss = eqx.filter_jit(validation_step)(model, inputs, y1, loss_fn, subkey)
    running_vloss += loss.item()

avg_loss = running_vloss / n
jax.block_until_ready(avg_loss)

print("Average loss:", avg_loss)

In [None]:
# Save the resulting average loss value in the output directory.
np.save(f"{OUTDIR}/avg_testing_loss.npy", avg_loss)