# Conditional Score-based Diffusion Models for Time Series Imputation

This notebook investigates the application of Conditional Score-based Diffusion Models (CSDI) to time series imputation. Developed by researchers at Stanford University, CSDI adapts diffusion model principles, primarily used in image and audio synthesis, to effectively handle missing data in time series.

## Background
Time series data, especially in healthcare and finance, often contain gaps resulting from sensor failures or incomplete data capture. Traditional imputation methods typically overlook the complex temporal dependencies inherent in time series data. CSDI addresses these challenges by using a conditional diffusion process to impute missing values, leveraging observed data to guide the imputation process.

## Objectives
- Introduce the theoretical framework and operation of CSDI.
- Implement the CSDI model using PyTorch to showcase its application on real-world datasets.
- Assess CSDI's performance in imputing missing values compared to conventional imputation techniques.

Through the exploration of CSDI's implementation and its imputation efficacy, this notebook aims to highlight its potential in improving the accuracy and reliability of time series analysis.

## Library Imports and Model Setup

In this section, we import necessary libraries and modules required for the implementation of the CSDI model. This includes standard data handling libraries like `numpy` and `pandas`, deep learning libraries from `torch`, and specific components for building and training the CSDI model.

In [1]:
import pickle  # To load or save Python objects
import os  # To interact with the operating system
import re  # To use regular expressions
import numpy as np  # For numerical operations
import pandas as pd  # For data manipulation
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset  # To handle datasets in PyTorch
from torch.optim import Adam

import torch  # Main PyTorch library
import torch.nn as nn  # Neural network module
import torch.nn.functional as F  # Functional interface
import math  # Provides access to mathematical functions
from linear_attention_transformer import (
    LinearAttentionTransformer,
)  # For attention mechanisms

import argparse  # For command-line option and argument parsing
import datetime  # For handling date and time
import json  # To work with JSON data
import yaml  # To handle YAML files

from main_model import CSDI_Physio  # The main model for CSDI on PhysioNet data
from dataset_physio import get_dataloader  # Helper to get data loader for the dataset
from utils import train, evaluate  # Utility functions for training and evaluation

## Experiment Setup and Execution for CSDI Model 

This segment highlights the setup and execution of utilizing CSDI for time series imputation of physiological data. Below is an overview of its components and their functionalities:

### **Imports and Module Loading**
- **Libraries and Modules**: The script begins by importing necessary Python libraries for handling configurations, data manipulations, neural network operations, and file system interactions. Custom modules specific to the CSDI model such as data loaders, training, and evaluation functions are also loaded.

### **Configuration and Argument Parsing**
- **Command Line Arguments**: An `argparse.ArgumentParser` is set up to facilitate input of various experiment parameters through the command line, enhancing the script's flexibility and usability across different experimental conditions without altering the codebase.
- **Configuration File**: The experiment's settings are loaded from a YAML configuration file, allowing easy adjustments to the model and training parameters. Modifications to these settings via command line arguments are directly reflected in the configuration, ensuring that each experiment can be finely tuned.

### **Output Directory Setup**
- **Directory Creation**: The script automatically generates a directory based on the current timestamp and fold number for cross-validation. This structured approach to output management keeps results organized and separated based on experimental conditions, crucial for analysis and comparison of results.
- **Configuration Saving**: The final configuration used for the experiment is saved in the created directory, promoting reproducibility and detailed documentation of experimental conditions.

### **Data Preparation**
- **Data Loaders**: Structured data loaders are prepared for the training, validation, and testing phases, considering factors like batch size and the ratio of missing data, ensuring consistent and correct data handling throughout the experiment.

### **Model Handling**
- **Model Initialization**: The CSDI model is initialized and configured based on the loaded settings and is prepared for deployment on the specified compute device.
- **Conditional Training**: The script provides options to either train the model from scratch or load a pre-trained model, offering flexibility for continued training or fine-tuning of previously trained models.

### **Model Training and Evaluation**
- **Training**: The model undergoes training using the specified settings, with progress and outputs managed through the structured data loaders.
- **Evaluation**: Independently of training, the model is evaluated to assess its performance, particularly focusing on its ability to impute missing values in physiological time series data, with results documented and stored in the designated output directory.



