# Conditional Diffusion Training
Train a transformer-based conditional diffusion model for time series.

In [1]:
import torch
from model import ConditionalScoreNet
from data import generate_sine_sequence
from trainer import train_diffusion_model_conditional
from sde import VPSDE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [2]:
# Hyperparameters
num_samples = 1000
history_len = 150
predict_len = 20
total_len = history_len + predict_len
input_dim = 1
batch_size = 128
num_epochs = 500
lr = 4e-4
num_diffusion_timesteps = 1000
checkpoint_path = "checkpoints/score_model.pth"


In [3]:
# Generate sine sequences
data = generate_sine_sequence(num_samples, total_len, input_dim=input_dim)
print("Data shape:", data.shape)


Data shape: torch.Size([1000, 170, 1])


In [4]:
model = ConditionalScoreNet(input_dim=input_dim,
                            history_len=history_len,
                            predict_len=predict_len).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)


In [5]:
train_diffusion_model_conditional(
    data_x=data,
    score_net=model,
    optimizer=optimizer,
    num_diffusion_timesteps=num_diffusion_timesteps,
    batch_size=batch_size,
    num_epochs=num_epochs,
    history_len=history_len,
    predict_len=predict_len,
    device=device,
    checkpoint_path=checkpoint_path,
    save_every=100
)


Training Progress:   0%|          | 1/500 [00:04<35:55,  4.32s/it, loss=40.4342]

Best model checkpoint saved at epoch 1 to checkpoints/score_model.pth


Training Progress:   0%|          | 2/500 [00:07<30:34,  3.68s/it, loss=12.3428]

Best model checkpoint saved at epoch 2 to checkpoints/score_model.pth


Training Progress:   1%|          | 3/500 [00:10<28:41,  3.46s/it, loss=7.1490] 

Best model checkpoint saved at epoch 3 to checkpoints/score_model.pth


Training Progress:   1%|          | 4/500 [00:13<27:45,  3.36s/it, loss=5.9149]

Best model checkpoint saved at epoch 4 to checkpoints/score_model.pth


Training Progress:   1%|          | 5/500 [00:17<27:19,  3.31s/it, loss=5.1089]

Best model checkpoint saved at epoch 5 to checkpoints/score_model.pth


Training Progress:   1%|          | 6/500 [00:20<27:03,  3.29s/it, loss=4.9036]

Best model checkpoint saved at epoch 6 to checkpoints/score_model.pth


Training Progress:   1%|▏         | 7/500 [00:23<26:54,  3.27s/it, loss=4.5751]

Best model checkpoint saved at epoch 7 to checkpoints/score_model.pth


Training Progress:   2%|▏         | 9/500 [00:30<26:32,  3.24s/it, loss=4.3948]

Best model checkpoint saved at epoch 9 to checkpoints/score_model.pth


Training Progress:   3%|▎         | 13/500 [00:42<25:48,  3.18s/it, loss=4.1506]

Best model checkpoint saved at epoch 13 to checkpoints/score_model.pth


Training Progress:   3%|▎         | 17/500 [00:55<25:24,  3.16s/it, loss=4.1044]

Best model checkpoint saved at epoch 17 to checkpoints/score_model.pth


Training Progress:   4%|▎         | 18/500 [00:58<25:27,  3.17s/it, loss=3.7038]

Best model checkpoint saved at epoch 18 to checkpoints/score_model.pth


Training Progress:   4%|▍         | 21/500 [01:16<39:21,  4.93s/it, loss=3.4524]

Best model checkpoint saved at epoch 21 to checkpoints/score_model.pth


Training Progress:   5%|▍         | 23/500 [01:22<32:11,  4.05s/it, loss=3.2846]

Best model checkpoint saved at epoch 23 to checkpoints/score_model.pth


Training Progress:   5%|▍         | 24/500 [01:25<30:15,  3.81s/it, loss=3.2760]

Best model checkpoint saved at epoch 24 to checkpoints/score_model.pth


Training Progress:   5%|▌         | 25/500 [01:29<28:52,  3.65s/it, loss=3.2690]

Best model checkpoint saved at epoch 25 to checkpoints/score_model.pth


Training Progress:   5%|▌         | 26/500 [01:32<27:53,  3.53s/it, loss=3.1535]

Best model checkpoint saved at epoch 26 to checkpoints/score_model.pth


Training Progress:   6%|▌         | 30/500 [01:45<25:39,  3.28s/it, loss=3.0028]

Best model checkpoint saved at epoch 30 to checkpoints/score_model.pth


Training Progress:   6%|▌         | 31/500 [01:48<25:32,  3.27s/it, loss=2.9680]

Best model checkpoint saved at epoch 31 to checkpoints/score_model.pth


Training Progress:   7%|▋         | 34/500 [01:57<25:05,  3.23s/it, loss=2.8887]

Best model checkpoint saved at epoch 34 to checkpoints/score_model.pth


Training Progress:   8%|▊         | 38/500 [02:10<24:43,  3.21s/it, loss=2.7633]

Best model checkpoint saved at epoch 38 to checkpoints/score_model.pth


Training Progress:   8%|▊         | 39/500 [02:16<24:32,  3.19s/it, loss=2.6115]

Best model checkpoint saved at epoch 40 to checkpoints/score_model.pth


Training Progress:  10%|▉         | 49/500 [02:48<25:46,  3.43s/it, loss=2.8346]


KeyboardInterrupt: 