In [1]:
import torch
import deepinv
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from pathlib import Path
import time



In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device


device(type='cuda', index=0)

In [3]:
batch_size = 32
image_size = 32

transform = transforms.Compose(
    [
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.0,), (1.0,))
    ]
)

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root="./data", train=True, transform=transform, download=True),
    batch_size=batch_size,
    shuffle=True,
)

In [4]:
lr = 1e-4
epochs = 100
start_epoch = 0

# pretrained_value = Path("./checkpoints/ddpm_epoch_25.pth")
pretrained_value = None
model = deepinv.models.DiffUNet(in_channels=3, out_channels=1, large_model=False, pretrained=pretrained_value).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
mse = deepinv.loss.MSE()

beta_start = 1e-4
beta_end = 0.02
timesteps = 1000

betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

In [5]:
for epoch in range(start_epoch, epochs):
    model.train()
    epoch_loss = 0.0  # Accumulate loss for the epoch
    start_time = time.time()  # Start timing the epoch
    num_batches = len(train_loader)

    for batch_idx, (data, _) in tqdm(enumerate(train_loader), desc="Training", total=num_batches):
        imgs = data.to(device)
        noise = torch.randn_like(imgs)
        t = torch.randint(0, timesteps, (imgs.size(0),), device=device)

        noised_imgs = (
            sqrt_alphas_cumprod[t, None, None, None] * imgs +
            sqrt_one_minus_alphas_cumprod[t, None, None, None] * noise
        )

        optimizer.zero_grad()
        estimated_noise = model(noised_imgs, t, type_t="timestep")
        loss = mse(estimated_noise, noise)
        loss = loss.sum()  # Ensure loss is a scalar for backward()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()  # Accumulate batch loss

    avg_loss = epoch_loss / num_batches  # Average loss for the epoch
    epoch_time = time.time() - start_time  # Calculate epoch time
    print(f"Epoch [{epoch+1}/{epochs}], Average Loss: {avg_loss:.4f}, Epoch Time: {epoch_time:.2f} seconds")
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f"./checkpoints/cifar10/ddpm_epoch_{epoch+1}.pth")

Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [1/100], Average Loss: 21.9345, Epoch Time: 153.52 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [2/100], Average Loss: 21.8109, Epoch Time: 151.36 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [3/100], Average Loss: 21.8013, Epoch Time: 155.40 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [4/100], Average Loss: 21.8052, Epoch Time: 151.04 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [5/100], Average Loss: 21.7949, Epoch Time: 152.52 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [6/100], Average Loss: 21.7807, Epoch Time: 163.79 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [7/100], Average Loss: 21.7785, Epoch Time: 159.51 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [8/100], Average Loss: 21.7891, Epoch Time: 160.03 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [9/100], Average Loss: 21.7775, Epoch Time: 157.00 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [10/100], Average Loss: 21.7811, Epoch Time: 151.65 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [11/100], Average Loss: 21.7805, Epoch Time: 152.03 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [12/100], Average Loss: 21.7850, Epoch Time: 154.30 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [13/100], Average Loss: 21.7783, Epoch Time: 152.26 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [14/100], Average Loss: 21.7848, Epoch Time: 154.64 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [15/100], Average Loss: 21.7659, Epoch Time: 149.82 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [16/100], Average Loss: 21.7786, Epoch Time: 149.70 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [17/100], Average Loss: 21.7691, Epoch Time: 164.77 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [18/100], Average Loss: 21.7712, Epoch Time: 167.80 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [19/100], Average Loss: 21.7856, Epoch Time: 156.89 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [20/100], Average Loss: 21.7681, Epoch Time: 159.18 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [21/100], Average Loss: 21.7770, Epoch Time: 163.70 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [22/100], Average Loss: 21.7739, Epoch Time: 159.43 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [23/100], Average Loss: 21.7745, Epoch Time: 155.09 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [24/100], Average Loss: 21.7706, Epoch Time: 164.11 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [25/100], Average Loss: 21.7743, Epoch Time: 168.06 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [26/100], Average Loss: 21.7615, Epoch Time: 150.24 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [27/100], Average Loss: 21.7564, Epoch Time: 150.05 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [28/100], Average Loss: 21.7676, Epoch Time: 152.18 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [29/100], Average Loss: 21.7623, Epoch Time: 150.44 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [30/100], Average Loss: 21.7616, Epoch Time: 151.26 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [31/100], Average Loss: 21.7743, Epoch Time: 153.49 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [32/100], Average Loss: 21.7710, Epoch Time: 161.39 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [33/100], Average Loss: 21.7643, Epoch Time: 157.97 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [34/100], Average Loss: 21.7772, Epoch Time: 149.04 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [35/100], Average Loss: 21.7728, Epoch Time: 148.75 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [36/100], Average Loss: 21.7709, Epoch Time: 154.70 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [37/100], Average Loss: 21.7675, Epoch Time: 162.10 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [38/100], Average Loss: 21.7564, Epoch Time: 166.84 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [39/100], Average Loss: 21.7693, Epoch Time: 162.13 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [40/100], Average Loss: 21.7633, Epoch Time: 152.40 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [41/100], Average Loss: 21.7711, Epoch Time: 163.12 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [42/100], Average Loss: 21.7633, Epoch Time: 155.80 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [43/100], Average Loss: 21.7650, Epoch Time: 159.92 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [44/100], Average Loss: 21.7701, Epoch Time: 157.95 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [45/100], Average Loss: 21.7734, Epoch Time: 159.54 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [46/100], Average Loss: 21.7728, Epoch Time: 157.75 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [47/100], Average Loss: 21.7641, Epoch Time: 153.59 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [48/100], Average Loss: 21.7640, Epoch Time: 152.04 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [49/100], Average Loss: 21.7601, Epoch Time: 153.82 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

Epoch [50/100], Average Loss: 21.7639, Epoch Time: 154.79 seconds


Training:   0%|          | 0/1563 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "trained_ddpm.pth")