In [None]:
parser = argparse.ArgumentParser(description="CSDI")
parser.add_argument("--config", type=str, default="base.yaml")
parser.add_argument("--device", default="cuda:0", help="Device for Attack")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--testmissingratio", type=float, default=0.1)
parser.add_argument(
    "--nfold", type=int, default=0, help="for 5fold test (valid value:[0-4])"
)
parser.add_argument("--unconditional", action="store_true")
parser.add_argument("--modelfolder", type=str, default="")
parser.add_argument("--nsample", type=int, default=100)

args = parser.parse_args()
print(args)

path = "config/" + args.config
with open(path, "r") as f:
    config = yaml.safe_load(f)

config["model"]["is_unconditional"] = args.unconditional
config["model"]["test_missing_ratio"] = args.testmissingratio

print(json.dumps(config, indent=4))

current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
foldername = "./save/physio_fold" + str(args.nfold) + "_" + current_time + "/"
print("model folder:", foldername)
os.makedirs(foldername, exist_ok=True)
with open(foldername + "config.json", "w") as f:
    json.dump(config, f, indent=4)

train_loader, valid_loader, test_loader = get_dataloader(
    seed=args.seed,
    nfold=args.nfold,
    batch_size=config["train"]["batch_size"],
    missing_ratio=config["model"]["test_missing_ratio"],
)

model = CSDI_Physio(config, args.device).to(args.device)

if args.modelfolder == "":
    train(
        model,
        config["train"],
        train_loader,
        valid_loader=valid_loader,
        foldername=foldername,
    )
else:
    model.load_state_dict(torch.load("./save/" + args.modelfolder + "/model.pth"))

evaluate(model, test_loader, nsample=args.nsample, scaler=1, foldername=foldername)

## Model and Data Setup

Next, we set up the data loaders and initialize the model. The CSDI model, in this notebook has been customized for physiological data from the PhysioNet dataset. It will be prepared along with necessary configurations for training and evaluation.


35 attributes which contains enough non-values

In [None]:
# List of attributes with significant non-values to consider for imputation
attributes = [
    "DiasABP",
    "HR",
    "Na",
    "Lactate",
    "NIDiasABP",
    "PaO2",
    "WBC",
    "pH",
    "Albumin",
    "ALT",
    "Glucose",
    "SaO2",
    "Temp",
    "AST",
    "Bilirubin",
    "HCO3",
    "BUN",
    "RespRate",
    "Mg",
    "HCT",
    "SysABP",
    "FiO2",
    "K",
    "GCS",
    "Cholesterol",
    "NISysABP",
    "TroponinT",
    "MAP",
    "TroponinI",
    "PaCO2",
    "Platelets",
    "Urine",
    "NIMAP",
    "Creatinine",
    "ALP",
]

## Data Handling and Preprocessing 

This section outlines the procedures and components involved in processing physiological time series data for the purpose of time series imputation:

- **extract_hour**: Converts timestamp strings from the dataset into numerical hours of the day. This numerical representation is crucial for aligning and analyzing time series data across multiple days.

- **parse_data**: Processes individual data frames by extracting the last recorded values for a predefined list of attributes. This function fills missing entries with `NaN`, facilitating the creation of a comprehensive feature matrix for each patient over the designated timeframe.

- **parse_id**: Manages data for individual patients by generating arrays of observed values and corresponding masks that indicate actual data points. It introduces controlled randomness in data availability by applying a missing data ratio, which mimics real-world scenarios of incomplete data, enhancing the model's ability to handle such occurrences effectively.

- **get_idlist**: Retrieves a list of patient IDs by scanning a specified directory for data files, ensuring all available data is accounted for and prepared for further processing.

- **Physio_Dataset**: Implements a custom dataset class for physiological data, which handles tasks such as loading, normalizing, and batching the data efficiently. It allows for the dynamic creation of datasets from raw data or loading from preprocessed files, supporting robust data handling within the PyTorch framework.

- **get_dataloader**: Sets up DataLoaders for different phases of model training and evaluation—training, validation, and testing. This function facilitates the division of data into subsets for cross-validation, ensuring each subset is properly shuffled and batched according to the model's requirements.

These components work together to ensure that the data is accurately prepared and readily available for implementing and training the imputation model, optimizing the workflow from raw data handling to model application.


In [None]:
def extract_hour(x):
    # Extracts the hour from a timestamp
    h, _ = map(int, x.split(":"))
    return h


