In [None]:
# importing libraries
import os 
import gc
import sys
import time
import torch
import random
import numpy as np
import matplotlib.pyplot as plt

from typing import Optional
from tqdm.notebook import tqdm
from torch.distributions import MultivariateNormal

## exposing path 
sys.path.insert(0, "../")

from int_filt.src import (
    create_interpolant,
    create_models,
    IdentityPreproc,
    StandardizeSim,
    SimSSM,
    DriftObjective,
)

from int_filt.utils import (
    InputData, 
    OutputData,
    ConfigData,
    configuration,
    ensure_reproducibility, 
    move_batch_to_device,
    dump_config,
    construct_time_discretization,
)

from int_filt.experiments import Experiment

In [None]:
# setting reproducibility
reproducible = True
SEED = 1024 if reproducible else int(time.time())
ensure_reproducibility(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data_folder = "../archive/data/multi_modal_jump_diffusion"
## globals
PREPROCESSING = {
    "none": IdentityPreproc,
    "sim": StandardizeSim,
}

ACTIVATIONS = {
    "relu": torch.nn.ReLU()
}

OPTIMIZERS = {
    "sgd": torch.optim.SGD,
    "adam": torch.optim.Adam,
    "adam-w": torch.optim.AdamW
}

SCHEDULERS = {
    "none": None,
    "cosine-annealing": torch.optim.lr_scheduler.CosineAnnealingLR
}

DEVICES = {
    "cpu": torch.device("cpu"),
    "cuda": torch.device("cuda")
}

In [None]:
## objects and functions for handling experiment
## defining sampling function
def pair_lagged_observations(observation_store, lag):
    num_observations = observation_store.shape[0]
    indexes = torch.arange(num_observations - lag)
    current_observation = observation_store[indexes]
    next_observation = observation_store[indexes + 1]
    return current_observation, next_observation

## defining class for multimodal jump diffusion ssm
class LoadStoredSSM(SimSSM):
    def __init__(self, config: ConfigData):
        self.train_sim = config["train_sim"]
        self.test_sim = config["test_sim"]

# class for handling the paired lagged datasets
class LaggedDataset(torch.utils.data.Dataset):
    def __init__(self, current_states, next_states, device):
        super(LaggedDataset, self).__init__()
        self.current_states = current_states
        self.next_states = next_states
        self.device = device
    
    def __len__(self):
        return self.next_states.shape[0]
    
    def __getitem__(self, idx):
        x0 = self.current_states[idx].to(self.device)
        x1 = self.next_states[idx].to(self.device)
        batch_dict = {"x0": x0, "xc": x0, "x1": x1, "y": x1}
        return batch_dict

## defining class for handling experiments
class MultimodalJumpDiffusion(Experiment):
    def __init__(self, config):
        super(MultimodalJumpDiffusion, self).__init__(config)
        self.N = 1000

    def train_drift(self, config: ConfigData) -> OutputData:
        """
        Trains the $b$ model
        """
        ## retrieving data loader
        data_loader = config["data_loader"]
        ## retrieving optimizer and scheduler
        optimizer = config["b_net_optimizer"]
        scheduler = config["b_net_scheduler"]
        ## initializing objective function
        objective_config = {
            "b_net": self.b_net, 
            "interpolant": self.interpolant, 
            "mc_config": self.mc_config,
            "preprocessing": self.preprocessing,
        }
        objective = DriftObjective(objective_config)
        ## allocating memory for storing loss and lr
        loss_history = torch.zeros((config["num_grad_steps"]))
        lr_history = torch.zeros((config["num_grad_steps"]))
        ## starting optimization
        for epoch in range(config["num_epochs"]):
            ## defining iterator
            iterator = tqdm(range(len(data_loader)))
            for idx, batch in enumerate(data_loader):
                ## preparing batch
                batch = move_batch_to_device(batch, self.device)
                ## estimating loss
                loss_dict = objective.forward(batch)
                # parsing loss dictionary
                loss = loss_dict["loss"]
                ## retrieving loss value
                loss_value = loss.item()
                ## optimization step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                ## scheduler step 
                if scheduler is not None:
                    scheduler.step()
                # retrieving learning rate
                current_lr = optimizer.param_groups[0]["lr"]
                ## progress bar
                msg = f"MSELoss: {loss_value}, Learning Rate {current_lr}"
                iterator.set_description(msg)
                iterator.update()
                ## storing loss and lr and sampled drifts
                current_step = epoch*len(data_loader) + idx
                loss_history[current_step] = loss_value
                lr_history[current_step] = current_lr
                ## cleaning up memory
                if self.clear_memory:
                    del batch, loss_dict 
                    gc.collect()
                    torch.cuda.empty_cache()
        ## constructing output dictionary
        train_dict = {"loss_history": loss_history, "lr_history": lr_history}
        return train_dict

In [None]:
# loading training states and observations
train_state_store = torch.load(os.path.join(data_folder, "train_states.pt"))
train_observation_store = torch.load(os.path.join(data_folder, "train_observations.pt"))
# loading testing states and observations
test_state_store = torch.load(os.path.join(data_folder, "test_states.pt"))
test_observation_store = torch.load(os.path.join(data_folder, "test_observations.pt"))
# defining train simulation dictionary (no perturbation just between observation lag)
train_sim = {"latent_states": train_observation_store, "observations": train_observation_store}
test_sim = {"latent_states": train_observation_store, "observations": test_observation_store}

In [None]:
## defining experiment config
experiment_config = {
    ## interpolant config
    "epsilon": 1e-0,
    "interpolant_method": "pffp_v0",
    ## model config
    "b_net_amortized": False,
    "b_net_activation": "relu",
    "b_net_hidden_dims": [500]*5,
    "b_net_activate_final": False,
    ## training config
    "b_net_lr": 1e-3,
    "num_grad_steps": 1000,
    "b_net_scheduler": "cosine-annealing",
    "b_net_optimizer": "adam-w",
    ## mc estimation config
    "num_mc_samples": 300,
    ## preprocessing config
    "preprocessing": "none",
    "postprocessing": False,
    ## ssm config
    "num_dims": 2,
    ## memory management config
    "clear_memory": True
}

## parsing default arguments
config = configuration(args=[])

## creating experiment
config = vars(config)

## setting current configurations
for k, v in experiment_config.items():
    config[k] = v

## displaying current settings
print(config)

## retrieving activations
config["b_net_activation"] = ACTIVATIONS[config["b_net_activation"]]
## retrieving device
config["device"] = DEVICES[config["device"]]

## adding mc configuration
config["mc_config"] = {"num_mc_samples": config["num_mc_samples"]}


In [None]:
## initializing interpolant
interpolant_config = {"method": config["interpolant_method"], "epsilon": config["epsilon"]}
interpolant = create_interpolant(interpolant_config)
## initializing models
models_config = {
    "backbone": config["backbone"],
    "spatial_dims": config["num_dims"],
    "b_net_hidden_dims": config["b_net_hidden_dims"],
    "b_net_activation": config["b_net_activation"],
    "b_net_activate_final": config["b_net_activate_final"],
    "b_net_amortized": config["b_net_amortized"],
    "device": config["device"]
}
models = create_models(models_config)
b_net = models["b_net"]
print(b_net)

In [None]:
## initializing state space model configuration
ssm_config = {
    "train_sim": train_sim,
    "test_sim": test_sim
}

## initializing state space model
ssm = LoadStoredSSM(ssm_config)

In [None]:
## initializing preprocessing
preprocessing_config = {
    "ssm": ssm
}
preprocessing = PREPROCESSING[config["preprocessing"]](preprocessing_config)

## logging 
writer = None

In [None]:
## initializing experiment
experiment_config = {
    "interpolant": interpolant,
    "b_net": b_net, 
    "ssm": ssm,
    "preprocessing": preprocessing,
    "writer": writer,
    "postprocessing": config["postprocessing"],
    "log_results": config["log_results"], 
    "logging_step": config["logging_step"],
    "mc_config": config["mc_config"],
    "device": config["device"],
    "full_out": config["full_out"],
    "clear_memory": config["clear_memory"],
}

experiment = MultimodalJumpDiffusion(experiment_config)

In [None]:
## prepare for training
## initializing data loader
X_train, Y_train = pair_lagged_observations(train_observation_store, 1)
train_dataset = LaggedDataset(X_train, Y_train, device = device)
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 1000, shuffle = True)

