# Conditional Diffusion Training
Train a transformer-based conditional diffusion model for time series.

In [1]:
import torch
from model import ConditionalScoreNet
from data import generate_sine_sequence
from trainer import train_diffusion_model_conditional
from sde import VPSDE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Hyperparameters
num_samples = 10000
history_len = 50
predict_len = 1
total_len = history_len + predict_len
input_dim = 1
batch_size = 128
num_epochs = 500
lr = 1e-4
num_diffusion_timesteps = 1000
checkpoint_path = "checkpoints/score_model.pth"


In [3]:
# Generate sine sequences
data = generate_sine_sequence(num_samples, total_len, input_dim=input_dim)
print("Data shape:", data.shape)


Data shape: torch.Size([10000, 51, 1])


In [4]:
model = ConditionalScoreNet(input_dim=input_dim,
                            history_len=history_len,
                            predict_len=predict_len).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


In [5]:
train_diffusion_model_conditional(
    data_x=data,
    score_net=model,
    optimizer=optimizer,
    num_diffusion_timesteps=num_diffusion_timesteps,
    batch_size=batch_size,
    num_epochs=num_epochs,
    history_len=history_len,
    predict_len=predict_len,
    device=device,
    checkpoint_path=checkpoint_path,
    save_every=100
)


Training Progress:  20%|██        | 100/500 [22:08<1:28:11, 13.23s/it, loss=0.1816]

Checkpoint saved to checkpoints/score_model.pth


Training Progress:  40%|████      | 200/500 [41:46<51:05, 10.22s/it, loss=0.1828]  

Checkpoint saved to checkpoints/score_model.pth


Training Progress:  60%|██████    | 300/500 [1:00:00<38:50, 11.65s/it, loss=0.1889]

Checkpoint saved to checkpoints/score_model.pth


Training Progress:  80%|████████  | 400/500 [1:18:13<17:25, 10.45s/it, loss=0.1833]

Checkpoint saved to checkpoints/score_model.pth


Training Progress: 100%|██████████| 500/500 [1:35:25<00:00, 11.45s/it, loss=0.1874]

Checkpoint saved to checkpoints/score_model.pth





[np.float64(0.5304012581517424),
 np.float64(0.20387378419879115),
 np.float64(0.17744316335154486),
 np.float64(0.18067738550561893),
 np.float64(0.18830542643613454),
 np.float64(0.17767878093674214),
 np.float64(0.18899238713179964),
 np.float64(0.18628960017916524),
 np.float64(0.19182178500709654),
 np.float64(0.17895029202292237),
 np.float64(0.1837109212256685),
 np.float64(0.17817866792784462),
 np.float64(0.18686180788127682),
 np.float64(0.19356732570295093),
 np.float64(0.18961992803253705),
 np.float64(0.19090588447413867),
 np.float64(0.19661833732565748),
 np.float64(0.17620808923546272),
 np.float64(0.1980922374921509),
 np.float64(0.18356646257865278),
 np.float64(0.18782928373806085),
 np.float64(0.1818169222413739),
 np.float64(0.18542384742936002),
 np.float64(0.18558674319824087),
 np.float64(0.19137655622974226),
 np.float64(0.18801050454000884),
 np.float64(0.18469172867038583),
 np.float64(0.18786422285852553),
 np.float64(0.19883738656209993),
 np.float64(0.1849