<a href="https://colab.research.google.com/github/Rukkya/.py/blob/main/cifar10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:


import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


train_data = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:02<00:00, 57.8MB/s]


Extracting data/cifar-10-python.tar.gz to data


In [5]:
T = 1000
betas = torch.linspace(0.0001, 0.02, T)
alphas = 1 - betas
alpha_bars = torch.cumprod(alphas, dim=0)

def add_noise(x0, t):
    noise = torch.randn_like(x0)
    sqrt_alpha_bar = torch.sqrt(alpha_bars[t])[:, None, None, None]
    sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_bars[t])[:, None, None, None]
    return sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise, noise



In [7]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.enc1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.ReLU())
        self.enc2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(2))
        self.bottleneck = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.ReLU())
        self.upconv1 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=1))
        self.dec1 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.ReLU())
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.ReLU())
        self.out_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)

    def forward(self, x, t):

        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x3 = self.bottleneck(x2)
        x4 = self.upconv1(x3)
        x4 = self.dec1(x4 + self.upconv1(x3))
        x5 = self.dec2(self.upconv2(x4))


        out = self.out_conv(x5)
        return out

In [None]:
model = UNet()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

def train(model, dataloader, optimizer, criterion, epochs=10):
    for epoch in range(epochs):
        for x0, _ in dataloader:
            t = torch.randint(0, T, (x0.size(0),), device=x0.device)
            noisy_images, noise = add_noise(x0, t)
            pred_noise = model(noisy_images, t)

    if pred_noise.shape[1] != noise.shape[1]:
        print(f"Warning: Mismatch in channel dimensions: pred_noise - {pred_noise.shape}, noise - {noise.shape}")
    loss = criterion(pred_noise, noise)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

train(model, train_loader, optimizer, criterion)


In [None]:
def sample(model, T):
    with torch.no_grad():
        x = torch.randn(1, 3, 32, 32)
        for t in reversed(range(T)):
            pred_noise = model(x, torch.tensor([t]))
            beta = betas[t]
            alpha_bar = alpha_bars[t]
            if t > 0:
                noise = torch.randn_like(x)
                x = (1 / torch.sqrt(1 - beta)) * (x - beta * pred_noise / torch.sqrt(1 - alpha_bar)) + torch.sqrt(beta) * noise
            else:
                x = (1 / torch.sqrt(1 - beta)) * (x - beta * pred_noise / torch.sqrt(1 - alpha_bar))
    return x

sample_image = sample(model, T)
plt.imshow((sample_image[0].permute(1, 2, 0) + 1) / 2)
plt.show()


In [None]:
def plot_images(images, title):
    fig, axs = plt.subplots(1, len(images), figsize=(12, 4))
    for i, img in enumerate(images):
        axs[i].imshow((img.permute(1, 2, 0) + 1) / 2)
        axs[i].axis('off')
    plt.suptitle(title)
    plt.show()

generated_images = [sample(model, T) for _ in range(5)]
plot_images(generated_images, "Generated Images")
