Рустам Шамсутдинов БВТ2201

In [12]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt
import numpy as np

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
lr = 2e-4
epochs = 50
img_size = 28
channels = 1
T = 200   # число шагов диффузии
beta_start = 1e-4
beta_end = 0.02
save_dir = "./ddpm_out"
os.makedirs(save_dir, exist_ok=True)
torch.manual_seed(42)

<torch._C.Generator at 0x7df5432c57b0>

In [14]:
betas = torch.linspace(beta_start, beta_end, T, dtype=torch.float32, device=device)
alphas = 1.0 - betas
alpha_cumprod = torch.cumprod(alphas, dim=0)
alpha_cumprod_prev = torch.cat([torch.tensor([1.], device=device), alpha_cumprod[:-1]])
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod)
sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - alpha_cumprod)
posterior_variance = betas * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)

In [15]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_ds = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

In [16]:
def extract(a, t, x_shape):

    batch_size = t.shape[0]
    out = a[t].reshape(batch_size, *((1,) * (len(x_shape) - 1)))
    return out

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_alpha_cumprod_t = extract(sqrt_alpha_cumprod, t, x_start.shape)
    sqrt_one_minus_alpha_cumprod_t = extract(sqrt_one_minus_alpha_cumprod, t, x_start.shape)
    return sqrt_alpha_cumprod_t * x_start + sqrt_one_minus_alpha_cumprod_t * noise

In [17]:
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        device = t.device
        half_dim = self.dim // 2
        emb = torch.exp(-math.log(10000) * torch.arange(0, half_dim, device=device) / half_dim)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return emb

In [18]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_ch)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)

    def forward(self, x, t_emb):
        h = self.conv1(x)
        h = self.norm1(h)
        h = F.silu(h)
        time_emb = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)
        h = h + time_emb
        h = self.conv2(h)
        h = self.norm2(h)
        h = F.silu(h)
        return h

class SimpleUNet(nn.Module):
    def __init__(self, in_ch=1, base_ch=64, time_emb_dim=128):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        )
        # encoder
        self.down1 = Block(in_ch, base_ch, time_emb_dim)
        self.down2 = Block(base_ch, base_ch*2, time_emb_dim)
        self.down3 = Block(base_ch*2, base_ch*4, time_emb_dim)
        # decoder
        self.up1 = Block(base_ch*4 + base_ch*2, base_ch*2, time_emb_dim)
        self.up2 = Block(base_ch*2 + base_ch, base_ch, time_emb_dim)
        self.final_conv = nn.Sequential(
            nn.Conv2d(base_ch, base_ch, 3, padding=1),
            nn.GroupNorm(8, base_ch),
            nn.SiLU(),
            nn.Conv2d(base_ch, in_ch, 1),
        )
        self.pool = nn.AvgPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, x, t):
        t = t.float()
        t_emb = self.time_mlp(t)
        h1 = self.down1(x, t_emb)
        h2 = self.down2(self.pool(h1), t_emb)
        h3 = self.down3(self.pool(h2), t_emb)
        u = self.upsample(h3)
        u = torch.cat([u, h2], dim=1)
        u = self.up1(u, t_emb)
        u = self.upsample(u)
        u = torch.cat([u, h1], dim=1)
        u = self.up2(u, t_emb)
        out = self.final_conv(u)
        return out

