In [14]:
import torchvision.transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import numpy as np
import tqdm

In [15]:
epochs = 5
t_steps = 1000
batch_size = 64
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [16]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Нормализация в [-1, 1]
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [17]:
class SimpleResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.block1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        self.block2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
        self.residual_conv = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
    
    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        return h + self.residual_conv(x)

In [18]:
class SimpleUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        
        self.down1 = SimpleResidualBlock(in_channels, 32)
        self.down2 = SimpleResidualBlock(32, 64)
        self.down3 = SimpleResidualBlock(64, 128)
        self.down4 = SimpleResidualBlock(128, 256)

        self.mid = SimpleResidualBlock(256, 256)

        # self.up3 = SimpleResidualBlock(128 + 128, 64)
        # self.up2 = SimpleResidualBlock(64 + 64, 32)
        # self.up1 = SimpleResidualBlock(32 + 32, out_channels)

        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.ConvTranspose2d(128, 128, 3, padding=1)

        self.time_proj = nn.Linear(1, 28*28)
    
    def forward(self, x, t):
        batch_size = x.shape[0]
        t_proj = self.time_proj(t.float().view(-1, 1)).view(batch_size, 1, 28, 28)
        x = x + t_proj

        x1 = self.pool(self.down1(x))
        x2 = self.pool(self.down2(x1))
        x3 = self.pool(self.down3(x2))
        x4 = self.pool(self.down3(x3))

        m = self.mid(x4)
        print("BEFORE DISATER", m.shape, x4.shape)

        m = self.upsample(m)
        print(m.shape, " ABABABAB ", x3.shape)
        m = torch.cat([m, x3], dim=1)
        m = self.up3(m)
        
        m = self.upsample(m)
        m = torch.cat([m, x2], dim=1)
        m = self.up2(m)
        
        m = self.upsample(m)
        m = torch.cat([m, x1], dim=1)
        m = self.up1(m)
        
        return m

In [19]:
def show_img_batch(images, images_count):
    plt.figure(figsize=(7, 5))
    for i in range(images.shape[0]):
        img = images[i].view(-1, 1, 28, 28)
        plt.subplot(1, images_count, i + 1)
        plt.imshow(img.cpu().detach().numpy().squeeze(), cmap="gray")
        plt.title(f'Image {i + 1}')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

In [20]:
def simple_forward_diffusion(x0, t, T=1000):
    alpha = 1.0 - (t.float() / T)
    alpha = alpha.view(-1, 1, 1, 1).to(x0.device)
    
    noise = torch.randn_like(x0)
    noisy_images = alpha.sqrt() * x0 + (1 - alpha).sqrt() * noise
    
    return noisy_images, noise

In [21]:
class DDPMScheduler:
    def __init__(self, device, T=1000, beta_start=0.0001, beta_end=0.02):
        self.T = T
        self.betas = torch.linspace(beta_start, beta_end, T).to(device)
        self.alphas = 1.0 - self.betas
        self.alphas_bar = torch.cumprod(self.alphas, dim=0).to(device)

    def add_noise(self, x_start, t):
        # Выбираем alpha_bar_t для каждого изображения в батче
        sqrt_alphas_bar = torch.sqrt(self.alphas_bar[t])[:, None, None, None]
        sqrt_one_minus_alphas_bar = torch.sqrt(1.0 - self.alphas_bar[t])[:, None, None, None]
        
        # Генерируем случайный Гауссовский шум (epsilon)
        noise = torch.randn_like(x_start).to(device)
        
        # Применяем формулу добавления шума
        x_t = sqrt_alphas_bar * x_start + sqrt_one_minus_alphas_bar * noise
        
        return x_t, noise


In [22]:
def train_one_epoch(model, scheduler, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        images = data.to(device)
        
        t = torch.randint(0, scheduler.T, (images.shape[0],), device=device).long()

        x_t, true_noise = scheduler.add_noise(images, t)
        
        # 3. Предсказание шума (обратная диффузия)
        predicted_noise = model(x_t, t)
        
        # 4. Расчет L2-потери (MSE) между предсказанным и истинным шумом
        loss = nn.functional.mse_loss(predicted_noise, true_noise)
        
        # 5. Оптимизация
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    return total_loss / len(dataloader)

In [23]:
@torch.no_grad()
def generate_images(model, scheduler, num_images, device):
    model.eval()
    
    x = torch.randn(num_images, 1, 28, 28, device=device)

    for t_step in tqdm(reversed(range(scheduler.T)), desc="Sampling", total=scheduler.T):
        t = torch.tensor([t_step] * num_images, device=device, dtype=torch.long)

        alpha_t = scheduler.alphas[t_step]
        alpha_bar_t = scheduler.alphas_bar[t_step]
        beta_t = scheduler.betas[t_step]

        predicted_noise = model(x, t)
        
        # Формула для оценки x_{t-1} из x_t (наивный вариант)
        # x_{t-1} = 1/sqrt(alpha_t) * (x_t - (1-alpha_t)/sqrt(1-alpha_bar_t) * predicted_noise) + sigma_t * z
        
        # Коэффициент для "чистой" части изображения
        mean_coefficient = 1.0 / torch.sqrt(alpha_t)
        
        # Коэффициент для предсказанного шума
        noise_coefficient = (1.0 - alpha_t) / torch.sqrt(1.0 - alpha_bar_t)
        
        # Расчет "среднего" (mu_t) для перехода к x_{t-1}
        mu_t = mean_coefficient * (x - noise_coefficient * predicted_noise)
        
        if t_step > 0:
            noise = torch.randn_like(x)
            sigma_t = torch.sqrt(beta_t)
            x = mu_t + sigma_t * noise
        else:
            x = mu_t

    x = (x.clamp(-1, 1) + 1) / 2
    return x.cpu()


In [24]:
model = SimpleUNet().to(device)
scheduler = DDPMScheduler(T=t_steps, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [25]:
losses = []

In [26]:
for epoch in range(epochs):
        avg_loss = train_one_epoch(model, scheduler, train_loader, optimizer, device)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

RuntimeError: Given groups=1, weight of size [128, 64, 3, 3], expected input[64, 128, 3, 3] to have 64 channels, but got 128 channels instead