def parse_data(x):
    # Extract the last recorded value for each attribute
    x = x.set_index("Parameter").to_dict()["Value"]
    values = []
    for attr in attributes:
        if x.__contains__(attr):
            values.append(x[attr])
        else:
            values.append(np.nan)  # Use dict.get to simplify missing value handling
    return values


def parse_id(id_, missing_ratio=0.1):
    # Parses the data for a single patient ID, handling missing data by the specified ratio
    data = pd.read_csv("./data/physio/set-a/{}.txt".format(id_))
    # set hour
    data["Time"] = data["Time"].apply(lambda x: extract_hour(x))

    # Create a matrix for 48 hours x 35 attributes
    observed_values = []
    for h in range(48):
        observed_values.append(parse_data(data[data["Time"] == h]))
    observed_values = np.array(observed_values)
    observed_masks = ~np.isnan(observed_values)

    # Randomly set some entries as missing based on the missing_ratio
    masks = observed_masks.reshape(-1).copy()
    obs_indices = np.where(masks)[0].tolist()
    miss_indices = np.random.choice(
        obs_indices, (int)(len(obs_indices) * missing_ratio), replace=False
    )
    masks[miss_indices] = False
    gt_masks = masks.reshape(observed_masks.shape)
    observed_values = np.nan_to_num(observed_values)
    observed_masks = observed_masks.astype("float32")
    gt_masks = gt_masks.astype("float32")
    return observed_values, observed_masks, gt_masks


def get_idlist():
    # Retrieves a list of patient IDs from filenames in a directory
    patient_id = []
    for filename in os.listdir("./data/physio/set-a"):
        match = re.search("\d{6}", filename)
        if match:
            patient_id.append(match.group())
    patient_id = np.sort(patient_id)
    return patient_id


class Physio_Dataset(Dataset):
    # A custom PyTorch Dataset for handling physiological data
    def __init__(self, eval_length=48, use_index_list=None, missing_ratio=0.0, seed=0):
        self.eval_length = eval_length
        np.random.seed(seed)  # Set seed for reproducibility in missing data simulation
        self.observed_values = []
        self.observed_masks = []
        self.gt_masks = []
        # Attempt to load preprocessed data from a pickle file or generate if not available
        path = (
            "./data/physio_missing" + str(missing_ratio) + "_seed" + str(seed) + ".pk"
        )
        if os.path.isfile(path):  # if datasetfile is none, create
            idlist = get_idlist()
            for id_ in idlist:
                try:
                    observed_values, observed_masks, gt_masks = parse_id(
                        id_, missing_ratio
                    )
                    self.observed_values.append(observed_values)
                    self.observed_masks.append(observed_masks)
                    self.gt_masks.append(gt_masks)
                except Exception as e:
                    print(id_, e)
                    continue
            self.observed_values = np.array(self.observed_values)
            self.observed_masks = np.array(self.observed_masks)
            self.gt_masks = np.array(self.gt_masks)

            # calc mean and std and normalize values
            # (it is the same normalization as Cao et al. (2018) (https://github.com/caow13/BRITS))
            tmp_values = self.observed_values.reshape(-1, 35)
            tmp_masks = self.observed_masks.reshape(-1, 35)
            mean = np.zeros(35)
            std = np.zeros(35)
            for k in range(35):
                c_data = tmp_values[:, k][tmp_masks[:, k] == 1]
                mean[k] = c_data.mean()
                std[k] = c_data.std()
            self.observed_values = (
                (self.observed_values - mean) / std * self.observed_masks
            )
            # Save processed data
            with open(path, "wb") as f:
                pickle.dump(
                    [self.observed_values, self.observed_masks, self.gt_masks], f
                )
        else:  # load datasetfile
            with open(path, "rb") as f:
                self.observed_values, self.observed_masks, self.gt_masks = pickle.load(
                    f
                )
        if use_index_list is None:
            self.use_index_list = np.arange(len(self.observed_values))
        else:
            self.use_index_list = use_index_list

    def __getitem__(self, org_index):
        index = self.use_index_list[org_index]
        s = {
            "observed_data": self.observed_values[index],
            "observed_mask": self.observed_masks[index],
            "gt_mask": self.gt_masks[index],
            "timepoints": np.arange(self.eval_length),
        }
        return s

    def __len__(self):
        return len(self.use_index_list)