In [19]:
model = SimpleUNet(in_ch=channels, base_ch=32, time_emb_dim=128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

class EMA:
    def __init__(self, model, beta=0.9999):
        self.beta = beta
        self.model_shadow = self._clone_model(model)
    def _clone_model(self, model):
        shadow = type(model)(in_ch=channels, base_ch=32, time_emb_dim=128).to(device)
        shadow.load_state_dict(model.state_dict())
        for p in shadow.parameters():
            p.requires_grad_(False)
        return shadow
    def update(self, model):
        ms = model.state_dict()
        for k, v in self.model_shadow.state_dict().items():
            v.copy_(v * self.beta + (1. - self.beta) * ms[k].to(v.dtype))
    def to(self, device):
        self.model_shadow.to(device)

ema = EMA(model)

mse = nn.MSELoss()

In [None]:
@torch.no_grad()
def p_sample(x_t, t):
    B = x_t.shape[0]
    t_tensor = torch.full((B,), t, device=device, dtype=torch.long)
    eps_theta = model(x_t, t_tensor)
    pred_x0 = (x_t - extract(sqrt_one_minus_alpha_cumprod, t_tensor, x_t.shape) * eps_theta) / extract(sqrt_alpha_cumprod, t_tensor, x_t.shape)
    coef1 = extract(betas, t_tensor, x_t.shape) * torch.sqrt(alpha_cumprod_prev[t]) / (1.0 - alpha_cumprod[t])
    coef2 = (1.0 - alpha_cumprod_prev[t]) * torch.sqrt(alphas[t]) / (1.0 - alpha_cumprod[t])
    posterior_mean = coef1 * pred_x0 + coef2 * x_t
    posterior_var = posterior_variance[t]
    if t > 0:
        noise = torch.randn_like(x_t)
        return posterior_mean + torch.sqrt(posterior_var) * noise
    else:
        return posterior_mean

@torch.no_grad()
def sample(n_samples=64, use_ema=False):
    model.eval()
    if use_ema:
        model_shadow = ema.model_shadow
    else:
        model_shadow = model
    x = torch.randn(n_samples, channels, img_size, img_size, device=device) 
    for t in reversed(range(T)):
        B = x.size(0)
        tt = torch.full((B,), t, device=device, dtype=torch.long)
        eps_theta = model_shadow(x, tt)
        posterior_mean = (1.0 / torch.sqrt(alphas[t])) * (x - (betas[t] / torch.sqrt(1.0 - alpha_cumprod[t])) * eps_theta)
        if t > 0:
            var = posterior_variance[t]
            noise = torch.randn_like(x)
            x = posterior_mean + torch.sqrt(var) * noise
        else:
            x = posterior_mean
    model.train()
    return x

@torch.no_grad()
def sample_and_save(epoch, n_samples=64, use_ema=True):
    imgs = sample(n_samples=n_samples, use_ema=use_ema)
    imgs = (imgs + 1.0) / 2.0
    imgs = imgs.clamp(0., 1.)
    grid = utils.make_grid(imgs.cpu(), nrow=int(math.sqrt(n_samples)), padding=2)
    plt.figure(figsize=(6,6))
    plt.axis('off')
    plt.imshow(grid.permute(1,2,0).squeeze(), cmap='gray')
    out_path = os.path.join(save_dir, f"ddpm_samples_epoch_{epoch}.png")
    plt.savefig(out_path, bbox_inches='tight')
    plt.close()
    print(f"Saved samples to {out_path}")

In [21]:
def train():
    global_step = 0
    model.train()
    for epoch in range(1, epochs + 1):
        running_loss = 0.0
        for xb, _ in train_loader:
            xb = xb.to(device)
            b = xb.size(0)
            t = torch.randint(0, T, (b,), device=device, dtype=torch.long)
            noise = torch.randn_like(xb)
            x_noisy = q_sample(xb, t, noise)
            predicted_noise = model(x_noisy, t)
            loss = mse(predicted_noise, noise)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ema.update(model)

            running_loss += loss.item()
            global_step += 1
            if global_step % 200 == 0:
                avg = running_loss / 200
                print(f"Epoch {epoch} step {global_step} avg_loss={avg:.6f}")
                running_loss = 0.0

        sample_and_save(epoch, use_ema=True)
    print("Training finished.")

In [22]:
train()

Epoch 1 step 200 avg_loss=0.219642
Epoch 1 step 400 avg_loss=0.091356
Saved samples to ./ddpm_out/ddpm_samples_epoch_1.png
Epoch 2 step 600 avg_loss=0.049638
Epoch 2 step 800 avg_loss=0.071786
Saved samples to ./ddpm_out/ddpm_samples_epoch_2.png
Epoch 3 step 1000 avg_loss=0.020613
Epoch 3 step 1200 avg_loss=0.064662
Epoch 3 step 1400 avg_loss=0.062236
Saved samples to ./ddpm_out/ddpm_samples_epoch_3.png
Epoch 4 step 1600 avg_loss=0.059030
Epoch 4 step 1800 avg_loss=0.059692
Saved samples to ./ddpm_out/ddpm_samples_epoch_4.png
Epoch 5 step 2000 avg_loss=0.036793
Epoch 5 step 2200 avg_loss=0.057546
Saved samples to ./ddpm_out/ddpm_samples_epoch_5.png
Epoch 6 step 2400 avg_loss=0.015870
Epoch 6 step 2600 avg_loss=0.056220
Epoch 6 step 2800 avg_loss=0.055970
Saved samples to ./ddpm_out/ddpm_samples_epoch_6.png
Epoch 7 step 3000 avg_loss=0.051695
Epoch 7 step 3200 avg_loss=0.054349
Saved samples to ./ddpm_out/ddpm_samples_epoch_7.png
Epoch 8 step 3400 avg_loss=0.031977
Epoch 8 step 3600 avg