# Minimal DDIM (CIFAR-10) with Adam

In [1]:
# Imports
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import sys, os
sys.path.append(os.path.abspath(os.path.join('..')))

# Import implementation
from implementations.stf_smoothing import (
    DiffusionSchedule, UNet, sample_loop, stf_train_ddim)

# import numpy as np
# import random
# seed = 0
# random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)

# # Deterministic operations (Optional but recommended for exact reproducibility)
# # Note: This might make training slower. If speed is priority, keep benchmark=True
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

print('torch:', torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device ->', device)

torch: 2.9.0+cu126
device -> cuda
torch: 2.9.0+cu126
device -> cuda


In [2]:
# CIFAR-10 data loaders
def get_dataloaders(batch_size=128, img_size=32, num_workers=4):
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),  # map to [-1,1]
    ])
    train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    test_ds = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

train_loader, test_loader = get_dataloaders(batch_size=128, img_size=32, num_workers=2)

In [3]:
# Helper to show saved sample grid
from PIL import Image
def show_image(path, figsize=(6,6)):
    img = Image.open(path)
    plt.figure(figsize=figsize)
    plt.imshow(img)
    plt.axis('off')
    plt.show()


In [None]:
# WARNING: Training DDIM on CIFAR-10 is moderately heavy. Reduce epochs or timesteps for quick tests.
timesteps = 1000  # use 1000 for standard DDIM training; reduce for inference
model = UNet(in_ch=3, base_ch=64, time_emb_dim=64)
schedule = DiffusionSchedule(timesteps=timesteps, device=device)
print(f'Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M')

stf_train_ddim(
    model, schedule, train_loader, device,
    epochs=100,
    lr=2e-4,
    save_dir='./runs_ddim_test_stf_smoothing'
)

Model parameters: 8.22M
Compiling model...


  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=False):
Epoch 1/100: 100%|██████████| 391/391 [00:55<00:00,  7.06it/s, loss=0.0889, lr=0.00020]


End epoch 1, avg loss 0.1480


DDIM Sampling: 100%|██████████| 100/100 [00:02<00:00, 38.78it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_1.png


  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=False):
Epoch 2/100: 100%|██████████| 391/391 [00:27<00:00, 14.03it/s, loss=0.0859, lr=0.00020]


End epoch 2, avg loss 0.0810


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.85it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_2.png


  with torch.cuda.amp.autocast(enabled=False):
Epoch 3/100: 100%|██████████| 391/391 [00:27<00:00, 14.07it/s, loss=0.0731, lr=0.00020]


End epoch 3, avg loss 0.0733


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 167.26it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_3.png


Epoch 4/100: 100%|██████████| 391/391 [00:27<00:00, 14.03it/s, loss=0.0556, lr=0.00020]


End epoch 4, avg loss 0.0689


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.61it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_4.png


Epoch 5/100: 100%|██████████| 391/391 [00:27<00:00, 14.04it/s, loss=0.0629, lr=0.00020]


End epoch 5, avg loss 0.0662


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.90it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_5.png


Epoch 6/100: 100%|██████████| 391/391 [00:28<00:00, 13.91it/s, loss=0.0656, lr=0.00020]


End epoch 6, avg loss 0.0646


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 162.34it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_6.png


Epoch 7/100: 100%|██████████| 391/391 [00:28<00:00, 13.93it/s, loss=0.0701, lr=0.00020]


End epoch 7, avg loss 0.0643


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.31it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_7.png


Epoch 8/100: 100%|██████████| 391/391 [00:28<00:00, 13.93it/s, loss=0.0662, lr=0.00020]


End epoch 8, avg loss 0.0633


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.10it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_8.png


Epoch 9/100: 100%|██████████| 391/391 [00:28<00:00, 13.78it/s, loss=0.0551, lr=0.00020]


End epoch 9, avg loss 0.0624


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 170.59it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_9.png


Epoch 10/100: 100%|██████████| 391/391 [00:28<00:00, 13.95it/s, loss=0.0640, lr=0.00020]


End epoch 10, avg loss 0.0611


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.22it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_10.png


Epoch 11/100: 100%|██████████| 391/391 [00:28<00:00, 13.73it/s, loss=0.0781, lr=0.00019]


