In [None]:
## importing libraries
import torch
import os
import sys
import gc 

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from pathlib import Path

## exposing path 
sys.path.insert(0, "../")

from int_filt.experiments import create_experiment

from int_filt.utils import (
    configuration, 
    ensure_reproducibility, 
    dump_config,
    dump_tensors, 
    move_batch_to_device
)

In [None]:
## defining macros
ACTIVATIONS = {
    "relu": torch.nn.ReLU()
}

OPTIMIZERS = {
    "sgd": torch.optim.SGD,
    "adam": torch.optim.Adam,
    "adam-w": torch.optim.AdamW
}

SCHEDULERS = {
    "none": None,
    "cosine-annealing": torch.optim.lr_scheduler.CosineAnnealingLR
}

DEVICES = {
    "cpu": torch.device("cpu"),
    "cuda": torch.device("cuda")
}

In [None]:
## defining configurations
DEBUG = True
experiment_config = {
    ## simulation config
    "experiment": "nlg-controlled",
    "controlled": True,
    "non_linearity": "sin",
    "num_dims": 1,
    "num_sims": 1_000,
    "num_iters": 10_000 if DEBUG else 1_000_000,
    "num_burn_in_steps": 0,
    "step_size": 1e-3,
    "sigma_x": 1e-2,
    "sigma_y": 1e-0,
    "beta": 1e-0,
    ## interpolant config
    "epsilon": 1.8e-2,
    "interpolant_method": "pffp_v0",
    ## model config
    "b_net_amortized": False,
    "b_net_activation": "relu",
    "b_net_hidden_dims": [1028],
    "b_net_activate_final": False,
    "c_net_activation": "relu",
    "c_net_hidden_dims": [1028],
    "c_net_activate_final": False,
    ## training config
    "num_grad_steps": 3 if DEBUG else 100,
    "b_net_lr": 1e-3,
    "b_net_scheduler": "none",
    "b_net_optimizer": "adam-w",
    "c_net_lr": 1e-3,
    "c_net_scheduler": "none",
    "c_net_optimizer": "adam-w",
    ## mc estimation config
    "num_mc_samples": 50 if DEBUG else 750,
    ## preprocessing config
    "preprocessing": "sim",
    ## sampling config
    "num_samples": 10 if DEBUG else 450,
    "num_time_steps": 10 if DEBUG else 1_000,
    "num_ar_steps": 10 if DEBUG else 100,
    "full_out": False,
    "initial_time_step": 0,
    "ar_sample_train": False,
    ## logging config
    "log_results": False,
    ## memory config
    "clear_memory": True
}

In [None]:
## parsing arguments
config = configuration(args=[])
## creating experiment
config = vars(config)

## setting current configurations
for k, v in experiment_config.items():
    config[k] = v
## retrieving activation and device
config["b_net_activation"] = ACTIVATIONS[config["b_net_activation"]]
config["c_net_activation"] = ACTIVATIONS[config["c_net_activation"]]
config["device"] = DEVICES[config["device"]]
## adding mc configuration
config["mc_config"] = {"num_mc_samples": config["num_mc_samples"]}

In [None]:
## running simulation
## setting reproducibility
ensure_reproducibility(config["random_seed"])
## creating experiment
experiment = create_experiment(config)

In [None]:
## standardization
batch = experiment.get_batch()
print(f"STANDARDIZATION: {experiment.preprocessing.params}")
print("BEFORE STANDARDIZATION\n")
for k, v in batch.items():
    print(k, "-> mean: ", v.mean(), ", std: ", v.std(), "shape: ", v.shape)
batch = experiment.preprocessing(batch)
print("\nAFTER PREPROCESSING\n")
for k, v in batch.items():
    print(k, "-> mean: ", v.mean(), ", std: ", v.std(), "shape: ", v.shape)
batch = experiment.preprocessing.unstandardize(batch)
print("\nAFTER POSTPROCESSING\n")
for k, v in batch.items():
    print(k, "-> mean: ", v.mean(), ", std: ", v.std(), "shape: ", v.shape)

In [None]:
## training
## initializing optimizer and scheduler
b_net_optimizer = OPTIMIZERS[config["b_net_optimizer"]](experiment.b_net.backbone.parameters(), lr = config["b_net_lr"])
b_net_scheduler = SCHEDULERS[config["b_net_scheduler"]]
if b_net_scheduler is not None:
    b_net_scheduler = b_net_scheduler(b_net_optimizer, config["b_net_num_grad_steps"])
