# CSDI for Time Series Forecasting

This notebook demonstrates the usage of the CSDI (Conditional Score-based Diffusion Models) model for time series forecasting tasks, specifically for electricity consumption data.

The notebook is organized as follows:

1. [Library Imports and Model Setup]()


2. [Configuration and Argument Parsing]()


3. [Electricity Dataset]()

	3.1. [Download Dataset]()
    
	3.2. [Data Loaders]()
    
	
4. [Forecasting using CSDI]()

	4.1. [Experiment Setup and Execution]()

	4.2. [Model Training]()

	4.3. [Loading Pretrained Model]()

	4.4. [Model Evaluation]()

# Library Imports and Model Setup

In this section, we import necessary libraries and modules required for the implementation of the CSDI model. This includes standard data handling libraries like `numpy` and `pandas`, deep learning libraries from `torch`, and specific components for building and training the CSDI model.

In [1]:
import datetime
import json
import os
import shutil
import sys

import gdown
import numpy as np
import torch
import yaml


# Add the parent directory to sys.path to import local modules
sys.path.append(os.path.abspath(os.path.join(".")))

from dataset.dataset_forecasting import get_dataloader
from main_model import CSDI_base
from util.utils import evaluate, train

## Configuration and Argument Parsing
The experiment's settings are loaded from a YAML configuration file, allowing easy adjustments to the model and training parameters. Modifications to these settings via command line arguments are directly reflected in the configuration, ensuring that each experiment can be finely tuned.

In [4]:
"""{
    "train": {
        "epochs": 100,  # Total number of training cycles through the entire dataset.
        "batch_size": 8,  # Number of data samples processed before the model's internal parameters are updated.
        "lr": 0.001,  # Learning rate, determines the step size at each iteration while moving toward a minimum of the loss function.
        "itr_per_epoch": 100000000.0  # Presumably meant to be iterations per epoch.
    },
    "diffusion": {
        "layers": 4,  # Number of layers in the diffusion model, affects depth and complexity.
        "channels": 64,  # Number of channels in each layer, influences the model's capacity to process information.
        "nheads": 8,  # Number of attention heads in transformer-based models, affects the model's ability to focus on different parts of the input sequence.
        "diffusion_embedding_dim": 128,  # Dimension of the embeddings used in the diffusion process, impacts the representational power.
        "beta_start": 0.0001,  # Initial value of the noise schedule, dictates how much noise starts the diffusion process.
        "beta_end": 0.5,  # Final value in the noise schedule, controls how much noise is removed by the end of the diffusion process.
        "num_steps": 50,  # Total number of steps in the diffusion process from start to completion.
        "schedule": "quad",  # Type of scheduling for the beta values, 'quad' implies a quadratic progression.
        "is_linear": True  # Indicates whether the scheduling progression is linear, set to true which conflicts with 'quad' indicating a possible oversight or specific implementation.
    },
    "model": {
        "is_unconditional": 0,  # Specifies whether the model is unconditional (0 indicates conditional).
        "timeemb": 128,  # Dimension of the time embeddings, used in models that incorporate timing information in their predictions.
        "featureemb": 16,  # Dimension of the feature embeddings, provides additional contextual information per feature.
        "target_strategy": "test",  # Strategy for targeting in training/testing, 'test' might indicate a specific approach or mode used during evaluation.
        "num_sample_features": 64  # Number of features to sample, relevant in scenarios like feature ablation or when working with high-dimensional data.
    }
}
"""


def load_config(config_path="config/base_forecasting.yaml"):
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
    return config


config = load_config()
print(json.dumps(config, indent=4))