End epoch 11, avg loss 0.0615


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 167.32it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_11.png


Epoch 12/100: 100%|██████████| 391/391 [00:28<00:00, 13.73it/s, loss=0.0610, lr=0.00019]


End epoch 12, avg loss 0.0616


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.70it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_12.png


Epoch 13/100: 100%|██████████| 391/391 [00:28<00:00, 13.70it/s, loss=0.0784, lr=0.00019]


End epoch 13, avg loss 0.0604


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 164.88it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_13.png


Epoch 14/100: 100%|██████████| 391/391 [00:28<00:00, 13.69it/s, loss=0.0582, lr=0.00019]


End epoch 14, avg loss 0.0601


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.68it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_14.png


Epoch 15/100: 100%|██████████| 391/391 [00:28<00:00, 13.84it/s, loss=0.0551, lr=0.00019]


End epoch 15, avg loss 0.0603


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.67it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_15.png


Epoch 16/100: 100%|██████████| 391/391 [00:28<00:00, 13.84it/s, loss=0.0634, lr=0.00019]


End epoch 16, avg loss 0.0604


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 163.94it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_16.png


Epoch 17/100: 100%|██████████| 391/391 [00:28<00:00, 13.77it/s, loss=0.0545, lr=0.00019]


End epoch 17, avg loss 0.0597


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.39it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_17.png


Epoch 18/100: 100%|██████████| 391/391 [00:28<00:00, 13.80it/s, loss=0.0669, lr=0.00018]


End epoch 18, avg loss 0.0595


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.59it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_18.png


Epoch 19/100: 100%|██████████| 391/391 [00:28<00:00, 13.84it/s, loss=0.0637, lr=0.00018]


End epoch 19, avg loss 0.0591


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.73it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_19.png


Epoch 20/100: 100%|██████████| 391/391 [00:28<00:00, 13.83it/s, loss=0.0527, lr=0.00018]


End epoch 20, avg loss 0.0596


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.41it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_20.png


Epoch 21/100: 100%|██████████| 391/391 [00:28<00:00, 13.73it/s, loss=0.0563, lr=0.00018]


End epoch 21, avg loss 0.0596


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.89it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_21.png


Epoch 22/100: 100%|██████████| 391/391 [00:28<00:00, 13.66it/s, loss=0.0645, lr=0.00018]


End epoch 22, avg loss 0.0587


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 161.45it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_22.png


Epoch 23/100: 100%|██████████| 391/391 [00:28<00:00, 13.68it/s, loss=0.0786, lr=0.00018]


End epoch 23, avg loss 0.0587


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 170.88it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_23.png


Epoch 24/100: 100%|██████████| 391/391 [00:28<00:00, 13.85it/s, loss=0.0511, lr=0.00017]


End epoch 24, avg loss 0.0585


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 172.24it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_24.png


Epoch 25/100: 100%|██████████| 391/391 [00:28<00:00, 13.86it/s, loss=0.0579, lr=0.00017]


End epoch 25, avg loss 0.0584


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.58it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_25.png


Epoch 26/100: 100%|██████████| 391/391 [00:28<00:00, 13.82it/s, loss=0.0712, lr=0.00017]


End epoch 26, avg loss 0.0584


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 173.61it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_26.png


Epoch 27/100: 100%|██████████| 391/391 [00:28<00:00, 13.87it/s, loss=0.0599, lr=0.00017]


End epoch 27, avg loss 0.0586


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 173.93it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_27.png


Epoch 28/100: 100%|██████████| 391/391 [00:28<00:00, 13.83it/s, loss=0.0547, lr=0.00016]


End epoch 28, avg loss 0.0584


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 162.79it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_28.png


Epoch 29/100: 100%|██████████| 391/391 [00:28<00:00, 13.82it/s, loss=0.0498, lr=0.00016]


End epoch 29, avg loss 0.0576


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.96it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_29.png


Epoch 30/100: 100%|██████████| 391/391 [00:28<00:00, 13.80it/s, loss=0.0538, lr=0.00016]


End epoch 30, avg loss 0.0581


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 167.15it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_30.png


