In [1]:
import sys
import os

# Get the absolute path of the current directory
current_dir = os.getcwd()

# Get the parent directory (the project root)
parent_dir = os.path.dirname(current_dir)

# Add the parent directory to sys.path
sys.path.append(parent_dir)

# Importing libraries

In [2]:
import torch
from models.diffusion_ts_model import DiffusionTransformer
from utils.visualisation import visualize_decomposition
import inspect

Loading the model from a checkpoint..

If a checkpoint is not available a random model will be used to display the usage of the functions.


In [6]:
# Loading the model
try:
    checkpoint = torch.load("../Sample_Genderations/exponential_decay/best_model.pt", map_location="cpu")
except FileNotFoundError:
    print("Checkpoint file not found. Please check the path.")
config = checkpoint["config"]

# Expected arguments for the constructor
sig = inspect.signature(DiffusionTransformer.__init__)
allowed_keys = set(sig.parameters.keys()) - {"self"}

# Filter compatible keys
filtered_config = {
    k: v for k, v in config.items() 
    if k in allowed_keys
}

filtered_config["input_dim"] = 1  

model = DiffusionTransformer(**filtered_config)

# Charger les poids
try :
    model.load_state_dict(checkpoint["model_state_dict"], strict=False)
except RuntimeError:
    print("Incompatible model weights, loading with strict=False")
except  Exception as e:
    print(f"An unexpected error occurred: {e}")



Checkpoint file not found. Please check the path.


In [7]:
#synthetic sine wave data for demonstration

seq_len = 400

t = torch.arange(seq_len).float().unsqueeze(0)  # shape: (1, 512)
x = torch.sin(2 * 3.1415 * t / 64) + 0.1 * torch.randn_like(t)  # add Gaussian noise
x = x.unsqueeze(-1)
x.shape

torch.Size([1, 400, 1])

In [8]:
# Simulate a time step for diffusion (if needed)
time_step = torch.tensor([10]).long()

# Pass data through model with time step
output = model(x, time_step)

# The output dict contains:
# - output['trend']: smooth trend component
# - output['seasonality']: periodic patterns
# - output['output']: full reconstruction
trend = output['trend'].detach().cpu().numpy()
seasonality = output['seasonality'].detach().cpu().numpy()
reconstructed = output['output'].detach().cpu().numpy()


In [15]:
#Visualize Decomposition
x_nobatch = x[0]
visualize_decomposition(model, x_nobatch, save_path="decomposition.png")
x_nobatch.shape

Decomposition saved to decomposition.png


torch.Size([400, 1])