{
    "train": {
        "epochs": 100,
        "batch_size": 8,
        "lr": 0.001,
        "itr_per_epoch": 100000000.0
    },
    "diffusion": {
        "layers": 4,
        "channels": 64,
        "nheads": 8,
        "diffusion_embedding_dim": 128,
        "beta_start": 0.0001,
        "beta_end": 0.5,
        "num_steps": 50,
        "schedule": "quad",
        "is_linear": true
    },
    "model": {
        "is_unconditional": 0,
        "timeemb": 128,
        "featureemb": 16,
        "target_strategy": "test",
        "num_sample_features": 64
    }
}


### Device Configuration
First, we determine the appropriate computation environment. If a GPU is available, the model will utilize it for faster computation; otherwise, it defaults to using the CPU. This ensures that the setup is optimized for performance regardless of the hardware available.

In [3]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Electricity Dataset

In this notebook, we will use the Electricity dataset for time series forecasting. This dataset is sourced from the UCI Machine Learning Repository and contains the electricity consumption (measured in kWh) of 370 customers over time. The dataset includes electricity consumption measurements taken at regular intervals, making it suitable for various time series analysis and forecasting tasks.

Let's get started by loading and exploring the dataset!


## Download Dataset

Run this cell to download the dataset if you haven't already.

In [2]:
def download_data(folder_url, output_path="data/electricity_nips"):
    """
    Download all files from a Google Drive folder and save them to the specified output path.
    Checks if data already exists in the output directory before downloading.
    """
    # Create the output directory if it doesn't exist
    os.makedirs(output_path, exist_ok=True)

    # Check if the directory is not empty, implying data might already be downloaded
    if os.listdir(output_path):
        print(f"Data already exists in {output_path}. Skipping download.")
        return

    # Extract the folder ID from the URL
    folder_id = folder_url.split("/")[-1]

    # List files in the folder
    url = f"https://drive.google.com/drive/folders/{folder_id}"
    temp_output_path = os.path.join(output_path, "temp")
    os.makedirs(temp_output_path, exist_ok=True)
    output = gdown.download_folder(
        url, output=temp_output_path, quiet=False, use_cookies=False
    )

    for root, dirs, files in os.walk(temp_output_path):
        for file in files:
            shutil.move(os.path.join(root, file), output_path)

    shutil.rmtree(temp_output_path)
    print(f"Downloaded files to {output_path}")


# Example usage with a specified folder URL
folder_url = "https://drive.google.com/drive/folders/1krZQofLdeQrzunuKkLXy8L_kMzQrVFI_"
download_data(folder_url)

Data already exists in data/electricity_nips. Skipping download.


## Data Loaders
Data loaders are set up for the training, validation, and testing phases. These loaders are crucial for managing the data flow during model training and evaluation, ensuring efficient handling of data batches and the appropriate application of missing data simulations as specified in the experimental setup.

In [5]:
# Set up dataloaders
datatype = "electricity"
target_dim = 370  # for electricity dataset

train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader(
    datatype=datatype, device=device, batch_size=config["train"]["batch_size"]
)

## Forecasting using CSDI
CSDI leverages its capabilities not only for imputation but also for forecasting, which involves predicting future values in a time series. While forecasting, CSDI uses conditional score-based diffusion to model future data distributions conditioned on observed past data. This extension is particularly useful for time series analysis where future trends are to be anticipated based on existing data patterns.

### Key Differences

- **Objective**: Forecasting predicts future, unknown values, while imputation estimates missing, yet contemporaneous values within a dataset.
- **Conditional Modeling**: Both processes utilize the observed values, but forecasting uses them to predict beyond the existing sequence, whereas imputation fills gaps within the sequence.
- **Data Utilization**: In forecasting, all observed data points contribute to predicting future states; in imputation, observed data points are used to reconstruct the missing values between them.

The `CSDI_base` class serves as a foundational framework for both forecasting and imputation tasks in time series analysis, illustrating how forecasting is effectively an extension of imputation. This class incorporates essential elements such as conditional inputs and configurable diffusion parameters, allowing it to adapt dynamically based on whether it is filling missing data within a sequence (imputation) or predicting future values (forecasting). Both tasks leverage a shared methodological core: using observed or past data to inform the generation of new data points, whether these points are within existing gaps or beyond the known sequence