Epoch 31/100: 100%|██████████| 391/391 [00:28<00:00, 13.70it/s, loss=0.0621, lr=0.00016]


End epoch 31, avg loss 0.0575


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 167.68it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_31.png


Epoch 32/100: 100%|██████████| 391/391 [00:28<00:00, 13.51it/s, loss=0.0623, lr=0.00015]


End epoch 32, avg loss 0.0576


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.90it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_32.png


Epoch 33/100: 100%|██████████| 391/391 [00:28<00:00, 13.57it/s, loss=0.0614, lr=0.00015]


End epoch 33, avg loss 0.0580


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 172.63it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_33.png


Epoch 34/100: 100%|██████████| 391/391 [00:28<00:00, 13.68it/s, loss=0.0584, lr=0.00015]


End epoch 34, avg loss 0.0578


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.76it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_34.png


Epoch 35/100: 100%|██████████| 391/391 [00:28<00:00, 13.61it/s, loss=0.0553, lr=0.00015]


End epoch 35, avg loss 0.0577


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 172.75it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_35.png


Epoch 36/100: 100%|██████████| 391/391 [00:28<00:00, 13.67it/s, loss=0.0756, lr=0.00014]


End epoch 36, avg loss 0.0584


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 172.20it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_36.png


Epoch 37/100: 100%|██████████| 391/391 [00:28<00:00, 13.61it/s, loss=0.0683, lr=0.00014]


End epoch 37, avg loss 0.0576


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 172.61it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_37.png


Epoch 38/100: 100%|██████████| 391/391 [00:28<00:00, 13.56it/s, loss=0.0573, lr=0.00014]


End epoch 38, avg loss 0.0570


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.79it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_38.png


Epoch 39/100: 100%|██████████| 391/391 [00:29<00:00, 13.39it/s, loss=0.0604, lr=0.00013]


End epoch 39, avg loss 0.0573


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 171.90it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_39.png


Epoch 40/100: 100%|██████████| 391/391 [00:28<00:00, 13.52it/s, loss=0.0668, lr=0.00013]


End epoch 40, avg loss 0.0576


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.81it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_40.png


Epoch 41/100: 100%|██████████| 391/391 [00:29<00:00, 13.47it/s, loss=0.0565, lr=0.00013]


End epoch 41, avg loss 0.0575


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.92it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_41.png


Epoch 42/100: 100%|██████████| 391/391 [00:28<00:00, 13.49it/s, loss=0.0489, lr=0.00013]


End epoch 42, avg loss 0.0570


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 171.61it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_42.png


Epoch 43/100: 100%|██████████| 391/391 [00:28<00:00, 13.59it/s, loss=0.0659, lr=0.00012]


End epoch 43, avg loss 0.0575


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 172.71it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_43.png


Epoch 44/100: 100%|██████████| 391/391 [00:28<00:00, 13.57it/s, loss=0.0568, lr=0.00012]


End epoch 44, avg loss 0.0574


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 167.06it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_44.png


Epoch 45/100: 100%|██████████| 391/391 [00:29<00:00, 13.43it/s, loss=0.0440, lr=0.00012]


End epoch 45, avg loss 0.0566


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.87it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_45.png


Epoch 46/100: 100%|██████████| 391/391 [00:29<00:00, 13.44it/s, loss=0.0517, lr=0.00011]


End epoch 46, avg loss 0.0576


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.31it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_46.png


Epoch 47/100: 100%|██████████| 391/391 [00:29<00:00, 13.47it/s, loss=0.0602, lr=0.00011]


End epoch 47, avg loss 0.0571


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 163.00it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_47.png


Epoch 48/100: 100%|██████████| 391/391 [00:29<00:00, 13.34it/s, loss=0.0647, lr=0.00011]


End epoch 48, avg loss 0.0567


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 164.40it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_48.png


Epoch 49/100: 100%|██████████| 391/391 [00:29<00:00, 13.37it/s, loss=0.0507, lr=0.00010]


End epoch 49, avg loss 0.0565


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.27it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_49.png


Epoch 50/100: 100%|██████████| 391/391 [00:29<00:00, 13.37it/s, loss=0.0692, lr=0.00010]


End epoch 50, avg loss 0.0567


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.32it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_50.png