def get_dataloader(seed=1, nfold=None, batch_size=16, missing_ratio=0.1):
    # only to obtain total length of dataset
    dataset = Physio_Dataset(missing_ratio=missing_ratio, seed=seed)
    indlist = np.arange(len(dataset))
    np.random.seed(seed)
    np.random.shuffle(indlist)

    # 5-fold test
    start = (int)(nfold * 0.2 * len(dataset))
    end = (int)((nfold + 1) * 0.2 * len(dataset))
    test_index = indlist[start:end]
    remain_index = np.delete(indlist, np.arange(start, end))
    np.random.seed(seed)
    np.random.shuffle(remain_index)
    num_train = (int)(len(dataset) * 0.7)
    train_index = remain_index[:num_train]
    valid_index = remain_index[num_train:]
    dataset = Physio_Dataset(
        use_index_list=train_index, missing_ratio=missing_ratio, seed=seed
    )
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=1)
    valid_dataset = Physio_Dataset(
        use_index_list=valid_index, missing_ratio=missing_ratio, seed=seed
    )
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=0)
    test_dataset = Physio_Dataset(
        use_index_list=test_index, missing_ratio=missing_ratio, seed=seed
    )
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=0)
    return train_loader, valid_loader, test_loader

## Building the Diffusion Model

This section outlines the implementation details of the neural network components that form the backbone of the CSDI for time series imputation i.e. diffusion model. 

### Transformer Layers for Long-range Dependencies
The `get_torch_trans` and `get_linear_trans` functions are designed to construct transformer layers, which are important for CSDI's ability to capture long-range dependencies within data sequences. `get_torch_trans` creates a standard transformer encoder that utilizes multiple attention heads to process various segments of the input data concurrently. Conversely, `get_linear_trans` sets up a transformer using linear attention mechanisms to manage longer sequences efficiently by reducing computational overhead, which is beneficial for scalability and performance in large datasets.

### Customized Convolutional Layer Initialization
The initialization of convolutional layers through `Conv1d_with_init` uses the Kaiming normalization method to ensure consistent variance of activations, promoting stable learning dynamics across the model's layers. These convolutional layers are essential for extracting localized features from data, aiding the transformer layers in detailed pattern recognition.

### Embedding Temporal Dynamics with Diffusion Embedding
The `DiffusionEmbedding` class provides a method to embed the diffusion steps used in CSDI into a high-dimensional space using sinusoidal functions. This embedding is integral to the model, as it enables the precise modeling of how data evolves through the diffusion process, which is central to the score-based generative approach of CSDI.

### Model Architecture
The `diff_CSDI` class encapsulates the complete model structure, integrating diffusion embeddings with transformer and convolutional layers. This class demonstrates the application of advanced neural network techniques to build a robust architecture for time series imputation. It is specifically tailored to exploit both the temporal and feature-wise dependencies within data, leveraging the unique properties of diffusion models for effective imputation.

### Residual Blocks for Enhanced Learning
Within `diff_CSDI`, the `ResidualBlock` utilizes a dual-path approach to process information across both time and feature dimensions independently, allowing the model to capture complex interactions in the data. This design enhances the model's predictive accuracy and is a key component of the CSDI architecture, reflecting its capability to adapt and learn from multi-dimensional time series data effectively.

These components collectively form a sophisticated framework designed for the CSDI model, showcasing how theoretical advancements in machine learning can be practically applied to solve challenges in time series imputation with high efficiency and adaptability.


In [None]:
def get_torch_trans(heads=8, layers=1, channels=64):
    encoder_layer = nn.TransformerEncoderLayer(
        d_model=channels, nhead=heads, dim_feedforward=64, activation="gelu"
    )
    return nn.TransformerEncoder(encoder_layer, num_layers=layers)


def get_linear_trans(heads=8, layers=1, channels=64, localheads=0, localwindow=0):
    return LinearAttentionTransformer(
        dim=channels,
        depth=layers,
        heads=heads,
        max_seq_len=256,
        n_local_attn_heads=0,
        local_attn_window_size=0,
    )


def Conv1d_with_init(in_channels, out_channels, kernel_size):
    layer = nn.Conv1d(in_channels, out_channels, kernel_size)
    nn.init.kaiming_normal_(layer.weight)
    return layer