In [6]:
class CSDI_Forecasting(CSDI_base):
    """
    class CSDI_Forecasting Initializes the forecasting model derived from the CSDI_base class:
      - Sets up the model with specific configurations and assigns it to a computational device.
      - Initializes key attributes including the dimensionality of the target data and the number of sampling features based on the model configuration.
    """

    def __init__(self, config, device, target_dim):
        super(CSDI_Forecasting, self).__init__(target_dim, config, device)
        self.target_dim_base = target_dim
        self.num_sample_features = config["model"]["num_sample_features"]

    """
    def process_data
    - Processes the input batch for the model, formatting and preparing data for subsequent operations.
    - `batch`: The input batch containing observed data, masks, and timepoints.
    - Transfers all relevant data to the computation device and converts it to floating-point format.
    - Reorders dimensions of observed data and masks to match the expected input format of the model.
    - Initializes a tensor for cut lengths to dynamically manage sequence processing.
    - Generates a tensor of feature identifiers for subsequent data handling processes.
    """

    def process_data(self, batch):
        observed_data = batch["observed_data"].to(self.device).float()
        observed_mask = batch["observed_mask"].to(self.device).float()
        observed_tp = batch["timepoints"].to(self.device).float()
        gt_mask = batch["gt_mask"].to(self.device).float()

        observed_data = observed_data.permute(0, 2, 1)
        observed_mask = observed_mask.permute(0, 2, 1)
        gt_mask = gt_mask.permute(0, 2, 1)

        cut_length = torch.zeros(len(observed_data)).long().to(self.device)
        for_pattern_mask = observed_mask

        feature_id = (
            torch.arange(self.target_dim_base)
            .unsqueeze(0)
            .expand(observed_data.shape[0], -1)
            .to(self.device)
        )

        return (
            observed_data,
            observed_mask,
            observed_tp,
            gt_mask,
            for_pattern_mask,
            cut_length,
            feature_id,
        )

    """
    def sample_features
    - Randomly samples a subset of features from the observed data for model processing.
    - `observed_data`: Tensor containing the observed data points.
    - `observed_mask`: Tensor indicating the presence of observed data points.
    - `feature_id`: Tensor of feature identifiers corresponding to observed data.
    - `gt_mask`: Tensor of ground truth masks indicating valid data points for evaluation.
    - Samples a specified number of features (defined by `num_sample_features`) from each observation in the batch.
    - Reorganizes the data, masks, and feature identifiers based on the sampled features to prepare for model input.
    - Returns the newly formatted data, masks, feature identifiers, and ground truth masks.
    """

    def sample_features(self, observed_data, observed_mask, feature_id, gt_mask):
        size = self.num_sample_features
        self.target_dim = size
        extracted_data = []
        extracted_mask = []
        extracted_feature_id = []
        extracted_gt_mask = []

        for k in range(len(observed_data)):
            ind = np.arange(self.target_dim_base)
            np.random.shuffle(ind)
            extracted_data.append(observed_data[k, ind[:size]])
            extracted_mask.append(observed_mask[k, ind[:size]])
            extracted_feature_id.append(feature_id[k, ind[:size]])
            extracted_gt_mask.append(gt_mask[k, ind[:size]])
        extracted_data = torch.stack(extracted_data, 0)
        extracted_mask = torch.stack(extracted_mask, 0)
        extracted_feature_id = torch.stack(extracted_feature_id, 0)
        extracted_gt_mask = torch.stack(extracted_gt_mask, 0)
        return extracted_data, extracted_mask, extracted_feature_id, extracted_gt_mask

    """
    def get_side_info
    - Generates side information combining time embeddings and feature embeddings based on provided masks.
    - `observed_tp`: Tensor containing the time points of observations.
    - `cond_mask`: Conditional mask tensor that specifies which data points are used for conditioning.
    - `feature_id`: Optional tensor of feature identifiers; used if model's target dimension differs from the base.
    - Constructs time embeddings for each time point and expands these across the target dimensions.
    - Depending on the model's configuration, either generates a static feature embedding for all features or specific embeddings based on the provided `feature_id`.
    - Concatenates time and feature embeddings to form a comprehensive side information tensor.
    - Reorders dimensions to match the expected input structure for further model processes.
    - If the model conditions on the input (is not unconditional), includes the conditional mask in the side information.
    - Returns the compiled side information for use in the model’s prediction or forecasting tasks.
    """

    def get_side_info(self, observed_tp, cond_mask, feature_id=None):
        B, K, L = cond_mask.shape

        time_embed = self.time_embedding(observed_tp, self.emb_time_dim)  # (B,L,emb)
        time_embed = time_embed.unsqueeze(2).expand(-1, -1, self.target_dim, -1)

        if self.target_dim == self.target_dim_base:
            feature_embed = self.embed_layer(
                torch.arange(self.target_dim).to(self.device)
            )  # (K,emb)
            feature_embed = feature_embed.unsqueeze(0).unsqueeze(0).expand(B, L, -1, -1)
        else:
            feature_embed = (
                self.embed_layer(feature_id).unsqueeze(1).expand(-1, L, -1, -1)
            )
        side_info = torch.cat([time_embed, feature_embed], dim=-1)  # (B,L,K,*)
        side_info = side_info.permute(0, 3, 2, 1)  # (B,*,K,L)

        if self.is_unconditional == False:
            side_mask = cond_mask.unsqueeze(1)  # (B,1,K,L)
            side_info = torch.cat([side_info, side_mask], dim=1)

        return side_info

    """
    def forward
    - Defines the forward pass for the model, processing input data and executing training or evaluation steps.
    - `batch`: Input batch containing observed data, masks, and additional information.
    - `is_train`: Indicator of whether the model is in training mode (1) or evaluation mode (0).
    - Processes the input batch to format and prepare data structures for model operations.
    - Conditionally samples a subset of features from the observed data if in training mode and feature sampling is necessary due to dimensionality constraints.
    - Sets the target dimensionality based on the model's base or sampled features.
    - Determines the conditioning mask based on the mode of operation, using ground truth masks for evaluation or generating test patterns for training.
    - Retrieves side information incorporating time and feature embeddings tailored to the conditioning context.
    - Selects the appropriate loss calculation function based on the training or evaluation context.
    - Computes and returns the loss by comparing model predictions with actual observations and side information, adjusted for the specified training or evaluation mode.
    """

    def forward(self, batch, is_train=1):
        (
            observed_data,
            observed_mask,
            observed_tp,
            gt_mask,
            _,
            _,
            feature_id,
        ) = self.process_data(batch)
        if is_train == 1 and (self.target_dim_base > self.num_sample_features):
            observed_data, observed_mask, feature_id, gt_mask = self.sample_features(
                observed_data, observed_mask, feature_id, gt_mask
            )
        else:
            self.target_dim = self.target_dim_base
            feature_id = None

        if is_train == 0:
            cond_mask = gt_mask
        else:  # test pattern
            cond_mask = self.get_test_pattern_mask(observed_mask, gt_mask)

        side_info = self.get_side_info(observed_tp, cond_mask, feature_id)

        loss_func = self.calc_loss if is_train == 1 else self.calc_loss_valid

        return loss_func(observed_data, cond_mask, observed_mask, side_info, is_train)

    """
    def evaluate
    - Conducts evaluation of the model by generating imputed samples and preparing metrics.
    - `batch`: Input batch containing observed data and masks.
    - `n_samples`: Number of imputation samples to generate for each point.
    - Processes the batch to extract data and masks, including the length for each sequence to avoid redundancy.
    - Sets up a non-gradient context for evaluation to prevent backpropagation and save computation.
    - Uses ground truth masks to determine conditional and target masks.
    - Retrieves side information based on observed time points and conditional masks.
    - Generates multiple imputed data samples.
    - Returns the generated samples along with the observed data and masks for further assessment.
    """

    def evaluate(self, batch, n_samples):
        (
            observed_data,
            observed_mask,
            observed_tp,
            gt_mask,
            _,
            _,
            feature_id,
        ) = self.process_data(batch)

        with torch.no_grad():
            cond_mask = gt_mask
            target_mask = observed_mask * (1 - gt_mask)

            side_info = self.get_side_info(observed_tp, cond_mask)

            samples = self.impute(observed_data, cond_mask, side_info, n_samples)

        return samples, observed_data, target_mask, observed_mask, observed_tp