Epoch 51/100: 100%|██████████| 391/391 [00:28<00:00, 13.53it/s, loss=0.0522, lr=0.00010]


End epoch 51, avg loss 0.0570


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.80it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_51.png


Epoch 52/100: 100%|██████████| 391/391 [00:29<00:00, 13.45it/s, loss=0.0502, lr=0.00009]


End epoch 52, avg loss 0.0564


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.59it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_52.png


Epoch 53/100: 100%|██████████| 391/391 [00:29<00:00, 13.46it/s, loss=0.0558, lr=0.00009]


End epoch 53, avg loss 0.0568


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.36it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_53.png


Epoch 54/100: 100%|██████████| 391/391 [00:28<00:00, 13.56it/s, loss=0.0744, lr=0.00009]


End epoch 54, avg loss 0.0569


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 164.45it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_54.png


Epoch 55/100: 100%|██████████| 391/391 [00:28<00:00, 13.49it/s, loss=0.0585, lr=0.00008]


End epoch 55, avg loss 0.0566


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 164.23it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_55.png


Epoch 56/100: 100%|██████████| 391/391 [00:29<00:00, 13.41it/s, loss=0.0560, lr=0.00008]


End epoch 56, avg loss 0.0558


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.22it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_56.png


Epoch 57/100: 100%|██████████| 391/391 [00:28<00:00, 13.53it/s, loss=0.0574, lr=0.00008]


End epoch 57, avg loss 0.0561


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.66it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_57.png


Epoch 58/100: 100%|██████████| 391/391 [00:28<00:00, 13.59it/s, loss=0.0602, lr=0.00008]


End epoch 58, avg loss 0.0559


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.57it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_58.png


Epoch 59/100: 100%|██████████| 391/391 [00:28<00:00, 13.65it/s, loss=0.0636, lr=0.00007]


End epoch 59, avg loss 0.0561


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 163.75it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_59.png


Epoch 60/100: 100%|██████████| 391/391 [00:28<00:00, 13.57it/s, loss=0.0422, lr=0.00007]


End epoch 60, avg loss 0.0562


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.81it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_60.png


Epoch 61/100: 100%|██████████| 391/391 [00:28<00:00, 13.60it/s, loss=0.0451, lr=0.00007]


End epoch 61, avg loss 0.0562


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.97it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_61.png


Epoch 62/100: 100%|██████████| 391/391 [00:28<00:00, 13.56it/s, loss=0.0547, lr=0.00006]


End epoch 62, avg loss 0.0557


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 170.17it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_62.png


Epoch 63/100: 100%|██████████| 391/391 [00:28<00:00, 13.63it/s, loss=0.0469, lr=0.00006]


End epoch 63, avg loss 0.0560


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.44it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_63.png


Epoch 64/100: 100%|██████████| 391/391 [00:28<00:00, 13.76it/s, loss=0.0566, lr=0.00006]


End epoch 64, avg loss 0.0565


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.55it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_64.png


Epoch 65/100: 100%|██████████| 391/391 [00:28<00:00, 13.78it/s, loss=0.0484, lr=0.00005]


End epoch 65, avg loss 0.0566


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 170.27it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_65.png


Epoch 66/100: 100%|██████████| 391/391 [00:28<00:00, 13.77it/s, loss=0.0637, lr=0.00005]


End epoch 66, avg loss 0.0560


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.31it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_66.png


Epoch 67/100: 100%|██████████| 391/391 [00:28<00:00, 13.67it/s, loss=0.0491, lr=0.00005]


End epoch 67, avg loss 0.0560


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.93it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_67.png


Epoch 68/100: 100%|██████████| 391/391 [00:28<00:00, 13.64it/s, loss=0.0577, lr=0.00005]


End epoch 68, avg loss 0.0558


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.33it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_68.png


Epoch 69/100: 100%|██████████| 391/391 [00:28<00:00, 13.63it/s, loss=0.0632, lr=0.00004]


End epoch 69, avg loss 0.0560


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.99it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_69.png


Epoch 70/100: 100%|██████████| 391/391 [00:28<00:00, 13.63it/s, loss=0.0581, lr=0.00004]


End epoch 70, avg loss 0.0567


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.02it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_70.png