class DiffusionEmbedding(nn.Module):
    def __init__(self, num_steps, embedding_dim=128, projection_dim=None):
        super().__init__()
        if projection_dim is None:
            projection_dim = embedding_dim
        self.register_buffer(
            "embedding",
            self._build_embedding(num_steps, embedding_dim / 2),
            persistent=False,
        )
        self.projection1 = nn.Linear(embedding_dim, projection_dim)
        self.projection2 = nn.Linear(projection_dim, projection_dim)

    def forward(self, diffusion_step):
        x = self.embedding[diffusion_step]
        x = self.projection1(x)
        x = F.silu(x)
        x = self.projection2(x)
        x = F.silu(x)
        return x

    def _build_embedding(self, num_steps, dim=64):
        steps = torch.arange(num_steps).unsqueeze(1)  # (T,1)
        frequencies = 10.0 ** (torch.arange(dim) / (dim - 1) * 4.0).unsqueeze(
            0
        )  # (1,dim)
        table = steps * frequencies  # (T,dim)
        table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)  # (T,dim*2)
        return table


# class diff_CSDI(nn.Module):
#     def __init__(self, config, inputdim=2):
#         super().__init__()
#         self.channels = config["channels"]

#         self.diffusion_embedding = DiffusionEmbedding(
#             num_steps=config["num_steps"],
#             embedding_dim=config["diffusion_embedding_dim"],
#         )

#         self.input_projection = Conv1d_with_init(inputdim, self.channels, 1)
#         self.output_projection1 = Conv1d_with_init(self.channels, self.channels, 1)
#         self.output_projection2 = Conv1d_with_init(self.channels, 1, 1)
#         nn.init.zeros_(self.output_projection2.weight)

#         self.residual_layers = nn.ModuleList(
#             [
#                 ResidualBlock(
#                     side_dim=config["side_dim"],
#                     channels=self.channels,
#                     diffusion_embedding_dim=config["diffusion_embedding_dim"],
#                     nheads=config["nheads"],
#                     is_linear=config["is_linear"],
#                 )
#                 for _ in range(config["layers"])
#             ]
#         )

# def forward(self, x, cond_info, diffusion_step):
#     B, inputdim, K, L = x.shape

#     x = x.reshape(B, inputdim, K * L)
#     x = self.input_projection(x)
#     x = F.relu(x)
#     x = x.reshape(B, self.channels, K, L)

#     diffusion_emb = self.diffusion_embedding(diffusion_step)

#     skip = []
#     for layer in self.residual_layers:
#         x, skip_connection = layer(x, cond_info, diffusion_emb)
#         skip.append(skip_connection)

#     x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
#     x = x.reshape(B, self.channels, K * L)
#     x = self.output_projection1(x)  # (B,channel,K*L)
#     x = F.relu(x)
#     x = self.output_projection2(x)  # (B,1,K*L)
#     x = x.reshape(B, K, L)
#     return x