## Experiment Setup and Execution

This segment highlights the setup and execution of utilizing CSDI for time series forecasting of electricity data. 

### Model Initialization
The model, `CSDI_Forecasting`, is initialized based on predefined configurations. This specialized model is designed to handle electricity time series data, inheriting robust functionalities from its base class to effectively manage the specific requirements of time series forecasting tasks.

In [7]:
# Set up model
model = CSDI_Forecasting(config, device, target_dim).to(device)

## Model Training

### Output Folder Setup
Before training the model, an output directory is created to store training artifacts such as model checkpoints, logs, and output files. The directory name includes a timestamp to ensure uniqueness and to help track experiments based on the date and time they were performed.

- **Directory Naming**: The folder is named using the current date and time, which helps in organizing and retrieving model training sessions based on when they were conducted.
- **Creation**: The directory is created on the file system, ensuring it exists before any training outputs are written to it. This prevents errors related to file writing during model training.

### Model Training Process
The model is trained using the specified configurations, data loaders, and the path to the output directory. The training function is designed to handle both the training and validation phases within each epoch, allowing for a comprehensive assessment of model performance over time.

- **Training Function**: Takes the model, training configurations, and data loaders as inputs. Additionally, it accepts the path to the output folder where the training results are stored.
- **Validation Data**: Optionally, a validation loader can be passed to periodically evaluate model performance on a separate validation set during the training process.