## initializing optimizer and scheduler for optional control term
c_net_optimizer = None
c_net_scheduler = None
if config["controlled"]:
    c_net_optimizer = OPTIMIZERS[config["c_net_optimizer"]](experiment.c_net.model.backbone.parameters(), lr = config["c_net_lr"])
    c_net_scheduler = SCHEDULERS[config["c_net_scheduler"]]
    if c_net_scheduler is not None:
        c_net_scheduler = c_net_scheduler(c_net_optimizer, config["num_grad_steps"])
## constructing optimization config dictionary
optim_config = {
    "b_net_optimizer": b_net_optimizer,
    "b_net_scheduler": b_net_scheduler,
    "c_net_optimizer": c_net_optimizer,
    "c_net_scheduler": c_net_scheduler,
    "num_grad_steps": config["num_grad_steps"],
}
## training
if config["controlled"]:
    train_dict = experiment.train_controlled(optim_config)
else:
    train_dict = experiment.train_drift(optim_config)
## optional logging
if config["log_results"]:
    ## saving the model
    torch.save(experiment.b_net.state_dict(), os.path.join(config["dump_dir"], "b_net.pt"))
    if config["controlled"]:
        torch.save(experiment.c_net.state_dict(), os.path.join(config["dump_dir"], "c_net.pt"))

In [None]:
## plotting training history
fig, axes = plt.subplots()
axes.plot(train_dict["b_loss_history"], label = "drift loss")
axes.plot(train_dict["c_loss_history"], label = "control loss")
axes.set_title("Loss History")
axes.set_xlabel("Gradient Step")
axes.set_ylabel("Loss Value")
legend = plt.legend()

In [None]:
## constructing sampling config dictionary
sde_config = {
    "num_time_steps": config["num_time_steps"],
}

## getting sample batch
batch = experiment.get_batch(train = False)
batch = move_batch_to_device(batch, experiment.device)

## simulating sde on sample batch
sde_dict = experiment.simulate_sde(batch, config = sde_config)
## displaying results shape
for key, tensor in sde_dict.items():
    print(key, f": {tensor.shape}")
    ## plotting aggregated drift history over optimization
if config["full_out"]:
    ## computing statistics over the mc samples
    mean_trajectory = torch.mean(sde_dict["trajectory"], dim = 1)
    std_trajectory = torch.std(sde_dict["trajectory"], dim = 1)
    ## defining figure and axis
    fig, axes = plt.subplots()
    axes.set_title("Simulated trajectories")
    axes.plot(sde_dict["trajectory"][:, :, 0])
    axes.plot(mean_trajectory, label = "mean")
    axes.plot(mean_trajectory + std_trajectory, color = "red", label = "+- std")
    axes.plot(mean_trajectory - std_trajectory, color = "red")
    axes.set_xlabel("Time Step")
    axes.set_ylabel("Simulated Trajectory")
    legend = plt.legend()

In [None]:
## constructing sampling config dictionary
sde_config = {
    "num_time_steps": config["num_time_steps"],
}

## getting sample batch
batch = experiment.get_batch(train = False)
batch = move_batch_to_device(batch, experiment.device)

## simulating sde on sample batch
sde_dict = experiment.simulate_controlled_sde(batch, config = sde_config)
## displaying results shape
for key, tensor in sde_dict.items():
    print(key, f": {tensor.shape}")
    ## plotting aggregated drift history over optimization
if config["full_out"]:
    ## computing statistics over the mc samples
    mean_trajectory = torch.mean(sde_dict["trajectory"], dim = 1)
    std_trajectory = torch.std(sde_dict["trajectory"], dim = 1)
    ## defining figure and axis
    fig, axes = plt.subplots(1, 2, figsize = (10, 5))
    axes[0].set_title("Simulated trajectories")
    axes[0].plot(sde_dict["trajectory"][:, :, 0])
    axes[0].plot(mean_trajectory, label = "mean")
    axes[0].plot(mean_trajectory + std_trajectory, color = "red", label = "+- std")
    axes[0].plot(mean_trajectory - std_trajectory, color = "red")
    axes[0].set_xlabel("Optimization Step")
    axes[0].set_ylabel("Drift b")
    value_trajectory = sde_dict["value_trajectory"].detach()
    ## computing statistics over the mc samples
    mean_trajectory = torch.mean(value_trajectory, dim = 1)
    std_trajectory = torch.std(value_trajectory, dim = 1)
    ## defining figure and axis
    axes[1].set_title("Simulated Value trajectories")
    axes[1].plot(value_trajectory[:, :, 0])
    axes[1].set_xlabel("Time Step")
    axes[1].set_ylabel("Value")