class ResidualBlock(nn.Module):
    def __init__(
        self, side_dim, channels, diffusion_embedding_dim, nheads, is_linear=False
    ):
        super().__init__()
        self.diffusion_projection = nn.Linear(diffusion_embedding_dim, channels)
        self.cond_projection = Conv1d_with_init(side_dim, 2 * channels, 1)
        self.mid_projection = Conv1d_with_init(channels, 2 * channels, 1)
        self.output_projection = Conv1d_with_init(channels, 2 * channels, 1)

        self.is_linear = is_linear
        if is_linear:
            self.time_layer = get_linear_trans(
                heads=nheads, layers=1, channels=channels
            )
            self.feature_layer = get_linear_trans(
                heads=nheads, layers=1, channels=channels
            )
        else:
            self.time_layer = get_torch_trans(heads=nheads, layers=1, channels=channels)
            self.feature_layer = get_torch_trans(
                heads=nheads, layers=1, channels=channels
            )

    def forward_time(self, y, base_shape):
        B, channel, K, L = base_shape
        if L == 1:
            return y
        y = y.reshape(B, channel, K, L).permute(0, 2, 1, 3).reshape(B * K, channel, L)

        if self.is_linear:
            y = self.time_layer(y.permute(0, 2, 1)).permute(0, 2, 1)
        else:
            y = self.time_layer(y.permute(2, 0, 1)).permute(1, 2, 0)
        y = y.reshape(B, K, channel, L).permute(0, 2, 1, 3).reshape(B, channel, K * L)
        return y

    def forward_feature(self, y, base_shape):
        B, channel, K, L = base_shape
        if K == 1:
            return y
        y = y.reshape(B, channel, K, L).permute(0, 3, 1, 2).reshape(B * L, channel, K)
        if self.is_linear:
            y = self.feature_layer(y.permute(0, 2, 1)).permute(0, 2, 1)
        else:
            y = self.feature_layer(y.permute(2, 0, 1)).permute(1, 2, 0)
        y = y.reshape(B, L, channel, K).permute(0, 2, 3, 1).reshape(B, channel, K * L)
        return y

    def forward(self, x, cond_info, diffusion_emb):
        B, channel, K, L = x.shape
        base_shape = x.shape
        x = x.reshape(B, channel, K * L)

        diffusion_emb = self.diffusion_projection(diffusion_emb).unsqueeze(
            -1
        )  # (B,channel,1)
        y = x + diffusion_emb

        y = self.forward_time(y, base_shape)
        y = self.forward_feature(y, base_shape)  # (B,channel,K*L)
        y = self.mid_projection(y)  # (B,2*channel,K*L)

        _, cond_dim, _, _ = cond_info.shape
        cond_info = cond_info.reshape(B, cond_dim, K * L)
        cond_info = self.cond_projection(cond_info)  # (B,2*channel,K*L)
        y = y + cond_info

        gate, filter = torch.chunk(y, 2, dim=1)
        y = torch.sigmoid(gate) * torch.tanh(filter)  # (B,channel,K*L)
        y = self.output_projection(y)

        residual, skip = torch.chunk(y, 2, dim=1)
        x = x.reshape(base_shape)
        residual = residual.reshape(base_shape)
        skip = skip.reshape(base_shape)
        return (x + residual) / math.sqrt(2.0), skip

## Training of the CSDI Model

This section describes the training process of the Conditional Score-based Diffusion Model (CSDI) and the evaluation metrics used to measure its performance in time series imputation tasks.

### Model Training
The `train` function orchestrates the model training over multiple epochs, handling both the training and optional validation phases to optimize and evaluate the model's performance iteratively:

- **Optimizer and Scheduler**: An Adam optimizer is initialized with specific learning rate and weight decay parameters from the configuration. A learning rate scheduler adjusts the learning rate at predefined milestones to fine-tune the training process as it progresses, typically reducing the learning rate to stabilize training as it nears completion.
- **Training Loop**: The model undergoes training over a specified number of epochs, processing batches of data loaded through `train_loader`. For each batch, the model performs a forward pass to compute the loss, followed by a backward pass to update the model weights. Progress and average loss for each epoch are displayed using the tqdm progress bar.
- **Validation**: If a `valid_loader` is provided, the model periodically evaluates its performance on the validation set after specified intervals. This phase involves computing the loss over the validation data without backpropagation or weight updates, providing an estimate of the model's performance on unseen data.
- **Model Saving**: If improved validation loss is observed, the model's parameters are saved to the specified directory, ensuring that the best-performing model is retained.

### Quantile Loss and CRPS Calculation
- **Quantile Loss**: This custom loss function measures the accuracy of predicted quantiles against actual data. It is used particularly for evaluating forecasts that involve probabilistic predictions. The loss function emphasizes the asymmetry in overestimation versus underestimation, weighted by the quantile level.
- **Continuous Ranked Probability Score (CRPS)**: CRPS is calculated to assess the quality of probabilistic forecasts. It is computed as the average quantile loss across multiple quantiles, providing a single score that summarises the model's accuracy across the entire probability distribution of outcomes.
- **CRPS Sum**: In scenarios involving aggregated or cumulative data predictions, CRPS sum is calculated to evaluate the model's performance on summed predictions. This is particularly useful in applications like financial forecasting where cumulative figures are more relevant than individual predictions.

These components ensure the model is not only trained effectively but also evaluated using metrics that provide deep insights into its probabilistic forecasting abilities. The use of CRPS and quantile-based evaluations aligns well with the needs of applications requiring reliable uncertainty estimates in their predictions.