### Execution
Upon execution, the training process iteratively updates the model weights based on the loss computed from the training data. It also evaluates the model on the validation set, if provided, to monitor its performance on unseen data. Results and model states are saved in the designated output directory, facilitating post-training evaluations and model deployment.

### Loss Function
The loss function employed in the CSDI model is designed to optimize the model's ability to denoise data:
- **Denoising Loss**: During training, the model calculates the loss as the squared difference between the actual noise added to the data in the forward process and the noise predicted by the model during the reverse diffusion process. This loss function is key to training the model to accurately reverse the noise addition, effectively reconstructing the original data from its noisy version.

### Metrics
1. **Mean Absolute Error (MAE)**: This metric measures the average magnitude of errors in a set of predictions, without considering their direction. It's a linear score that averages the absolute differences between predicted and actual values, providing a straightforward interpretation of prediction accuracy.
2. **Continuous Ranked Probability Score (CRPS)**: CRPS is used to assess the accuracy of probabilistic predictions. It measures the difference between the predicted cumulative distribution function and the empirical distribution function of the observed data. This score is particularly useful for evaluating the performance of models that generate probabilistic or distributional forecasts.

In [8]:
# Set up output folder
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
foldername = f"./save/forecasting_{datatype}_{current_time}/"
os.makedirs(foldername, exist_ok=True)

# Save config
with open(foldername + "config.json", "w") as f:
    json.dump(config, f, indent=4)

In [11]:
# Train the model
train(
    model,
    config["train"],
    train_loader,
    valid_loader=valid_loader,
    foldername=foldername,
)