In [None]:
## constructing apf-filtering config dictionary
apf_config = {
    "num_time_steps": 100,
    "num_particles": 2**9,
    "num_obs": 100
}

## filter from initial state and full sequence of observations
observation_idx = 0
num_obs = 1000
num_particles = 2**4
## retrieving initial state for target sequence
x0 = experiment.ssm.test_sim["latent_states"][0, observation_idx, :]
x0 = torch.unsqueeze(x0, dim = 0)
x0 = x0.repeat((apf_config["num_particles"], 1))
## retrieving the whole sequence of observations
y = experiment.ssm.test_sim["observations"][:apf_config["num_obs"], observation_idx, :]
## retrieving ground truth sequence
x_gt = experiment.ssm.test_sim["latent_states"][:apf_config["num_obs"], observation_idx, :]
## constructing the batch
batch = {
    "x0": x0,
    "xc": x0,
    "y": y
}
batch = move_batch_to_device(batch, experiment.device)
for key, tensor in batch.items():
    print(key, tensor.shape)
## running filtering
apf_dict = experiment.FA_APF(batch, apf_config)
## displaying filtering metrics
print(torch.mean(apf_dict["ess"]))
print(apf_dict["log_norm_const"][-1])

In [None]:
## plotting elbo and ess
fig, axes = plt.subplots(1, 2, figsize = (10, 5))
axes[0].plot(apf_dict["log_norm_const"])
axes[0].set_title("Log Observation Likelihood")
axes[1].plot(apf_dict["ess"])
axes[1].set_title("Expected Sample Size")

In [None]:
## plotting filtered states
filtered_states = apf_dict["states"]
fig, axes = plt.subplots()
axes.plot(x_gt, label = "GT latent states", color = "red")
#axes.plot(filtered_states[:, :, 0].T, color = "purple", label = "filtered latent states")
axes.plot(torch.mean(filtered_states[:, :, 0], dim = 0), label = "filtered latent states", color = "purple")
axes.set_title("Filtered vs GT States")
legend = plt.legend()

In [None]:
## autoregressive sampling
AR_SAMPLING = True
if AR_SAMPLING:
    ## constructing autoregressive sampling config dictionary
    ar_sample_config = {
        "num_time_steps": config["num_time_steps"],
        "num_ar_steps": 100,#config["num_ar_steps"],
        "initial_time_step": config["initial_time_step"],
        "ar_sample_train": config["ar_sample_train"],
    }
    ## getting sample batch
    batch = experiment.get_batch(train = False, idx = config["initial_time_step"])
    batch = move_batch_to_device(batch, experiment.device) 
    ## sampling from model
    sample_dict = experiment.ar_sample(batch, config = ar_sample_config)
    ## parsing samples dict
    ar_samples = sample_dict["ar_samples"]
    ## displaying the shape of the results
    print(f"{ar_samples.shape=}", end = "")
    if config["full_out"]:
        trajectory = sample_dict["trajectory"]
        drift = sample_dict["drift"]
        diffusion = sample_dict["diffusion"]
        ## displaying the shape of the results
        print(f", {trajectory.shape=}, {drift.shape=}, {diffusion.shape=}")
    ## computing statistics over the mc samples
    mean_ar_samples = torch.mean(ar_samples, dim = (0, 2))
    std_ar_samples = torch.std(ar_samples, dim = (0, 2))

In [None]:
## plotting example trajectory
gt_trajectory = experiment.ssm.test_sim["latent_states"][:ar_sample_config["num_ar_steps"],:,:]
observation_ids = np.random.randint(args["num_sims"], size = (10))
#observation_idx = 200
fig, axes = plt.subplots(2, 5)
fig.suptitle(f"AR vs GT States Simulation id: {observation_ids}")
for idx in range(10):
    r_id = idx//5
    c_id = idx%5
    axes[r_id, c_id].plot(ar_samples[:, observation_ids[idx], ], label = "Ar samples", color = "purple")
    axes[r_id, c_id].plot(gt_trajectory[:,observation_ids[idx], ], label = "GT trajectory", color = "red")
    #axes.plot(ar_samples[:, :, 0], label = "Ar samples", color = "purple")
    #axes.plot(gt_trajectory[:, :, 0], label = "GT trajectory", color = "red")
    #legend = plt.legend()