In [None]:
def train(
    model,
    config,
    train_loader,
    valid_loader=None,
    valid_epoch_interval=20,
    foldername="",
):
    optimizer = Adam(model.parameters(), lr=config["lr"], weight_decay=1e-6)
    if foldername != "":
        output_path = foldername + "/model.pth"

    p1 = int(0.75 * config["epochs"])
    p2 = int(0.9 * config["epochs"])
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[p1, p2], gamma=0.1
    )

    best_valid_loss = 1e10
    for epoch_no in range(config["epochs"]):
        avg_loss = 0
        model.train()
        with tqdm(train_loader, mininterval=5.0, maxinterval=50.0) as it:
            for batch_no, train_batch in enumerate(it, start=1):
                optimizer.zero_grad()

                loss = model(train_batch)
                loss.backward()
                avg_loss += loss.item()
                optimizer.step()
                it.set_postfix(
                    ordered_dict={
                        "avg_epoch_loss": avg_loss / batch_no,
                        "epoch": epoch_no,
                    },
                    refresh=False,
                )
                if batch_no >= config["itr_per_epoch"]:
                    break

            lr_scheduler.step()
        if valid_loader is not None and (epoch_no + 1) % valid_epoch_interval == 0:
            model.eval()
            avg_loss_valid = 0
            with torch.no_grad():
                with tqdm(valid_loader, mininterval=5.0, maxinterval=50.0) as it:
                    for batch_no, valid_batch in enumerate(it, start=1):
                        loss = model(valid_batch, is_train=0)
                        avg_loss_valid += loss.item()
                        it.set_postfix(
                            ordered_dict={
                                "valid_avg_epoch_loss": avg_loss_valid / batch_no,
                                "epoch": epoch_no,
                            },
                            refresh=False,
                        )
            if best_valid_loss > avg_loss_valid:
                best_valid_loss = avg_loss_valid
                print(
                    "\n best loss is updated to ",
                    avg_loss_valid / batch_no,
                    "at",
                    epoch_no,
                )

    if foldername != "":
        torch.save(model.state_dict(), output_path)


def quantile_loss(target, forecast, q: float, eval_points) -> float:
    return 2 * torch.sum(
        torch.abs((forecast - target) * eval_points * ((target <= forecast) * 1.0 - q))
    )


def calc_denominator(target, eval_points):
    return torch.sum(torch.abs(target * eval_points))


def calc_quantile_CRPS(target, forecast, eval_points, mean_scaler, scaler):
    target = target * scaler + mean_scaler
    forecast = forecast * scaler + mean_scaler

    quantiles = np.arange(0.05, 1.0, 0.05)
    denom = calc_denominator(target, eval_points)
    CRPS = 0
    for i in range(len(quantiles)):
        q_pred = []
        for j in range(len(forecast)):
            q_pred.append(torch.quantile(forecast[j : j + 1], quantiles[i], dim=1))
        q_pred = torch.cat(q_pred, 0)
        q_loss = quantile_loss(target, q_pred, quantiles[i], eval_points)
        CRPS += q_loss / denom
    return CRPS.item() / len(quantiles)


def calc_quantile_CRPS_sum(target, forecast, eval_points, mean_scaler, scaler):
    eval_points = eval_points.mean(-1)
    target = target * scaler + mean_scaler
    target = target.sum(-1)
    forecast = forecast * scaler + mean_scaler

    quantiles = np.arange(0.05, 1.0, 0.05)
    denom = calc_denominator(target, eval_points)
    CRPS = 0
    for i in range(len(quantiles)):
        q_pred = torch.quantile(forecast.sum(-1), quantiles[i], dim=1)
        q_loss = quantile_loss(target, q_pred, quantiles[i], eval_points)
        CRPS += q_loss / denom
    return CRPS.item() / len(quantiles)

## Evaluation for CSDI Model 

The `evaluate` function is designed to assess the CSDI model's performance by processing test data loaded through `test_loader`. This function focuses on generating and evaluating a number of sample predictions per test input, facilitating an in-depth analysis of the model's accuracy and reliability.

### Evaluation Procedure
- **Model State and Setup**: The model is set to evaluation mode to disable operations specific to training phases like dropout. The function calculates multiple metrics including mean squared error (MSE), mean absolute error (MAE), and Continuous Ranked Probability Score (CRPS) to provide a holistic view of model performance.

- **Batch Processing**: Test data is processed batch-by-batch. For each batch:
  - The model generates multiple samples for each input sequence to capture the distribution of possible outcomes, which is crucial for evaluating probabilistic models like CSDI.
  - Each sample output is compared to the actual target values using metrics adapted for evaluating forecasts:
    - **Median Prediction**: The median of the generated samples is computed and used to calculate MSE and MAE. This approach highlights the model's ability to predict the central tendency of the distribution accurately.
    - **Quantile Calculation**: Samples are also used to calculate the CRPS, which quantifies the model's accuracy across the entire distribution of forecasts, not just at the median or mean.

