# Script for testing DeepPhi models trained on FACS data

In [None]:
import os
import argparse
from argparse import Namespace
import warnings
import time
import numpy as np
import matplotlib.pyplot as plt
import tqdm.notebook as tqdm
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrandom

import equinox as eqx

from plnn.models import DeepPhiPLNN
from plnn.dataset import get_dataloaders
from plnn.data_generation.plnn_animator import PLNNSimulationAnimator
from plnn.io import load_model_from_directory, load_model_training_metadata
from plnn.loss_functions import select_loss_function
from plnn.optimizers import get_dt_schedule
from plnn.model_training import validation_step

In [None]:
SEED = None
SAVE_ANIMATION = False

BASEOUTDIR = f"./out/test_models_facs"   # Output directory
BASEDIR = "../data/trained_models/facs"  # Directory containing models

# Model directory
MODEL_DIR = "model_facs_v3_dec1b_2dpca_v12b_20240719_005108"

MODEL_NAME = MODEL_DIR[0:-16]  # strip time to get model name

In [None]:
# Directory containing training data
if 'facs_v2' in MODEL_DIR:
    DATDIRBASE = "../data/training_data/facs_v2"
elif 'facs_v3' in MODEL_DIR:
    DATDIRBASE = "../data/training_data/facs_v3"
else:
    DATDIRBASE = "../data/training_data/facs"


if "dec1a_2dpca" in MODEL_DIR:
    DATDIR = f"{DATDIRBASE}/pca/dec1/transition1_subset_epi_tr_ce_an_pc12"
elif "dec1b_2dpca" in MODEL_DIR:
    DATDIR = f"{DATDIRBASE}/pca/dec1_fitonsubset/transition1_subset_epi_tr_ce_an_pc12"
elif "dec2a_2dpca" in MODEL_DIR:
    DATDIR = f"{DATDIRBASE}/pca/dec2/transition2_subset_ce_pn_m_pc12"
elif "dec2b_2dpca" in MODEL_DIR:
    DATDIR = f"{DATDIRBASE}/pca/dec2_fitonsubset/transition2_subset_ce_pn_m_pc12"
else:
    raise RuntimeError("Cannot determine DATDIR from MODEL_DIR!")

In [None]:
# If running as a script, overwrite parameters with command line args

def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__ # type: ignore
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

SCRIPT = not is_notebook()

if SCRIPT:
    import tqdm as tqdm
    parser = argparse.ArgumentParser()
    parser.add_argument("--basedir", type=str, 
                        default="data/trained_models/facs")
    parser.add_argument("--modeldir", type=str, required=True)
    parser.add_argument("--modelname", type=str, required=True)
    parser.add_argument("--datdirbase", type=str, 
                        default="data/training_data/facs")
    parser.add_argument("--datdir", type=str, required=True)
    parser.add_argument("--baseoutdir", type=str, 
                        default="notebooks/out/test_models_facs")
    parser.add_argument("--save_animation", action='store_true')
    parser.add_argument("--seed", type=int, default=None)
    args = parser.parse_args()

    BASEDIR = args.basedir
    MODEL_DIR = args.modeldir
    MODEL_NAME = args.modelname
    DATDIRBASE = args.datdirbase
    DATDIR = f"{DATDIRBASE}/{args.datdir}"
    BASEOUTDIR = args.baseoutdir
    SAVE_ANIMATION = args.save_animation
    SEED = args.seed

In [None]:
rng = np.random.default_rng(seed=SEED)

modeldir = f"{BASEDIR}/{MODEL_DIR}"

OUTDIR = f"{BASEOUTDIR}/{MODEL_DIR}"
os.makedirs(OUTDIR, exist_ok=True)

datdir_train = f"{DATDIR}/training"
datdir_valid = f"{DATDIR}/validation"
datdir_test = f"{DATDIR}/testing"

nsims_train = np.genfromtxt(f"{datdir_train}/nsims.txt", dtype=int)
nsims_valid = np.genfromtxt(f"{datdir_valid}/nsims.txt", dtype=int)

try:
    nsims_test = np.genfromtxt(f"{datdir_test}/nsims.txt", dtype=int)
except FileNotFoundError as e:
    msg = f"{e} Reverting to validation data instead."
    warnings.warn(msg)
    datdir_test = f"{DATDIR}/validation"
    nsims_test = np.genfromtxt(f"{datdir_test}/nsims.txt", dtype=int)

