In [31]:
import torch
from models.diffusion_ts_model import DiffusionTransformer
from utils.visualisation import visualize_decomposition

In [42]:
#Load or Create a Model
import inspect

# Charger le checkpoint
checkpoint = torch.load("best_model.pt", map_location="cpu")
config = checkpoint["config"]

# Arguments attendus par le constructeur
sig = inspect.signature(DiffusionTransformer.__init__)
allowed_keys = set(sig.parameters.keys()) - {"self"}

# Filtrer les cl√©s compatibles
filtered_config = {
    k: v for k, v in config.items() 
    if k in allowed_keys
}

filtered_config["input_dim"] = 1  

model = DiffusionTransformer(**filtered_config)

# Charger les poids
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()


DiffusionTransformer(
  (input_proj): Linear(in_features=8, out_features=256, bias=True)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (time_mlp): Sequential(
    (0): TimeEmbedding()
    (1): Linear(in_features=256, out_features=1024, bias=True)
    (2): SiLU()
    (3): Linear(in_features=1024, out_features=256, bias=True)
  )
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inpla

In [52]:
#synthetic sine wave data for demonstration

seq_len = 512

t = torch.arange(seq_len).float().unsqueeze(0)  # shape: (1, 512)
x = torch.sin(2 * 3.1415 * t / 64) + 0.1 * torch.randn_like(t)  # add Gaussian noise
x = x.unsqueeze(-1)
x.shape

torch.Size([1, 512, 1])

In [53]:
# Simulate a time step for diffusion (if needed)
time_step = torch.tensor([10]).long()

# Pass data through model with time step
output = model(x, time_step)

# The output dict contains:
# - output['trend']: smooth trend component
# - output['seasonality']: periodic patterns
# - output['output']: full reconstruction
trend = output['trend'].detach().cpu().numpy()
seasonality = output['seasonality'].detach().cpu().numpy()
reconstructed = output['output'].detach().cpu().numpy()


In [54]:
#Visualize Decomposition
x_nobatch = x[0]
visualize_decomposition(model, x_nobatch, save_path="decomposition.png")
x_nobatch.shape

Decomposition saved to decomposition.png


torch.Size([512, 1])