## Example TimeMCL for toy datasets.

This notebook contains an example for training TimeMCL on synthetic datasets.

### Setup

First create the conda virtual environment with the required packages with `conda create -n synth_env -y python=3.10.15`. Then activate it with `source activate synth_env`, and install the required packages with `cd toy ; pip install -r requirements.txt` before running the next cells.

LaTeX can optionally be used for plot rendering. It can be installed with: `sudo apt-get install -y dvipng texlive-latex-extra texlive-fonts-recommended cm-super`.

In [None]:
import os
import sys
import numpy as np
import rootutils
import torch 
import matplotlib.pyplot as plt
rootutils.setup_root(search_from='.', indicator=".project-root", pythonpath=True)

sys.path.append(os.path.dirname(os.environ["PROJECT_ROOT"]))
sys.path.append(os.path.join(os.environ["PROJECT_ROOT"], "toy"))
from toy import tMCL, train_tMCL, plot_brownian_bridge,plot_brownien,plot_ARp_quantization, is_usetex_available
import torch 


### Training

Training can be performed with the following commands

Three datasets are supported, Brownian motion, ARp and Brownian bridge.

In [None]:
dataset_name = "brownian_bridge" # "brownian_bridge" ,"brownian_motion" or "ARp"

In [None]:
# Training and model parameters
batch_size=  4096 # Training batch size
device= "cuda" if torch.cuda.is_available() else "cpu" # device to use for training
wta_mode="relaxed_wta" # WTA mode to use for training
n_hypotheses= 10 # Number of hypotheses to use for training
num_steps= 500 # Number of training steps
learning_rate= 0.001 # Learning rate for training
cond_dim = 2 if dataset_name=="brownian_bridge" else 1 if dataset_name=="brownian_motion" else 5 # Input dimension of the model

# Dataset parameters
nb_discretization_points = 500 # Total number of discretization points
p = 5 if dataset_name=="ARp" else None # AR(p) model order
nb_step_simulation= 250 - p if dataset_name=="ARp" else 250
sigma =  0.06 if dataset_name=="ARp" else None # Noise level for AR(p)
coefficients = [0.4, 0.2, 0.2, 0.1, 0.1] if dataset_name=="ARp" else None # Coefficients for AR(p)
init_values = None # Initial values for AR(p)

In [None]:
model = tMCL(
    cond_dim=cond_dim,
    nb_step_simulation=nb_step_simulation,
    n_hypotheses=n_hypotheses,
    device=device,
    loss_type=wta_mode,
)

In [None]:
additional_params = {"p": p, "coefficients": coefficients, "sigma": sigma, "init_values": init_values} if "ARp" else {}

trained_model = train_tMCL(
model=model,
process_type=dataset_name,
num_steps=num_steps,
batch_size=batch_size,
nb_discretization_points=nb_discretization_points,
interval_length=interval_length,
device=device,
learning_rate=learning_rate,
additional_params=additional_params,
)

### Plotting

In [None]:
# Parameters
interval_length = nb_step_simulation # Length of the interval to simulate
m = 2 # parameter m in the K-L decomposition of the eigenfunctions
N_levels = [5, 2] # number of levels for the quantization of the eigenfunctions
a = 0 # Starting point of the brownian bridge
b = 1 # Ending point of the brownian bridge
pred_length = nb_step_simulation # Length of the prediction
t_condition= 100 if dataset_name=="ARp" else 0.5 # Time condition for the AR(p) model

In [None]:
from matplotlib import rc

rc("text", usetex=True if is_usetex_available() else False)
rc("font", family="serif")

fig, ax = plt.subplots(figsize=(10, 6))

if dataset_name == "brownian_bridge":
    result = plot_brownian_bridge(
        interval_length=interval_length,
        nb_discretization_points=nb_discretization_points,
        m=m,
        N_levels=N_levels,
        a=a,
        b=b,
        t_condition=t_condition,
        trained_model=trained_model,
        ax=ax,
    )
    ax.set_title("Brownian Bridge", fontsize=28)

elif dataset_name == "brownian_motion":
    result = plot_brownien(
        T=1,
        t_condition=t_condition,
        pred_length=pred_length,
        num_steps=num_steps,
        m=m,
        N_levels=N_levels,
        trained_model=trained_model,
        ax=ax,
    )
    ax.set_title("Brownian Motion", fontsize=28)

elif dataset_name == "ARp":
    result = plot_ARp_quantization(
        batch_size=batch_size,
        nb_discretization_points=nb_discretization_points,
        interval_length=interval_length,
        coefficients=coefficients,
        sigma=sigma,
        t_condition=t_condition,
        trained_model=trained_model,
        ax=ax,
    )
    ax.set_title(f"AR(p) â€“ p = {len(coefficients)}", fontsize=28)

ax.tick_params(axis="x", labelsize=20)
ax.tick_params(axis="y", labelsize=20)
ax.set_xlabel("Time", fontsize=25)
ax.grid()

handles, labels = ax.get_legend_handles_labels()
plt.tight_layout()

unique_handles_labels = dict(zip(labels, handles))

# Get the unique handles and labels
unique_labels = list(unique_handles_labels.keys())
unique_handles = list(unique_handles_labels.values())

# Create common legend above the figure
fig.legend(
    unique_handles,
    unique_labels,
    loc="upper center",
    bbox_to_anchor=(0.5, 1.1),
    ncol=3,
    fontsize=20,
)