### Load the model

In [None]:
# Load the model
model, hyperparams, idx, model_name, model_fpath = load_model_from_directory(
    modeldir, 
    subdir="states",
    idx='best',
    model_class=DeepPhiPLNN,
    dtype=jnp.float64,
)

# Load the argument dictionary and training run dictionary
logged_args, run_dict = load_model_training_metadata(
    modeldir,
    load_all=True
)

loss_id = logged_args['loss']
loss_fn = select_loss_function(
    loss_id, 
    kernel=logged_args.get('kernel'),
    bw_range=logged_args.get('bw_range'),
)

loss_train = run_dict['loss_hist_train']
loss_valid = run_dict['loss_hist_valid']
sigma_hist = run_dict['sigma_hist']
lr_hist = run_dict['learning_rate_hist']
dt_hist = run_dict['dt_hist']

try:
    if dt_hist is None or len(dt_hist) < len(sigma_hist):
        print("Calculuating `dt_hist` to match length of `sigma_hist`")
        dt_schedule = get_dt_schedule(logged_args['dt_schedule'], logged_args)
        dt_hist = np.array([dt_schedule(i) for i in range(len(sigma_hist))])
except (RuntimeError, TypeError) as e:
    print("Could not calculate `dt_hist` to match length of `sigma_hist`")
    print(e)

print(f"Loading model `{model_name}` at epoch {idx} from file: {model_fpath}.")

### Load testing data

In [None]:
ncells_sample = logged_args['ncells_sample']
length_multiplier = logged_args['passes_per_epoch']
batch_size = 20

_, _, test_loader, _, _, test_dset = get_dataloaders(
    datdir_train, datdir_valid, nsims_train, nsims_valid,
    shuffle_train=False,
    return_datasets=True,
    include_test_data=True,
    datdir_test=datdir_test, nsims_test=nsims_test, shuffle_test=True,
    batch_size_test=batch_size,  # TODO: Batch Testing
    ncells_sample=ncells_sample,
    length_multiplier=length_multiplier,
    seed=rng.integers(2**32)
)

print("Loaded datasets using parameters:")
print("\tncells_sample:", ncells_sample)
print("\tdataset base length:", test_dset.get_baselength())
print("\tlength_multiplier:", length_multiplier)
print("\tdataset length:", len(test_dset))
print("\tbatch size:", batch_size)
print("\tdataloader length:", len(test_loader))

# Perform model evaluation on the testing data
Perform one pass through the testing dataset, computing the loss as done in the training process.

In [None]:
key = jrandom.PRNGKey(seed=rng.integers(2**32))

In [None]:
@eqx.filter_jit
def compute_loss(model, x, y, loss_fn, key):
    t0, y0, t1, sigparams = x
    y_pred = model(t0, t1, y0, sigparams, key)
    return loss_fn(y_pred, y), y_pred

In [None]:
time0 = time.time()
n = len(test_loader)
running_vloss = 0.0
for i, data in enumerate(tqdm.tqdm(test_loader, disable=SCRIPT)):
    inputs, y1 = data
    key, subkey = jrandom.split(key, 2)
    loss = eqx.filter_jit(validation_step)(model, inputs, y1, loss_fn, subkey)
    running_vloss += loss.item()

avg_loss = running_vloss / n
jax.block_until_ready(avg_loss)

print("Average loss:", avg_loss)

# Save the resulting average loss value in the output directory.
np.save(f"{OUTDIR}/avg_testing_loss.npy", avg_loss)

In [None]:
faster_model = eqx.tree_at(lambda m: m.dt0, model, 0.01)

In [None]:
time0 = time.time()
n = len(test_loader)
running_vloss = 0.0
for i, data in enumerate(tqdm.tqdm(test_loader, disable=SCRIPT)):
    inputs, y1 = data
    key, subkey = jrandom.split(key, 2)
    loss = eqx.filter_jit(validation_step)(faster_model, inputs, y1, loss_fn, subkey)
    running_vloss += loss.item()

avg_loss = running_vloss / n

In [None]:
avg_loss = running_vloss / n
print("Average loss:", avg_loss)

## Deterministic Testing

Construct a testing dataloader that does not shuffle the data.
For each datapoint, run the forward model multiple times (for multiple values of `dt`).
Determine the distribution of the loss for each datapoint.
See if there is a correlation between the loss and time, or between the loss and experimental condition.
Run the same sort of trial for the training data.

