In [1]:
%matplotlib inline
# -*- coding: utf-8 -*-
"""
Exemple de génération d'images par diffusion
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Define a simple U-Net architecture
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Define the forward diffusion process
def forward_diffusion(x0, t, noise):
    return torch.sqrt(1 - t) * x0 + torch.sqrt(t) * noise

# Define the inverse diffusion process
def inverse_diffusion(model, xt, t):
    pred_noise = model(xt)
    return (xt - torch.sqrt(t) * pred_noise) / torch.sqrt(1 - t)

# Load dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(mnist, batch_size=64, shuffle=True)

# Initialize model, optimizer, and loss function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader):
        images = images.to(device)
        noise = torch.randn_like(images).to(device)
        t = torch.rand(images.size(0), 1, 1, 1).to(device)
        
        noisy_images = forward_diffusion(images, t, noise)
        predicted_noise = model(noisy_images)
        
        loss = criterion(predicted_noise, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

# Generate images using the trained model
def generate_images(model, num_images, steps):
    model.eval()
    with torch.no_grad():
        xt = torch.randn(num_images, 1, 28, 28).to(device)
        for t in reversed(torch.linspace(0, 1, steps)):
            xt = inverse_diffusion(model, xt, t)
    return xt.cpu()

# Generate and visualize images
generated_images = generate_images(model, 16, 100)
grid = torch.cat([generated_images[i] for i in range(16)], dim=1).squeeze().numpy()
plt.imshow(grid, cmap='gray')
plt.show()


  warn(


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 38644488.17it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1492333.49it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4136034.00it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 693048.92it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






Epoch [1/10], Step [100/938], Loss: 0.2368


KeyboardInterrupt: 

In [None]:
# Generate images using the trained model
def generate_images(model, num_images, steps):
    model.eval()
    with torch.no_grad():
        xt = torch.randn(num_images, 1, 28, 28).to(device)
        for t in reversed(torch.linspace(0, 1, steps)):
            xt = inverse_diffusion(model, xt, t)
    return xt.cpu()

# Generate and visualize images
generated_images = generate_images(model, 16, 100)
grid = torch.cat([generated_images[i] for i in range(16)], dim=1).squeeze().numpy()
plt.imshow(grid, cmap='gray')
plt.show()
