# Conditional Diffusion Training
Train a transformer-based conditional diffusion model for time series.

In [1]:
import torch
from model import ConditionalScoreNet
from data import generate_sine_sequence
from trainer import train_diffusion_model_conditional
from sde import VPSDE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Hyperparameters
num_samples = 5000
history_len = 150
predict_len = 50
total_len = history_len + predict_len
input_dim = 1
batch_size = 128
num_epochs = 500
lr = 1e-4
num_diffusion_timesteps = 1000
checkpoint_path = "checkpoints/score_model.pth"


In [3]:
# Generate sine sequences
data = generate_sine_sequence(num_samples, total_len, input_dim=input_dim)
print("Data shape:", data.shape)


Data shape: torch.Size([5000, 200, 1])


In [4]:
model = ConditionalScoreNet(input_dim=input_dim,
                            history_len=history_len,
                            predict_len=predict_len).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


In [5]:
train_diffusion_model_conditional(
    data_x=data,
    score_net=model,
    optimizer=optimizer,
    num_diffusion_timesteps=num_diffusion_timesteps,
    batch_size=batch_size,
    num_epochs=num_epochs,
    history_len=history_len,
    predict_len=predict_len,
    device=device,
    checkpoint_path=checkpoint_path,
    save_every=100
)


Training Progress:  20%|██        | 100/500 [28:30<2:01:18, 18.20s/it, loss=9.6064]

Checkpoint saved to checkpoints/score_model.pth


Training Progress:  40%|████      | 200/500 [58:53<1:30:51, 18.17s/it, loss=10.2140]

Checkpoint saved to checkpoints/score_model.pth


Training Progress:  60%|██████    | 300/500 [1:29:19<1:01:00, 18.30s/it, loss=9.7592] 

Checkpoint saved to checkpoints/score_model.pth


Training Progress:  80%|████████  | 400/500 [1:58:22<30:22, 18.23s/it, loss=9.9139]   

Checkpoint saved to checkpoints/score_model.pth


Training Progress:  81%|████████  | 403/500 [1:59:25<28:44, 17.78s/it, loss=9.3811]


KeyboardInterrupt: 