In [None]:
## importing packages
import torch
import os
import sys
import joypy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from pathlib import Path
from tqdm import tqdm

## exposing path 
sys.path.insert(0, "../")

from int_filt.experiments import create_experiment
from int_filt.utils.config import configuration
from int_filt.utils.utils import ensure_reproducibility, move_batch_to_device

In [None]:
## globals
ACTIVATIONS = {
    "relu": torch.nn.ReLU()
}

OPTIMIZERS = {
    "adam": torch.optim.Adam,
    "adam-w": torch.optim.AdamW
}

SCHEDULERS = {
    "none": None
}

DEVICES = {
    "cpu": torch.device("cpu"),
    "cuda": torch.device("cuda")
}

## defining simulation settings
num_dims = 1
num_iters = 1000
sigma_x = 1e-2
sigma_y = 1e-2

## defining plotting settings
num_observations_to_plot = 1000
num_iters_to_plot = 200

## defining model settings
b_net_amortized = True
b_net_lr = 1e-3
b_net_num_grad_steps = 1000

## defining sampling settings
num_samples = 500
num_time_steps = 300

Exponential non linearity

In [None]:
## defining configurations
experiment_config = {
    "experiment": "nlg",
    "non_linearity": "exp",
    "num_dims": num_dims,
    "num_iters": num_iters,
    "sigma_x": sigma_x,
    "sigma_y": sigma_y,
    "b_net_amortized": b_net_amortized,
    "b_net_lr": b_net_lr,
    "b_net_num_grad_steps": b_net_num_grad_steps
}
## parsing default arguments
args = configuration(args=[])
## retrieving activations
args.b_net_activation = ACTIVATIONS[args.b_net_activation]
## retrieving device
args.device = DEVICES[args.device]
## creating experiment
args = vars(args)

## setting current configurations
for k, v in experiment_config.items():
    args[k] = v

## adding mc configuration
args["mc_config"] = {"num_samples": args["num_samples"]}

## prepare for training drift
b_net_num_grad_step = args["b_net_num_grad_steps"]
b_net_optimizer = args["b_net_optimizer"]
b_net_scheduler = args["b_net_scheduler"]
b_net_lr = args["b_net_lr"]

## dump dir 
dump_dir = args["dump_dir"]
path = Path(dump_dir)
path.mkdir(parents=True, exist_ok=True)

## reproducibility
random_seed = args["random_seed"]
ensure_reproducibility(random_seed)

## displaying current arguments
print(args)

## creating experiment
experiment_nlg_exp = create_experiment(args)

## joyplot
## retrieving data
latent_states_nlg_exp = torch.squeeze(experiment_nlg_exp.ssm.sim["latent_states"]).numpy().T
observations_nlg_exp = torch.squeeze(experiment_nlg_exp.ssm.sim["latent_states"]).numpy().T
## constructing data frame
#observation_indices = np.arange(args["num_iters"])
observation_indices = np.arange(num_iters_to_plot)
latent_states_nlg_exp = pd.DataFrame(latent_states_nlg_exp[:num_observations_to_plot, :num_iters_to_plot], columns = observation_indices)
## ridge plot
#fig, axes = joypy.joyplot(latent_states_nlg_exp, ylabels=False)

In [None]:
batch = experiment_nlg_exp.get_batch()
print("BEFORE STANDARDIZATION\n")
for k, v in batch.items():
    print(k, "-> mean: ", v.mean(), ", std: ", v.std())
batch = experiment_nlg_exp.standardize(batch)
print("\nAFTER STANDARDIZATION\n")
for k, v in batch.items():
    print(k, "-> mean: ", v.mean(), ", std: ", v.std())

In [None]:
## initializing optimizer
b_net_optimizer = OPTIMIZERS[b_net_optimizer](experiment_nlg_exp.b_net.backbone.parameters(), lr = b_net_lr)

## constructing optimization config dictionary
b_net_optim_config = {
    "num_grad_steps": b_net_num_grad_step,
    "optimizer": b_net_optimizer,
    "scheduler": b_net_scheduler
}

## training b_net 
experiment_nlg_exp.train(b_net_optim_config)
## saving the weights
torch.save(experiment_nlg_exp.b_net.state_dict(), os.path.join(dump_dir, "b_net_exp.pt"))

In [None]:
## constructing sampling config dictionary
sample_config = {
    "num_time_steps": num_time_steps,
    "num_samples": num_samples
}

## getting sample batch
batch = experiment_nlg_exp.get_batch()
batch = experiment_nlg_exp.standardize(batch)
batch = move_batch_to_device(batch, experiment_nlg_exp.device)

## sampling from model
samples = experiment_nlg_exp.sample(batch, sample_config = sample_config)

## sampling from gt state transition
x = batch["x0"]
samples_gt = torch.zeros_like(samples)
for sample_id in tqdm(range(num_samples)):
    x1 = experiment_nlg_exp.ssm.state_transition(x)
    samples_gt[sample_id] = x1
print(f"{samples.shape=}, {samples_gt.shape}")

## plotting example histogram
x1 = samples_gt[:, 0, 0]
x1_hat = samples[:, 0, 0]
sns.kdeplot(x1, label = "ground truth")
sns.kdeplot(x1_hat, label = "prediction")
plt.legend()