100%|██████████| 691/691 [04:28<00:00,  2.58it/s, avg_epoch_loss=0.294, epoch=0]
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.219, epoch=1]
100%|██████████| 691/691 [04:37<00:00,  2.49it/s, avg_epoch_loss=0.205, epoch=2]
100%|██████████| 691/691 [04:37<00:00,  2.49it/s, avg_epoch_loss=0.199, epoch=3]
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.194, epoch=4]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.191, epoch=5]
100%|██████████| 691/691 [04:38<00:00,  2.48it/s, avg_epoch_loss=0.183, epoch=6]
100%|██████████| 691/691 [04:41<00:00,  2.45it/s, avg_epoch_loss=0.182, epoch=7]
100%|██████████| 691/691 [04:37<00:00,  2.49it/s, avg_epoch_loss=0.182, epoch=8]
100%|██████████| 691/691 [04:42<00:00,  2.44it/s, avg_epoch_loss=0.177, epoch=9]
100%|██████████| 691/691 [04:37<00:00,  2.49it/s, avg_epoch_loss=0.175, epoch=10]
100%|██████████| 691/691 [04:38<00:00,  2.48it/s, avg_epoch_loss=0.182, epoch=11]
100%|██████████| 691/691 [


 best loss is updated to  0.16422410309314728 at 19


100%|██████████| 691/691 [04:40<00:00,  2.46it/s, avg_epoch_loss=0.169, epoch=20]
100%|██████████| 691/691 [04:40<00:00,  2.46it/s, avg_epoch_loss=0.166, epoch=21]
100%|██████████| 691/691 [04:40<00:00,  2.46it/s, avg_epoch_loss=0.168, epoch=22]
100%|██████████| 691/691 [04:40<00:00,  2.46it/s, avg_epoch_loss=0.163, epoch=23]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.163, epoch=24]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.163, epoch=25]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.163, epoch=26]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.163, epoch=27]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.163, epoch=28]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.163, epoch=29]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.16, epoch=30] 
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.16, epoch=31] 
100%|██████████|


 best loss is updated to  0.15855088829994202 at 39


100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.157, epoch=40]
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.161, epoch=41]
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.153, epoch=42]
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.15, epoch=43] 
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.161, epoch=44]
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.15, epoch=45] 
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.155, epoch=46]
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.159, epoch=47]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.156, epoch=48]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.152, epoch=49]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.157, epoch=50]
100%|██████████| 691/691 [04:40<00:00,  2.47it/s, avg_epoch_loss=0.155, epoch=51]
100%|██████████|


 best loss is updated to  0.15426921844482422 at 59


100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.155, epoch=60]
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.156, epoch=61]
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.155, epoch=62]
100%|██████████| 691/691 [04:39<00:00,  2.47it/s, avg_epoch_loss=0.152, epoch=63]
100%|██████████| 691/691 [04:36<00:00,  2.50it/s, avg_epoch_loss=0.145, epoch=64]
100%|██████████| 691/691 [04:34<00:00,  2.52it/s, avg_epoch_loss=0.153, epoch=65]
100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.158, epoch=66]
100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.156, epoch=67]
100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.153, epoch=68]
100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.153, epoch=69]
100%|██████████| 691/691 [04:34<00:00,  2.52it/s, avg_epoch_loss=0.155, epoch=70]
100%|██████████| 691/691 [04:34<00:00,  2.51it/s, avg_epoch_loss=0.156, epoch=71]
100%|██████████|


 best loss is updated to  0.1498965471982956 at 79


100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.153, epoch=80]
100%|██████████| 691/691 [04:36<00:00,  2.50it/s, avg_epoch_loss=0.148, epoch=81]
100%|██████████| 691/691 [04:36<00:00,  2.50it/s, avg_epoch_loss=0.15, epoch=82] 
100%|██████████| 691/691 [04:37<00:00,  2.49it/s, avg_epoch_loss=0.145, epoch=83]
100%|██████████| 691/691 [04:37<00:00,  2.49it/s, avg_epoch_loss=0.147, epoch=84]
100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.147, epoch=85]
100%|██████████| 691/691 [04:34<00:00,  2.51it/s, avg_epoch_loss=0.145, epoch=86]
100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.151, epoch=87]
100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.148, epoch=88]
100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.146, epoch=89]
100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.15, epoch=90] 
100%|██████████| 691/691 [04:35<00:00,  2.51it/s, avg_epoch_loss=0.147, epoch=91]
100%|██████████|


 best loss is updated to  0.14829400181770325 at 99