In [None]:
_, valid_loader, test_loader, _, valid_dset, test_dset = get_dataloaders(
    datdir_train, datdir_valid, nsims_train, nsims_valid,
    return_datasets=True,
    include_test_data=True,
    shuffle_train=False,
    shuffle_valid=False,
    shuffle_test=False,
    datdir_test=datdir_test, 
    nsims_test=nsims_test, 
    batch_size_test=1,  # Needs to be 1 otherwise loss is averaged
    ncells_sample=ncells_sample,
    length_multiplier=1,
    seed=rng.integers(2**32)
)

In [None]:
NITERS = 10  # Number of iterations per datapoint

loader = test_loader

times = np.nan * np.ones(len(loader))
conditions = np.nan * np.empty([len(loader), 2, 4])
results = np.nan * np.empty([len(loader), NITERS])

for i, data in enumerate(tqdm.tqdm(loader)):
    inputs, y1 = data
    times[i] = inputs[0][0]
    conditions[i,:] = inputs[-1][0]
    for k in range(NITERS):
    
        key, subkey = jrandom.split(key, 2)
        loss = validation_step(
            faster_model, inputs, y1, loss_fn, subkey
        )
        results[i, k] = loss.item()
    

In [None]:
# import seaborn as sns

plt.plot(results.mean(axis=1))

# sns.lineplot(data=results, x="timepoint", y="signal", hue="event", errorbar=('sd', 1))

In [None]:
def run_trial(
        model,
        loader,
        niters,
        key, 
):
    n = len(loader)
    times = np.nan * np.ones(n)
    conditions = np.nan * np.empty([n, 2, 4])
    losses = np.nan * np.empty([n, niters])

    # validation_stepper = eqx.filter_jit(validation_step)

    @eqx.filter_jit
    def batch_stepper(model, inputs, y1, loss_fn, keys):
        subkeys = jrandom.split(key, niters)
        losses = jax.vmap(validation_step, (None, None, None, None, 0))(
            model, inputs, y1, loss_fn, subkeys
        )
        return losses

    for i, data in enumerate(tqdm.tqdm(loader)):
        inputs, y1 = data
        times[i] = inputs[0][0]
        conditions[i,:] = inputs[-1][0]
        key, subkey = jrandom.split(key, 2)        
        losses[i,:] = batch_stepper(model, inputs, y1, loss_fn, subkey)

    return losses, times, conditions


In [None]:
key, subkey = jrandom.split(key, 2)
results = run_trial(faster_model, loader, 10, subkey)

In [None]:
losses = results[0]

plt.plot(losses.mean(axis=1));

In [None]:
NSAMPLES = 1
NITERS_PER_DATUM = 10
SCAN_DT0 = [0.1, 0.05, 0.01, 0.005]

TRAIN_RESULTS = {}

def get_loader(
        valid_or_test, *, 
        length_multiplier=1,
        ncells_sample=ncells_sample,
        seed=None
):
    _, valid_loader, test_loader, _, valid_dset, test_dset = get_dataloaders(
        datdir_train, datdir_valid, nsims_train, nsims_valid,
        return_datasets=True,
        include_test_data=True,
        shuffle_train=False,
        shuffle_valid=False,
        shuffle_test=False,
        datdir_test=datdir_test, 
        nsims_test=nsims_test, 
        batch_size_train=1,
        batch_size_valid=1,
        batch_size_test=1,  # Needs to be 1 otherwise loss is averaged
        ncells_sample=ncells_sample,
        length_multiplier=length_multiplier,
        seed=seed
    )
    return {'valid': valid_loader, 'test': test_loader}[valid_or_test]


In [None]:
for dt0 in SCAN_DT0:
    TRAIN_RESULTS[dt0] = []
    model = eqx.tree_at(lambda m: m.dt0, model, dt0)
    
    loader = get_loader('valid', rng.integers(2**32))
    key, subkey = jrandom.split(key, 2)
    res = run_trial(
        model, loader, NITERS_PER_DATUM, subkey
    )
    TRAIN_RESULTS[dt0].append(res)


In [None]:
TRAIN_RESULTS

In [None]:
import pickle
with open(f"{OUTDIR}/saved_train_results.okl", 'wb') as f:
    pickle.dump(TRAIN_RESULTS, f)


