In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

In [2]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data preparation
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 49287311.48it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [3]:
# Model Definition
class DiffusionModel(nn.Module):
    def __init__(self, channels, img_size, hidden_dim):
        super(DiffusionModel, self).__init__()
        self.channels = channels
        self.img_size = img_size
        self.hidden_dim = hidden_dim
        
        self.encoder = nn.Sequential(
            nn.Conv2d(channels, hidden_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim * 2, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim * 2, hidden_dim * 4, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim * 4, hidden_dim * 8, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim * 8, hidden_dim * 4, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(hidden_dim, channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [4]:
# Hyperparameters
channels = 3
img_size = 32
hidden_dim = 64
num_epochs = 50
sample_interval = 200

# Create the diffusion model
model = DiffusionModel(channels, img_size, hidden_dim).to(device)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Create a directory to save generated images
os.makedirs("images", exist_ok=True)

In [5]:
# Training loop
total_steps = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        # Move images to the device
        images = images.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, images)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print the progress
        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_steps}], Loss: {loss.item():.4f}")

        # Save generated images at regular intervals
        batches_done = epoch * total_steps + i
        if batches_done % sample_interval == 0:
            # Generate sample images
            sample_imgs = model(images).detach()
            # Rescale images to the range [0, 1]
            sample_imgs = (sample_imgs + 1) / 2.0
            # Save the sample images
            save_image(sample_imgs, f"images/sample_{batches_done}.png", nrow=8, normalize=True)

    # Save the model checkpoint after each epoch
    torch.save(model.state_dict(), f"diffusion_model_epoch_{epoch+1}.pth")

print("Training complete!")

Epoch [1/50], Step [100/782], Loss: 0.0281
Epoch [1/50], Step [200/782], Loss: 0.0227
Epoch [1/50], Step [300/782], Loss: 0.0206
Epoch [1/50], Step [400/782], Loss: 0.0151
Epoch [1/50], Step [500/782], Loss: 0.0126
Epoch [1/50], Step [600/782], Loss: 0.0118
Epoch [1/50], Step [700/782], Loss: 0.0100
Epoch [2/50], Step [100/782], Loss: 0.0107
Epoch [2/50], Step [200/782], Loss: 0.0085
Epoch [2/50], Step [300/782], Loss: 0.0089
Epoch [2/50], Step [400/782], Loss: 0.0086
Epoch [2/50], Step [500/782], Loss: 0.0075
Epoch [2/50], Step [600/782], Loss: 0.0079
Epoch [2/50], Step [700/782], Loss: 0.0067
Epoch [3/50], Step [100/782], Loss: 0.0070
Epoch [3/50], Step [200/782], Loss: 0.0063
Epoch [3/50], Step [300/782], Loss: 0.0064
Epoch [3/50], Step [400/782], Loss: 0.0058
Epoch [3/50], Step [500/782], Loss: 0.0063
Epoch [3/50], Step [600/782], Loss: 0.0061
Epoch [3/50], Step [700/782], Loss: 0.0053
Epoch [4/50], Step [100/782], Loss: 0.0058
Epoch [4/50], Step [200/782], Loss: 0.0055
Epoch [4/50

KeyboardInterrupt: ignored