## Loading a Pretrained Model

### Function for Loading Model State
To enhance or expedite the training process, or for evaluation purposes, you may start with a model that has already been trained. The function `load_pretrained_model` facilitates the loading of these pretrained weights into an existing model architecture.

- **Parameters**:
  - `model`: The model instance into which the pretrained weights will be loaded.
  - `modelfolder`: The subdirectory under `./save/` where the pretrained model is stored, defaulting to `pretrained`.
  - `device`: The computing device (CPU or GPU) where the model will be loaded. This ensures that the model is compatible with the hardware used for subsequent operations.

### Execution
The function constructs the full path to the pretrained model's state dictionary file (`model.pth`) using the specified `modelfolder`. It then loads this state dictionary into the model, ensuring that all model parameters are updated accordingly.

- **Model Compatibility**:
  - It is crucial that the model architecture into which the weights are being loaded matches the architecture of the model when it was saved. Incompatibility in architectures will lead to errors during the loading process.

### Usage
To utilize a pretrained model, simply pass your initialized but untrained model to the `load_pretrained_model` function. This setup allows you to leverage previously learned patterns, potentially reducing training time and improving model robustness from the outset.

In [11]:
def load_pretrained_model(model_path, model, device=device):
    state_dict = torch.load(model_path, map_location=device)

    # Check if the state_dict keys start with 'module.'
    if list(state_dict.keys())[0].startswith("module."):
        # Remove the 'module.' prefix
        state_dict = {k[7:]: v for k, v in state_dict.items()}

    model.load_state_dict(state_dict)
    return model

In [12]:
# Path to the pre-trained model
model_path = "/projects/diffusion_bootcamp/models/time-seris/csdi/save/forecasting_electricity_20240730_010005/model.pth"
# Load pre-trained model
pretrained_model = load_pretrained_model(model_path, model)
pretrained_model.target_dim = target_dim

## Model Evaluation

### Evaluation Parameters Configuration
Prior to evaluating the models, several key parameters are established:
- **`test_loader`**: The data loader for the test dataset.
- **`nsample`**: Specifies the number of samples to generate during model evaluation, set to 100 for comprehensive testing.
- **`scaler`** and **`mean_scaler`**: Define scaling factors for the data. These parameters adjust the data normalization during the evaluation to match the conditions used during model training.

### Data Loader Configuration
- The test data loader is updated to include the new missing ratio, ensuring that the evaluation tests the model's ability to handle and impute missing data effectively.

### Model Evaluation
- **Current Model Evaluation**: The initially trained model is evaluated using the updated `test_loader`. This step is crucial for understanding the baseline performance of the model on the test set.
- **Pre-trained Model Evaluation**: Additionally, a pre-trained model is evaluated under the same conditions. This is particularly useful for comparing the effectiveness of pre-training and fine-tuning strategies on model performance.

### Execution
Both the current and pre-trained models are evaluated with the specified number of samples and scaling parameters. The results are stored in a designated folder, facilitating subsequent analysis and comparison:

In [None]:
# Set evaluation parameters
nsample = 100  # number of samples for evaluation

# Evaluate the model
evaluate(
    model,
    test_loader,
    nsample=nsample,
    scaler=scaler,
    mean_scaler=mean_scaler,
    foldername=foldername,
)

### References
- Tashiro, Yusuke, et al. "CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation." *Advances in Neural Information Processing Systems*. 2021. [GitHub Repository](https://github.com/ermongroup/CSDI)
