In [26]:
import torch
from tqdm import tqdm
import sys
import torch.optim as optim
from ddpm.config import cifar10_config
from ddpm.data import get_cifar10_dataloaders
from ddpm.diffusion_model import DiffusionModel

In [27]:
from ddpm import config as _config
_config.DEBUG = False

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [29]:
cifar10_config.res_net_config.initial_pad = 0

In [30]:
train_loader, test_loader = get_cifar10_dataloaders(
    batch_size=cifar10_config.batch_size, 
)

In [31]:
model = DiffusionModel(cifar10_config).to(device)

In [32]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
for epoch in tqdm(range(100), desc="Training Progress", leave=True):
    model.train()
    batch_progress = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False)

    for batch_idx, (images, labels) in enumerate(batch_progress):
        images = images.to(device)
        labels = labels.to(device)

        # -- Debug print the shapes, but only for the very first batch of each epoch
        if batch_idx == 0:
            print(f"[DEBUG] Epoch {epoch}, first batch shapes: images={images.shape}, labels={labels.shape}")

        # Forward pass
        try:
            loss = model(images, labels)
        except Exception as e:
            print(f"Error during model forward pass: {e}")
            continue

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_progress.set_postfix(loss=loss.item())
        sys.stdout.flush()

    # Print the final loss of the epoch
    tqdm.write(f"Epoch {epoch}, loss={loss.item():.4f}")

    # ---------------------
    # Periodically sample
    # ---------------------
    model.eval()

    # For example, let's randomly pick class labels for 16 samples:
    # (If CIFAR-10, classes range from 0..9)
    labels_for_sampling = torch.randint(
        low=0, 
        high=10,             # or whatever num_classes you have
        size=(16,),          # 16 samples
        dtype=torch.long,
        device=device
    )

    # Now we pass `y=labels_for_sampling` into model.sample
    with torch.no_grad():
        samples = model.sample(
            shape=(16, 3, 32, 32),
            device=device,
            y=labels_for_sampling
        )


Training Progress:   0%|          | 0/10 [00:00<?, ?it/s]

[DEBUG] Epoch 0, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:   0%|          | 0/10 [02:22<?, ?it/s]

Epoch 0, loss=0.0316


Training Progress:  10%|█         | 1/10 [02:37<23:40, 157.80s/it]

[DEBUG] Epoch 1, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  10%|█         | 1/10 [04:54<23:40, 157.80s/it]

Epoch 1, loss=0.0619


Training Progress:  20%|██        | 2/10 [05:10<20:40, 155.02s/it]

[DEBUG] Epoch 2, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  20%|██        | 2/10 [07:33<20:40, 155.02s/it]

Epoch 2, loss=0.0337


Training Progress:  30%|███       | 3/10 [07:49<18:15, 156.51s/it]

[DEBUG] Epoch 3, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  30%|███       | 3/10 [10:06<18:15, 156.51s/it]

Epoch 3, loss=0.0254


Training Progress:  40%|████      | 4/10 [10:21<15:30, 155.00s/it]

[DEBUG] Epoch 4, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  40%|████      | 4/10 [12:39<15:30, 155.00s/it]

Epoch 4, loss=0.0274


Training Progress:  50%|█████     | 5/10 [12:54<12:51, 154.28s/it]

[DEBUG] Epoch 5, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  50%|█████     | 5/10 [15:12<12:51, 154.28s/it]

Epoch 5, loss=0.0147


Training Progress:  60%|██████    | 6/10 [15:29<10:17, 154.39s/it]

[DEBUG] Epoch 6, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  60%|██████    | 6/10 [17:48<10:17, 154.39s/it]

Epoch 6, loss=0.0449


Training Progress:  70%|███████   | 7/10 [18:04<07:43, 154.50s/it]

[DEBUG] Epoch 7, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  70%|███████   | 7/10 [20:21<07:43, 154.50s/it]

Epoch 7, loss=0.0400


Training Progress:  80%|████████  | 8/10 [20:37<05:08, 154.16s/it]

[DEBUG] Epoch 8, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  80%|████████  | 8/10 [22:54<05:08, 154.16s/it]

Epoch 8, loss=0.0682


Training Progress:  90%|█████████ | 9/10 [23:10<02:33, 153.74s/it]

[DEBUG] Epoch 9, first batch shapes: images=torch.Size([64, 3, 32, 32]), labels=torch.Size([64])


Training Progress:  90%|█████████ | 9/10 [25:27<02:33, 153.74s/it]

Epoch 9, loss=0.0181


Training Progress: 100%|██████████| 10/10 [25:43<00:00, 154.34s/it]


In [34]:
torch.save(model.state_dict(), "diffusion_model.pth")


In [None]:
model.eval()
with torch.no_grad():
    labels_for_sampling = torch.randint(0, 10, (16,), device=device)
    samples = model.sample((16, 3, 32, 32), device=device, y=labels_for_sampling)

samples = (samples.clamp(-1, 1) + 1) / 2  

import torchvision
torchvision.utils.save_image(samples, 'generated_samples.png', nrow=4)