### Aggregating Results
- **Data Aggregation**: All metric calculations are accumulated over the entire test set to ensure comprehensive evaluation. This includes aggregating the targets, evaluation points, and generated samples.
- **Persistence**: Results, including raw generated samples and calculated metrics, are saved to files for further analysis and verification. This is important for detailed post-evaluation analysis and for ensuring reproducibility.

### Metrics Computation
- **CRPS and Sum**: The function also computes CRPS and its sum variant, which are particularly informative for aggregated data predictions, such as in financial forecasting or total rainfall estimation. These metrics provide insight into the accuracy of the model's probabilistic forecasts.
- **Output**: The computed root mean squared error (RMSE), MAE, CRPS, and CRPS sum are printed and saved, providing a quantified summary of model performance across different aspects of forecasting accuracy.


In [None]:
def evaluate(model, test_loader, nsample=100, scaler=1, mean_scaler=0, foldername=""):
    with torch.no_grad():
        model.eval()
        mse_total = 0
        mae_total = 0
        evalpoints_total = 0

        all_target = []
        all_observed_point = []
        all_observed_time = []
        all_evalpoint = []
        all_generated_samples = []
        with tqdm(test_loader, mininterval=5.0, maxinterval=50.0) as it:
            for batch_no, test_batch in enumerate(it, start=1):
                output = model.evaluate(test_batch, nsample)

                samples, c_target, eval_points, observed_points, observed_time = output
                samples = samples.permute(0, 1, 3, 2)  # (B,nsample,L,K)
                c_target = c_target.permute(0, 2, 1)  # (B,L,K)
                eval_points = eval_points.permute(0, 2, 1)
                observed_points = observed_points.permute(0, 2, 1)

                samples_median = samples.median(dim=1)
                all_target.append(c_target)
                all_evalpoint.append(eval_points)
                all_observed_point.append(observed_points)
                all_observed_time.append(observed_time)
                all_generated_samples.append(samples)

                mse_current = (
                    ((samples_median.values - c_target) * eval_points) ** 2
                ) * (scaler**2)
                mae_current = (
                    torch.abs((samples_median.values - c_target) * eval_points)
                ) * scaler

                mse_total += mse_current.sum().item()
                mae_total += mae_current.sum().item()
                evalpoints_total += eval_points.sum().item()

                it.set_postfix(
                    ordered_dict={
                        "rmse_total": np.sqrt(mse_total / evalpoints_total),
                        "mae_total": mae_total / evalpoints_total,
                        "batch_no": batch_no,
                    },
                    refresh=True,
                )

            with open(
                foldername + "/generated_outputs_nsample" + str(nsample) + ".pk", "wb"
            ) as f:
                all_target = torch.cat(all_target, dim=0)
                all_evalpoint = torch.cat(all_evalpoint, dim=0)
                all_observed_point = torch.cat(all_observed_point, dim=0)
                all_observed_time = torch.cat(all_observed_time, dim=0)
                all_generated_samples = torch.cat(all_generated_samples, dim=0)

                pickle.dump(
                    [
                        all_generated_samples,
                        all_target,
                        all_evalpoint,
                        all_observed_point,
                        all_observed_time,
                        scaler,
                        mean_scaler,
                    ],
                    f,
                )

            CRPS = calc_quantile_CRPS(
                all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler
            )
            CRPS_sum = calc_quantile_CRPS_sum(
                all_target, all_generated_samples, all_evalpoint, mean_scaler, scaler
            )

            with open(foldername + "/result_nsample" + str(nsample) + ".pk", "wb") as f:
                pickle.dump(
                    [
                        np.sqrt(mse_total / evalpoints_total),
                        mae_total / evalpoints_total,
                        CRPS,
                    ],
                    f,
                )
                print("RMSE:", np.sqrt(mse_total / evalpoints_total))
                print("MAE:", mae_total / evalpoints_total)
                print("CRPS:", CRPS)
                print("CRPS_sum:", CRPS_sum)

### References
- Tashiro, Yusuke, et al. "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation." *Advances in Neural Information Processing Systems*. 2021. [GitHub Repository](https://github.com/ermongroup/CSDI)
