In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [5]:
# Define the U-Net model (simplified version for demonstration purposes)
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Define the layers of the U-Net here (simplified for demonstration)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),  # First convolutional layer
            nn.ReLU(),  # Activation function
            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # Second convolutional layer
            nn.ReLU()  # Activation function
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1),  # First deconvolutional layer
            nn.ReLU(),  # Activation function
            nn.ConvTranspose2d(64, 1, kernel_size=3, padding=1),  # Second deconvolutional layer
            nn.Sigmoid()  # Activation function to bring output in range [0, 1]
        )

    def forward(self, x, t):
        # Define the forward pass of the U-Net here
        # In an actual implementation, 't' would be used in conditioning
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Parameters
num_timesteps = 1000  # Number of timesteps in the forward diffusion process
beta = torch.linspace(0.0001, 0.02, num_timesteps)  # Linearly spaced variances for the forward process

# Forward diffusion (noising) process
def forward_diffusion(x0, t):
    # Generate random Gaussian noise
    noise = torch.randn_like(x0)
    # Calculate the cumulative product of (1 - beta) up to the t-th timestep
    alpha_t = torch.cumprod(1 - beta, dim=0)[t].view(-1, 1, 1, 1)
    # Return the noisy image and the actual noise added
    return torch.sqrt(alpha_t) * x0 + torch.sqrt(1 - alpha_t) * noise, noise

# Define the training loop
def train(model, dataloader, optimizer, num_epochs=10):
    model.train()  # Set the model to training mode
    for epoch in range(num_epochs):
        for batch in dataloader:
            x0, _ = batch  # Get a batch of images (x0)
            t = torch.randint(0, num_timesteps, (x0.size(0),), device=x0.device).long()  # Randomly select timesteps for each image in the batch
            xt, noise = forward_diffusion(x0, t)  # Generate noisy images and get the actual noise
            noise_pred = model(xt, t)  # Predict the noise using the model
            loss = nn.MSELoss()(noise_pred, noise)  # Calculate the mean squared error loss between predicted and actual noise
            optimizer.zero_grad()  # Zero the gradients
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update the model parameters
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}")  # Print the loss for each epoch

# Prepare the dataset and dataloader
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5,), (0.5,))  # Normalize the images to the range [-1, 1]
])
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)  # Download and prepare the MNIST dataset
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)  # Create a dataloader for batching and shuffling the data

# Initialize the model and optimizer
model = UNet().to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))  # Move the model to GPU if available
optimizer = optim.Adam(model.parameters(), lr=1e-4)  # Use the Adam optimizer

# Train the model
train(model, dataloader, optimizer)  # Call the training function


KeyboardInterrupt: 