Epoch 71/100: 100%|██████████| 391/391 [00:28<00:00, 13.54it/s, loss=0.0487, lr=0.00004]


End epoch 71, avg loss 0.0564


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.47it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_71.png


Epoch 72/100: 100%|██████████| 391/391 [00:28<00:00, 13.53it/s, loss=0.0650, lr=0.00004]


End epoch 72, avg loss 0.0558


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.62it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_72.png


Epoch 73/100: 100%|██████████| 391/391 [00:28<00:00, 13.61it/s, loss=0.0631, lr=0.00003]


End epoch 73, avg loss 0.0557


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 164.24it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_73.png


Epoch 74/100: 100%|██████████| 391/391 [00:28<00:00, 13.56it/s, loss=0.0467, lr=0.00003]


End epoch 74, avg loss 0.0560


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 164.81it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_74.png


Epoch 75/100: 100%|██████████| 391/391 [00:28<00:00, 13.60it/s, loss=0.0505, lr=0.00003]


End epoch 75, avg loss 0.0554


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.12it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_75.png


Epoch 76/100: 100%|██████████| 391/391 [00:28<00:00, 13.59it/s, loss=0.0554, lr=0.00003]


End epoch 76, avg loss 0.0554


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.21it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_76.png


Epoch 77/100: 100%|██████████| 391/391 [00:28<00:00, 13.71it/s, loss=0.0558, lr=0.00003]


End epoch 77, avg loss 0.0561


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.67it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_77.png


Epoch 78/100: 100%|██████████| 391/391 [00:28<00:00, 13.77it/s, loss=0.0605, lr=0.00002]


End epoch 78, avg loss 0.0551


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.54it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_78.png


Epoch 79/100: 100%|██████████| 391/391 [00:28<00:00, 13.77it/s, loss=0.0489, lr=0.00002]


End epoch 79, avg loss 0.0561


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.67it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_79.png


Epoch 80/100: 100%|██████████| 391/391 [00:28<00:00, 13.79it/s, loss=0.0464, lr=0.00002]


End epoch 80, avg loss 0.0558


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 168.59it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_80.png


Epoch 81/100: 100%|██████████| 391/391 [00:28<00:00, 13.69it/s, loss=0.0661, lr=0.00002]


End epoch 81, avg loss 0.0556


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 167.61it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_81.png


Epoch 82/100: 100%|██████████| 391/391 [00:28<00:00, 13.60it/s, loss=0.0413, lr=0.00002]


End epoch 82, avg loss 0.0555


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 167.12it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_82.png


Epoch 83/100: 100%|██████████| 391/391 [00:28<00:00, 13.59it/s, loss=0.0645, lr=0.00001]


End epoch 83, avg loss 0.0557


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 167.13it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_83.png


Epoch 84/100: 100%|██████████| 391/391 [00:28<00:00, 13.59it/s, loss=0.0612, lr=0.00001]


End epoch 84, avg loss 0.0557


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 161.89it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_84.png


Epoch 85/100: 100%|██████████| 391/391 [00:28<00:00, 13.53it/s, loss=0.0573, lr=0.00001]


End epoch 85, avg loss 0.0557


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.43it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_85.png


Epoch 86/100: 100%|██████████| 391/391 [00:28<00:00, 13.66it/s, loss=0.0652, lr=0.00001]


End epoch 86, avg loss 0.0554


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 164.33it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_86.png


Epoch 87/100: 100%|██████████| 391/391 [00:28<00:00, 13.59it/s, loss=0.0517, lr=0.00001]


End epoch 87, avg loss 0.0554


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 169.41it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_87.png


Epoch 88/100: 100%|██████████| 391/391 [00:28<00:00, 13.62it/s, loss=0.0548, lr=0.00001]


End epoch 88, avg loss 0.0554


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 164.67it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_88.png


Epoch 89/100: 100%|██████████| 391/391 [00:28<00:00, 13.53it/s, loss=0.0606, lr=0.00001]


End epoch 89, avg loss 0.0552


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.21it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_89.png


Epoch 90/100: 100%|██████████| 391/391 [00:28<00:00, 13.55it/s, loss=0.0541, lr=0.00001]