## initializing optimizer and scheduler
b_net_optimizer = OPTIMIZERS[config["b_net_optimizer"]](experiment.b_net.backbone.parameters(), lr = config["b_net_lr"])
b_net_scheduler = SCHEDULERS[config["b_net_scheduler"]]
if b_net_scheduler is not None:
    b_net_scheduler = b_net_scheduler(b_net_optimizer, config["num_grad_steps"])

In [None]:
## getting example batch
data_loader_ = torch.utils.data.DataLoader(train_dataset, batch_size = 1000, shuffle = False)
batch = next(iter(data_loader_))
print(batch)
batch["xt"] = batch["x0"]
batch["xc"] = batch["xc"]
batch["t"] = torch.zeros((experiment.N), device = device)
out = b_net(batch)
print(out)
print(f"{out.shape}")

In [None]:
## standardization
if getattr(experiment.preprocessing, "params", False):
    print(f"STANDARDIZATION: {experiment.preprocessing.params}")
print("BEFORE STANDARDIZATION\n")
for k, v in batch.items():
    print(k, "-> mean: ", v.mean(), ", std: ", v.std(), "shape: ", v.shape)
batch = experiment.preprocessing.standardize(batch)
print("\nAFTER PREPROCESSING\n")
for k, v in batch.items():
    print(k, "-> mean: ", v.mean(), ", std: ", v.std(), "shape: ", v.shape)
batch = experiment.preprocessing.unstandardize(batch)
print("\nAFTER POSTPROCESSING\n")
for k, v in batch.items():
    print(k, "-> mean: ", v.mean(), ", std: ", v.std(), "shape: ", v.shape)

In [None]:
## training
## constructing optimization config dictionary
b_net_optim_config = {
    "data_loader": data_loader,
    "num_epochs": 10,
    "b_net_optimizer": b_net_optimizer,
    "b_net_scheduler": b_net_scheduler,
    "num_grad_steps": 10*len(data_loader),
}

## training b_net 
train_dict = experiment.train_drift(b_net_optim_config)
loss_history = train_dict["loss_history"]
lr_history = train_dict["lr_history"]
## optional logging
if config["log_results"]:
    ## saving the model
    torch.save(experiment.b_net.state_dict(), os.path.join(config["dump_dir"], "b_net.pt"))
## displaying the shape of the results
print(f"{loss_history.shape=}, {lr_history.shape=}")
## plotting loss and lr history
fix, axes = plt.subplots(1, 2)
## plotting loss 
axes[0].plot(loss_history)
axes[0].set_title("Loss History")
axes[0].set_xlabel("Gradient Step")
axes[0].set_ylabel("Loss Value")
## plotting learning rate
axes[1].plot(lr_history)
axes[1].set_title("Learning Rate History")
axes[1].set_xlabel("Gradient Step")
axes[1].set_ylabel("Learning Rate")