In [None]:
def run_trial2(
        model, 
        ncells_sample,
        train_valid_test,
        n_resamp,
        n_reps,
        batch_size,
        key,
        rng=None,
):
    if rng is None:
        rng = np.random.default_rng()
    
    loader = get_loader(
        train_valid_test, 
        length_multiplier=n_resamp, 
        ncells_sample=ncells_sample,
        seed=rng.integers(2**32)
    )
    n = len(loader)
    
    times = np.nan * np.ones(n)
    conditions = np.nan * np.ones([n, 2, 4])
    losses = np.nan * np.ones([n, n_reps])

    inputs_array = []
    y1_array = []
    
    for i, data in enumerate(loader):
        inputs, y1 = data
        times[i] = inputs[0][0]
        conditions[i,:] = inputs[-1][0]
        inputs_array.append(inputs)
        y1_array.append(y1)
    
    inputs_array = jax.tree_map(lambda *x: jnp.stack(x), *inputs_array)
    y1_array = jnp.array(y1_array)
    
    @eqx.filter_jit
    def validation_step_ntimes(n_reps, model, inputs, y1, loss_fn, key):
        subkeys = jrandom.split(key, n_reps)
        losses = jax.vmap(validation_step, (None, None, None, None, 0))(
            model, inputs, y1, loss_fn, subkeys
        )
        return losses
    
    @eqx.filter_jit
    def step_ntimes_vectorized(
            n_reps, model, inputs_arr, y1_arr, loss_fn, key,
    ):
        subkeys = jrandom.split(key, len(y1_arr))
        res = jax.vmap(validation_step_ntimes, (None, None, 0, 0, None, 0))(
            n_reps, model, inputs_arr, y1_arr, loss_fn, subkeys
        )
        return res

    
    nbatches = n // batch_size + (n % batch_size != 0)
    
    count = 0
    for batch_idx in tqdm.tqdm(range(nbatches)):
        time0 = time.time()
        key, subkey = jrandom.split(key, 2)
        idx0 = count
        idx1 = min(count + batch_size, n)
        
        partial_inputs_array = [arr[idx0:idx1] for arr in inputs_array]
        partial_y1_array = y1_array[idx0:idx1]

        results = step_ntimes_vectorized(
            n_reps, model, partial_inputs_array, partial_y1_array, loss_fn, key
        )
        losses[idx0:idx1] = results
        count += len(results)
        print(f"  time: {time.time() - time0} ")

    return losses, times, conditions

In [None]:
loader = get_loader('test', length_multiplier=1, seed=rng.integers(2**32))
print("dataloader length:", len(loader))

In [None]:
NRESAMP = 20
NREPS = 10

key, subkey = jrandom.split(key, 2)
results = run_trial2(
    faster_model, ncells_sample, "test",
    n_resamp=NRESAMP,
    n_reps=NREPS,
    batch_size=40,
    key=subkey,
    rng=rng
)

In [None]:
results.shape

In [None]:
NTIMES = 6
NRESAMP = 20
NREPS = 10
NCONDS = int(len(results) // NTIMES // NRESAMP)

losses = results[:]
losses = losses.reshape([NRESAMP, NCONDS, NTIMES, NREPS])
losses = losses.transpose(1, 2, 0, 3)


print("(NCONDS, NTIMES, NRESAMP, NREPS):", losses.shape)

In [None]:
fig, [ax1, ax2] = plt.subplots(2, 1)

for sampidx in range(NRESAMP):
    vals1 = losses[0, :, sampidx, :].mean(1)
    vals2 = losses[1, :, sampidx, :].mean(1)
    ax1.plot(vals1)
    ax2.plot(vals2)
    ax1.set_title("Condition 1")
    ax2.set_title("Condition 2")
plt.tight_layout()



fig, ax = plt.subplots(1, 1)

timepoints = np.arange(2, 5, 0.5) + 0.25

for condidx in range(NCONDS):
    avg_losses_over_reps = losses[condidx].mean(-1)
    mean_loss_over_samps = avg_losses_over_reps.mean(-1)
    std_loss_over_samps = avg_losses_over_reps.std(-1)
    # ax.plot(std_loss_over_samps, label=f'Cond {condidx+1}')
    print(std_loss_over_samps)
    
    ax.errorbar(
        timepoints, 
        mean_loss_over_samps, 
        yerr=2*std_loss_over_samps,
        capsize=3, linestyle="--", label=f"Cond {condidx + 1}"
    )

ax.set_xlim(2, 5)
ax.legend()
ax.set_xlabel("timepoint")
plt.show()
            