End epoch 90, avg loss 0.0556


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 162.95it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_90.png


Epoch 91/100: 100%|██████████| 391/391 [00:28<00:00, 13.52it/s, loss=0.0465, lr=0.00000]


End epoch 91, avg loss 0.0556


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 167.91it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_91.png


Epoch 92/100: 100%|██████████| 391/391 [00:28<00:00, 13.61it/s, loss=0.0447, lr=0.00000]


End epoch 92, avg loss 0.0554


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 164.07it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_92.png


Epoch 93/100: 100%|██████████| 391/391 [00:28<00:00, 13.52it/s, loss=0.0506, lr=0.00000]


End epoch 93, avg loss 0.0554


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.38it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_93.png


Epoch 94/100: 100%|██████████| 391/391 [00:29<00:00, 13.46it/s, loss=0.0607, lr=0.00000]


End epoch 94, avg loss 0.0554


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.65it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_94.png


Epoch 95/100: 100%|██████████| 391/391 [00:28<00:00, 13.57it/s, loss=0.0459, lr=0.00000]


End epoch 95, avg loss 0.0560


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 167.87it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_95.png


Epoch 96/100: 100%|██████████| 391/391 [00:28<00:00, 13.56it/s, loss=0.0619, lr=0.00000]


End epoch 96, avg loss 0.0552


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 165.66it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_96.png


Epoch 97/100: 100%|██████████| 391/391 [00:28<00:00, 13.49it/s, loss=0.0440, lr=0.00000]


End epoch 97, avg loss 0.0562


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 170.59it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_97.png


Epoch 98/100: 100%|██████████| 391/391 [00:28<00:00, 13.52it/s, loss=0.0512, lr=0.00000]


End epoch 98, avg loss 0.0551


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 162.81it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_98.png


Epoch 99/100: 100%|██████████| 391/391 [00:28<00:00, 13.52it/s, loss=0.0501, lr=0.00000]


End epoch 99, avg loss 0.0559


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 166.30it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_99.png


Epoch 100/100: 100%|██████████| 391/391 [00:28<00:00, 13.50it/s, loss=0.0596, lr=0.00000]


End epoch 100, avg loss 0.0558


DDIM Sampling: 100%|██████████| 100/100 [00:00<00:00, 163.60it/s]


Saved samples to runs_ddim_test_stf_smoothing/samples_epoch_100.png
Notebook set up. To begin training, call train_ddim(...) as shown above.


In [None]:
# Load a checkpoint and generate samples (example)
ckpt_path = './runs_ddim_test_stf_smoothing/checkpoint_epoch_100.pt'  # adjust to your path
if Path(ckpt_path).exists():
    ckpt = torch.load(ckpt_path, map_location=device)
    state_dict = ckpt['model_state']

    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('_orig_mod.'):
            new_state_dict[k.replace('_orig_mod.', '')] = v
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.to(device).eval()
    samples = sample_loop(model, schedule, (16,3,32,32), device=device)
    grid = (samples.clamp(-1,1) + 1) / 2.0
    utils.save_image(grid, 'sample_from_ckpt.png', nrow=4)
    show_image('sample_from_ckpt.png')
else:
    print('No checkpoint found at', ckpt_path)


In [None]:
# Short diagnostic: take a test batch, add noise at random t, and visualize noisy vs. denoised (single-step prediction)
loader = test_loader
x, _ = next(iter(loader))
x = x[:8].to(device)
t = torch.randint(0, schedule.timesteps, (x.shape[0],), device=device)
noise = torch.randn_like(x)
x_noisy = schedule.q_sample(x, t, noise=noise)
with torch.no_grad():
    pred_noise = model(x_noisy, t)
    x0_pred = (x_noisy - schedule.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1) * pred_noise) / schedule.sqrt_alphas_cumprod[t].view(-1,1,1,1)
    x0_pred = x0_pred.clamp(-1,1)
# show noisy and denoised pairs
pairs = torch.cat([x_noisy[:8], x0_pred[:8]], dim=0)
grid = (pairs + 1)/2.0
utils.save_image(grid, 'diagnostic_pairs.png', nrow=8)
show_image('diagnostic_pairs.png', figsize=(12,4))
