In [None]:
#!/usr/bin/env python3
"""
DDPM training on CIFAR-10 (fixed UNet channel-mismatch bug — updated)

This script is a full, runnable DDPM training example on CIFAR-10. It
includes a corrected U-Net implementation with explicit skip-channel
bookkeeping that guarantees the up-path ResNet blocks are constructed with
the exact channel counts they will receive during forward().

Key fixes applied:
- During the down path we record the actual "skip" channel counts.
- During up path construction we use those exact skip channel counts in
  reverse order to set the input channel size for the first ResNet block
  (the block that consumes the concatenated tensor).
- This removes any mismatch between concatenated tensor channels and
  ResNet block conv weight shapes.

Usage:
    python ddpm_cifar10_fixed_full_running.py

"""

import os
import math
from pathlib import Path
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

import numpy as np
from tqdm import tqdm
import random

# Optional FID metric
try:
    from torchmetrics.image.fid import FrechetInceptionDistance
except Exception:
    FrechetInceptionDistance = None


# ---------------------------
# Config
# ---------------------------
class CFG:
    image_size = 32
    channels = 3
    batch_size = 16
    lr = 2e-4
    epochs = 100
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    diffusion_steps = 400
    beta_start = 1e-4
    beta_end = 0.02

    base_channels = 128
    channel_mult = (1, 2, 2, 2)  # -> [128,256,256,256]
    attn_resolutions = (16,)     # attention at 16x16
    num_res_blocks = 2
    dropout = 0.1

    out_dir = "./ddpm_original_runs"
    save_every = 400
    sample_batch = 16


# ---------------------------
# Beta schedule & helpers
# ---------------------------

def linear_beta_schedule(timesteps, beta_start, beta_end):
    return torch.linspace(beta_start, beta_end, timesteps)


def make_diffusion_series(T, beta_start, beta_end, device):
    betas = linear_beta_schedule(T, beta_start, beta_end).to(device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
    posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
    return {
        "betas": betas,
        "alphas": alphas,
        "alphas_cumprod": alphas_cumprod,
        "sqrt_alphas_cumprod": sqrt_alphas_cumprod,
        "sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod,
        "posterior_variance": posterior_variance
    }


# ---------------------------
# Sinusoidal time embedding
# ---------------------------
def sinusoidal_positional_embedding(timesteps: torch.Tensor, dim: int):
    assert len(timesteps.shape) == 1
    half = dim // 2
    freqs = torch.exp(- math.log(10000) * torch.arange(half, dtype=torch.float32, device=timesteps.device) / (half - 1))
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb


# ---------------------------
# ResNet block
# ---------------------------
class ResnetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, dropout):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)

        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.SiLU()
        self.res_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity()

        # Lazy GroupNorms (created on first forward pass to match actual channels)
        self.norm1 = None
        self.norm2 = None

    def forward(self, x, t_emb=None):
        if self.norm1 is None or self.norm1.num_channels != x.size(1):
            self.norm1 = nn.GroupNorm(8, x.size(1)).to(x.device)
        if self.norm2 is None or self.norm2.num_channels != self.conv1.out_channels:
            self.norm2 = nn.GroupNorm(8, self.conv1.out_channels).to(x.device)

        h = self.norm1(x)
        h = self.act(h)
        h = self.conv1(h)

        if t_emb is not None:
            h = h + self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)

        h = self.norm2(h)
        h = self.act(h)
        h = self.dropout(h)
        h = self.conv2(h)

        return h + self.res_conv(x)


# ---------------------------
# Attention block
# ---------------------------
class AttentionBlock(nn.Module):
    def __init__(self, ch, num_heads=4):
        super().__init__()
        assert ch % num_heads == 0
        self.num_heads = num_heads
        self.norm = nn.GroupNorm(8, ch)
        self.q = nn.Conv2d(ch, ch, 1)
        self.k = nn.Conv2d(ch, ch, 1)
        self.v = nn.Conv2d(ch, ch, 1)
        self.proj_out = nn.Conv2d(ch, ch, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        q = self.q(h).view(B, self.num_heads, C // self.num_heads, H * W)
        k = self.k(h).view(B, self.num_heads, C // self.num_heads, H * W)
        v = self.v(h).view(B, self.num_heads, C // self.num_heads, H * W)
        scale = 1.0 / math.sqrt(C // self.num_heads)
        attn = torch.einsum('bhdn,bhdm->bhnm', q, k) * scale
        attn = F.softmax(attn, dim=-1)
        out = torch.einsum('bhnm,bhdm->bhdn', attn, v)
        out = out.contiguous().view(B, C, H, W)
        out = self.proj_out(out)
        return x + out


# ---------------------------
# Corrected U-Net
# ---------------------------
class OriginalDDPMUNet(nn.Module):
    def __init__(self, in_ch=3, base_ch=128, channel_mult=(1,2,2,2),
                 attn_resolutions=(16,), num_res_blocks=2, dropout=0.1, time_emb_dim=256):
        super().__init__()
        self.in_ch = in_ch
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        self.init_conv = nn.Conv2d(in_ch, base_ch, kernel_size=3, padding=1)

        # compute channel list per level
        chs = [base_ch * m for m in channel_mult]  # e.g. [128,256,256,256]
        in_out = list(zip([base_ch] + chs[:-1], chs))

        # Down path (we record skip channel sizes)
        self.down_blocks = nn.ModuleList()
        self.skip_channels = []
        curr_res = CFG.image_size
        for i, (lvl_in, lvl_out) in enumerate(in_out):
            block_layers = nn.ModuleList()
            attn_layers = nn.ModuleList()
            for j in range(num_res_blocks):
                in_ch_block = lvl_in if j == 0 else lvl_out
                block_layers.append(ResnetBlock(in_ch_block, lvl_out, time_emb_dim, dropout))
            if curr_res in attn_resolutions:
                attn_layers.append(AttentionBlock(lvl_out))
            down_sample = nn.Conv2d(lvl_out, lvl_out, kernel_size=3, stride=2, padding=1) if i < len(in_out)-1 else nn.Identity()
            self.down_blocks.append(nn.ModuleList([block_layers, attn_layers, down_sample]))
            self.skip_channels.append(lvl_out)
            curr_res //= 2

        # Middle
        mid_ch = chs[-1]
        self.mid_block1 = ResnetBlock(mid_ch, mid_ch, time_emb_dim, dropout)
        self.mid_attn = AttentionBlock(mid_ch)
        self.mid_block2 = ResnetBlock(mid_ch, mid_ch, time_emb_dim, dropout)

        # Up path: build using recorded skip_channels in reverse
        self.up_blocks = nn.ModuleList()
        curr_ch = mid_ch
        # spatial resolution at the bottleneck
        curr_res = max(1, CFG.image_size // (2 ** (len(chs)-1)))

        # Use reversed skip_channels so we pop them in the same order in forward
        rev_skip_chs = list(reversed(self.skip_channels))  # e.g. [lvl3, lvl2, lvl1, lvl0]
        for i, skip_ch in enumerate(rev_skip_chs):
            block_layers = nn.ModuleList()
            attn_layers = nn.ModuleList()

            # first block consumes concatenated channels: curr_ch + skip_ch -> skip_ch
            block_layers.append(ResnetBlock(curr_ch + skip_ch, skip_ch, time_emb_dim, dropout))
            # subsequent blocks (if any) are skip_ch -> skip_ch
            for _ in range(num_res_blocks):
                block_layers.append(ResnetBlock(skip_ch, skip_ch, time_emb_dim, dropout))

            if curr_res in attn_resolutions:
                attn_layers.append(AttentionBlock(skip_ch))

            # upsample that preserves channel count of "h" (curr_ch)
            # For the last iteration (when we've reached the highest resolution), we DON'T upsample
            if i < len(rev_skip_chs) - 1:
                up_sample = nn.ConvTranspose2d(curr_ch, curr_ch, kernel_size=4, stride=2, padding=1)
            else:
                up_sample = nn.Identity()

            self.up_blocks.append(nn.ModuleList([block_layers, attn_layers, up_sample]))

            # after this level, h will have channels = skip_ch
            curr_ch = skip_ch
            curr_res *= 2

        # final layers (GroupNorm -> SiLU -> Conv)
        # After the final up-block the number of channels equals `base_ch`,
        # so final_norm should use base_ch groups/channels and final_conv maps
        # back to the input image channels (e.g. 3 for RGB).
        self.final_norm = nn.GroupNorm(8, base_ch)
        self.final_act = nn.SiLU()
        self.final_conv = nn.Conv2d(base_ch, in_ch, kernel_size=3, padding=1)

    def forward(self, x, t):
        t_emb = sinusoidal_positional_embedding(t, self.time_mlp[0].in_features)
        t_emb = self.time_mlp(t_emb)

        h = self.init_conv(x)
        skips: List[torch.Tensor] = []

        for block_layers, attn_layers, down_sample in self.down_blocks:
            for block in block_layers:
                h = block(h, t_emb)
            for attn in attn_layers:
                h = attn(h)
            skips.append(h)
            h = down_sample(h)

        h = self.mid_block1(h, t_emb)
        h = self.mid_attn(h)
        h = self.mid_block2(h, t_emb)

        # up path: pop skips in reverse order
        for block_layers, attn_layers, up_sample in self.up_blocks:
            h = up_sample(h)
            if len(skips) == 0:
                raise RuntimeError("Skip stack empty — mismatch between down and up blocks.")
            skip = skips.pop()
            if h.shape[2:] != skip.shape[2:]:
                skip = F.interpolate(skip, size=h.shape[2:], mode='nearest')
            h = torch.cat([h, skip], dim=1)  # concatenated channels match first ResnetBlock constructor
            for block in block_layers:
                h = block(h, t_emb)
            for attn in attn_layers:
                h = attn(h)

        h = self.final_norm(h)
        h = self.final_act(h)
        out = self.final_conv(h)
        return out


# ---------------------------
# Dataloader
# ---------------------------
def get_dataloader(batch_size, image_size, train=True):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
    ])
    ds = torchvision.datasets.CIFAR10(root='./data', train=train, download=True, transform=transform)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=train, num_workers=2, pin_memory=True)
    return loader


# ---------------------------
# q_sample & p_sample_loop
# ---------------------------

def q_sample(x0: torch.Tensor, t: torch.LongTensor, noise: torch.Tensor, series):
    sqrt_ac = series['sqrt_alphas_cumprod'][t].view(-1,1,1,1)
    sqrt_om = series['sqrt_one_minus_alphas_cumprod'][t].view(-1,1,1,1)
    return sqrt_ac * x0 + sqrt_om * noise


@torch.no_grad()
def p_sample_loop(model: nn.Module, shape: Tuple[int,int,int,int], series, device, progress=False):
    model.eval()
    B = shape[0]
    x = torch.randn(shape, device=device)
    T = series['betas'].shape[0]
    rng = range(T-1, -1, -1)
    if progress:
        rng = tqdm(rng, desc='sampling')
    for t in rng:
        t_tensor = torch.full((B,), t, dtype=torch.long, device=device)
        eps_pred = model(x, t_tensor)

        beta_t = series['betas'][t]
        alpha_t = series['alphas'][t]
        alpha_cumprod_t = series['alphas_cumprod'][t]

        x0_pred = (x - torch.sqrt(1 - alpha_cumprod_t) * eps_pred) / torch.sqrt(alpha_cumprod_t)

        if t > 0:
            posterior_var = series['posterior_variance'][t]
            mean = ((beta_t * torch.sqrt(series['alphas_cumprod'][t-1]) / (1.0 - alpha_cumprod_t)) * x0_pred
                    + ((1.0 - series['alphas_cumprod'][t-1]) * torch.sqrt(alpha_t) / (1.0 - alpha_cumprod_t)) * x)
            noise = torch.randn_like(x)
            x = mean + torch.sqrt(posterior_var) * noise
        else:
            mean = ((beta_t * torch.sqrt(series['alphas_cumprod'][t-1]) / (1.0 - alpha_cumprod_t)) * x0_pred
                    + ((1.0 - series['alphas_cumprod'][t-1]) * torch.sqrt(alpha_t) / (1.0 - alpha_cumprod_t)) * x)
            x = mean

    x = torch.clamp(x, -1.0, 1.0)
    return x


# ---------------------------
# FID (optional)
# ---------------------------
def evaluate_fid(model: nn.Module, series, device, num_gen=5000, batch_size=128):
    if FrechetInceptionDistance is None:
        print('torchmetrics FID not available; skipping FID.')
        return None
    print(f'Computing FID with {num_gen} generated images (batch_size {batch_size})...')
    fid = FrechetInceptionDistance(feature=2048).to(device)
    real_loader = get_dataloader(batch_size=batch_size, image_size=CFG.image_size, train=True)
    real_count = 0
    for x_real, _ in tqdm(real_loader, desc='Updating FID with real images'):
        x_real = x_real.to(device)
        imgs_uint8 = ((x_real.clamp(-1,1)+1.0)/2.0*255.0).to(torch.uint8)
        fid.update(imgs_uint8, real=True)
        real_count += imgs_uint8.shape[0]
        if real_count >= num_gen:
            break
    gen_count = 0
    gen_bs = min(batch_size, 64)
    while gen_count < num_gen:
        to_gen = min(gen_bs, num_gen - gen_count)
        samples = p_sample_loop(model, (to_gen, CFG.channels, CFG.image_size, CFG.image_size), series, device, progress=False)
        imgs_uint8 = ((samples.clamp(-1,1)+1.0)/2.0*255.0).to(torch.uint8)
        fid.update(imgs_uint8, real=False)
        gen_count += imgs_uint8.shape[0]
        print(f'Generated {gen_count}/{num_gen} for FID', end='')
    result = fid.compute().item()
    print(f'FID: {result:.4f}')
    return result


# ---------------------------
# Training loop
# ---------------------------
def train():
    os.makedirs(CFG.out_dir, exist_ok=True)
    loader = get_dataloader(CFG.batch_size, CFG.image_size, train=True)

    model = OriginalDDPMUNet(
        in_ch=CFG.channels,
        base_ch=CFG.base_channels,
        channel_mult=CFG.channel_mult,
        attn_resolutions=CFG.attn_resolutions,
        num_res_blocks=CFG.num_res_blocks,
        dropout=CFG.dropout
    ).to(CFG.device)

    opt = optim.Adam(model.parameters(), lr=CFG.lr)
    series = make_diffusion_series(CFG.diffusion_steps, CFG.beta_start, CFG.beta_end, CFG.device)

    global_step = 0
    print('Training on', CFG.device)
    print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')

    for epoch in range(CFG.epochs):
        model.train()
        pbar = tqdm(loader, desc=f'Epoch {epoch+1}/{CFG.epochs}')
        for x, _ in pbar:
            x = x.to(CFG.device)
            b = x.shape[0]
            t = torch.randint(0, CFG.diffusion_steps, (b,), device=CFG.device, dtype=torch.long)
            noise = torch.randn_like(x)
            x_t = q_sample(x, t, noise, series)

            eps_pred = model(x_t, t)
            loss = F.mse_loss(eps_pred, noise)

            opt.zero_grad()
            loss.backward()
            opt.step()

            global_step += 1
            pbar.set_postfix({'loss': float(loss.item()), 'step': global_step})

            if global_step % CFG.save_every == 0:
                model.eval()
                with torch.no_grad():
                    samples = p_sample_loop(model, (CFG.sample_batch, CFG.channels, CFG.image_size, CFG.image_size), series, CFG.device, progress=False)
                    grid = (samples + 1.0) / 2.0
                    save_path = Path(CFG.out_dir) / f'samples_step_{global_step}.png'
                    save_image(grid, str(save_path), nrow=4)
                    print(f'Saved samples to {save_path}')
                model.train()

        # epoch checkpoint & sample
        model.eval()
        with torch.no_grad():
            samples = p_sample_loop(model, (CFG.sample_batch, CFG.channels, CFG.image_size, CFG.image_size), series, CFG.device, progress=False)
            grid = (samples + 1.0) / 2.0
            save_path = Path(CFG.out_dir) / f'samples_epoch_{epoch+1}.png'
            save_image(grid, str(save_path), nrow=4)
            print(f'Saved epoch samples to {save_path}')

        ckpt = Path(CFG.out_dir) / f'ddpm_original_epoch_{epoch+1}.pt'
        torch.save({'model': model.state_dict(), 'opt': opt.state_dict(), 'epoch': epoch+1}, ckpt)
        print(f'Saved checkpoint {ckpt}')
        model.train()

    print('Training finished.')
    return model, series


# if __name__ == '__main__':
#     seed = 42
#     torch.manual_seed(seed)
#     np.random.seed(seed)
#     random.seed(seed)

#     # Quick debug: reduce these to test quickly
#     # CFG.epochs = 2
#     # CFG.batch_size = 32

#     model, series = train()

#     with torch.no_grad():
#         samples = p_sample_loop(model, (CFG.sample_batch, CFG.channels, CFG.image_size, CFG.image_size), series, CFG.device, progress=True)
#         grid = (samples + 1.0) / 2.0
#         save_image(grid, str(Path(CFG.out_dir) / 'final_samples.png'), nrow=4)
#         print('Saved final samples.')



In [None]:
# if __name__ == '__main__':
    # Seed everything for repeatability while debugging
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

    # Quick hint: if you want to quickly sanity-check, reduce epochs and batch_size
    # CFG.epochs = 2
    # CFG.batch_size = 32

model, series = train()

    # Optional: generate a small grid and save
with torch.no_grad():
    samples = p_sample_loop(model, (CFG.sample_batch, CFG.channels, CFG.image_size, CFG.image_size), series, CFG.device, progress=True)
    grid = (samples + 1.0) / 2.0
    save_image(grid, str(Path(CFG.out_dir) / "final_samples.png"), nrow=4)
    print("Saved final samples.")

Training on cuda
Model parameters: 30,471,427


Epoch 1/100:  13%|█▎        | 401/3125 [01:50<4:29:37,  5.94s/it, loss=0.0454, step=401]

Saved samples to ddpm_original_runs/samples_step_400.png


Epoch 1/100:  26%|██▌       | 800/3125 [03:39<5:26:08,  8.42s/it, loss=0.0343, step=800]

Saved samples to ddpm_original_runs/samples_step_800.png


Epoch 1/100:  38%|███▊      | 1200/3125 [05:29<4:30:57,  8.45s/it, loss=0.0683, step=1200]

Saved samples to ddpm_original_runs/samples_step_1200.png


Epoch 1/100:  51%|█████     | 1600/3125 [07:18<3:33:39,  8.41s/it, loss=0.0501, step=1600]

Saved samples to ddpm_original_runs/samples_step_1600.png


Epoch 1/100:  64%|██████▍   | 2000/3125 [09:07<2:37:41,  8.41s/it, loss=0.0342, step=2000]

Saved samples to ddpm_original_runs/samples_step_2000.png


Epoch 1/100:  77%|███████▋  | 2400/3125 [10:55<1:41:52,  8.43s/it, loss=0.0509, step=2400]

Saved samples to ddpm_original_runs/samples_step_2400.png


Epoch 1/100:  90%|████████▉ | 2800/3125 [12:44<45:36,  8.42s/it, loss=0.0588, step=2800]

Saved samples to ddpm_original_runs/samples_step_2800.png


Epoch 1/100: 100%|██████████| 3125/3125 [13:51<00:00,  3.76it/s, loss=0.0281, step=3125]


Saved epoch samples to ddpm_original_runs/samples_epoch_1.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_1.pt


Epoch 2/100:   2%|▏         | 75/3125 [00:43<7:17:14,  8.60s/it, loss=0.0724, step=3200]

Saved samples to ddpm_original_runs/samples_step_3200.png


Epoch 2/100:  15%|█▌        | 475/3125 [02:32<6:12:40,  8.44s/it, loss=0.0566, step=3600]

Saved samples to ddpm_original_runs/samples_step_3600.png


Epoch 2/100:  28%|██▊       | 875/3125 [04:21<5:17:22,  8.46s/it, loss=0.0315, step=4000]

Saved samples to ddpm_original_runs/samples_step_4000.png


Epoch 2/100:  41%|████      | 1275/3125 [06:10<4:19:27,  8.41s/it, loss=0.0534, step=4400]

Saved samples to ddpm_original_runs/samples_step_4400.png


Epoch 2/100:  54%|█████▎    | 1675/3125 [07:59<3:24:06,  8.45s/it, loss=0.0557, step=4800]

Saved samples to ddpm_original_runs/samples_step_4800.png


Epoch 2/100:  66%|██████▋   | 2075/3125 [09:49<2:28:05,  8.46s/it, loss=0.0462, step=5200]

Saved samples to ddpm_original_runs/samples_step_5200.png


Epoch 2/100:  79%|███████▉  | 2475/3125 [11:38<1:31:30,  8.45s/it, loss=0.0354, step=5600]

Saved samples to ddpm_original_runs/samples_step_5600.png


Epoch 2/100:  92%|█████████▏| 2875/3125 [13:27<35:11,  8.44s/it, loss=0.0569, step=6000]

Saved samples to ddpm_original_runs/samples_step_6000.png


Epoch 2/100: 100%|██████████| 3125/3125 [14:18<00:00,  3.64it/s, loss=0.0851, step=6250]


Saved epoch samples to ddpm_original_runs/samples_epoch_2.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_2.pt


Epoch 3/100:   5%|▍         | 150/3125 [00:58<7:01:13,  8.50s/it, loss=0.0807, step=6400]

Saved samples to ddpm_original_runs/samples_step_6400.png


Epoch 3/100:  18%|█▊        | 551/3125 [02:47<4:15:05,  5.95s/it, loss=0.0693, step=6801]

Saved samples to ddpm_original_runs/samples_step_6800.png


Epoch 3/100:  30%|███       | 951/3125 [04:37<3:37:23,  6.00s/it, loss=0.0698, step=7201]

Saved samples to ddpm_original_runs/samples_step_7200.png


Epoch 3/100:  43%|████▎     | 1350/3125 [06:26<4:10:18,  8.46s/it, loss=0.0238, step=7601]

Saved samples to ddpm_original_runs/samples_step_7600.png


Epoch 3/100:  56%|█████▌    | 1750/3125 [08:15<3:13:51,  8.46s/it, loss=0.0488, step=8000]

Saved samples to ddpm_original_runs/samples_step_8000.png


Epoch 3/100:  69%|██████▉   | 2150/3125 [10:04<2:17:14,  8.45s/it, loss=0.0285, step=8400]

Saved samples to ddpm_original_runs/samples_step_8400.png


Epoch 3/100:  82%|████████▏ | 2550/3125 [11:53<1:20:58,  8.45s/it, loss=0.0796, step=8800]

Saved samples to ddpm_original_runs/samples_step_8800.png


Epoch 3/100:  94%|█████████▍| 2950/3125 [13:43<24:37,  8.44s/it, loss=0.0456, step=9200]

Saved samples to ddpm_original_runs/samples_step_9200.png


Epoch 3/100: 100%|██████████| 3125/3125 [14:18<00:00,  3.64it/s, loss=0.0389, step=9375]


Saved epoch samples to ddpm_original_runs/samples_epoch_3.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_3.pt


Epoch 4/100:   7%|▋         | 225/3125 [01:13<6:44:05,  8.36s/it, loss=0.045, step=9601] 

Saved samples to ddpm_original_runs/samples_step_9600.png


Epoch 4/100:  20%|██        | 625/3125 [03:02<5:49:50,  8.40s/it, loss=0.027, step=1e+4]

Saved samples to ddpm_original_runs/samples_step_10000.png


Epoch 4/100:  33%|███▎      | 1025/3125 [04:51<4:55:21,  8.44s/it, loss=0.0394, step=10400]

Saved samples to ddpm_original_runs/samples_step_10400.png


Epoch 4/100:  46%|████▌     | 1425/3125 [06:40<3:59:36,  8.46s/it, loss=0.0531, step=10800]

Saved samples to ddpm_original_runs/samples_step_10800.png


Epoch 4/100:  58%|█████▊    | 1825/3125 [08:30<3:03:09,  8.45s/it, loss=0.0513, step=11200]

Saved samples to ddpm_original_runs/samples_step_11200.png


Epoch 4/100:  71%|███████   | 2225/3125 [10:19<2:06:10,  8.41s/it, loss=0.0394, step=11601]

Saved samples to ddpm_original_runs/samples_step_11600.png


Epoch 4/100:  84%|████████▍ | 2625/3125 [12:08<1:10:33,  8.47s/it, loss=0.0354, step=12000]

Saved samples to ddpm_original_runs/samples_step_12000.png


Epoch 4/100:  97%|█████████▋| 3025/3125 [13:57<14:01,  8.41s/it, loss=0.0523, step=12400]

Saved samples to ddpm_original_runs/samples_step_12400.png


Epoch 4/100: 100%|██████████| 3125/3125 [14:17<00:00,  3.64it/s, loss=0.0377, step=12500]


Saved epoch samples to ddpm_original_runs/samples_epoch_4.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_4.pt


Epoch 5/100:  10%|▉         | 300/3125 [01:28<6:36:33,  8.42s/it, loss=0.033, step=12800]

Saved samples to ddpm_original_runs/samples_step_12800.png


Epoch 5/100:  22%|██▏       | 700/3125 [03:18<5:41:17,  8.44s/it, loss=0.0166, step=13200]

Saved samples to ddpm_original_runs/samples_step_13200.png


Epoch 5/100:  35%|███▌      | 1100/3125 [05:07<4:45:47,  8.47s/it, loss=0.0572, step=13600]

Saved samples to ddpm_original_runs/samples_step_13600.png


Epoch 5/100:  48%|████▊     | 1500/3125 [06:56<3:47:20,  8.39s/it, loss=0.0561, step=14000]

Saved samples to ddpm_original_runs/samples_step_14000.png


Epoch 5/100:  61%|██████    | 1900/3125 [08:44<2:52:07,  8.43s/it, loss=0.138, step=14400]

Saved samples to ddpm_original_runs/samples_step_14400.png


Epoch 5/100:  74%|███████▎  | 2300/3125 [10:33<1:55:56,  8.43s/it, loss=0.0236, step=14800]

Saved samples to ddpm_original_runs/samples_step_14800.png


Epoch 5/100:  86%|████████▋ | 2700/3125 [12:23<59:55,  8.46s/it, loss=0.0245, step=15200]

Saved samples to ddpm_original_runs/samples_step_15200.png


Epoch 5/100:  99%|█████████▉| 3100/3125 [14:12<03:31,  8.45s/it, loss=0.0411, step=15600]

Saved samples to ddpm_original_runs/samples_step_15600.png


Epoch 5/100: 100%|██████████| 3125/3125 [14:17<00:00,  3.64it/s, loss=0.0352, step=15625]


Saved epoch samples to ddpm_original_runs/samples_epoch_5.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_5.pt


Epoch 6/100:  12%|█▏        | 375/3125 [01:44<6:28:05,  8.47s/it, loss=0.0403, step=16000]

Saved samples to ddpm_original_runs/samples_step_16000.png


Epoch 6/100:  25%|██▍       | 775/3125 [03:33<5:30:55,  8.45s/it, loss=0.0616, step=16400]

Saved samples to ddpm_original_runs/samples_step_16400.png


Epoch 6/100:  38%|███▊      | 1175/3125 [05:22<4:33:36,  8.42s/it, loss=0.0561, step=16800]

Saved samples to ddpm_original_runs/samples_step_16800.png


Epoch 6/100:  50%|█████     | 1575/3125 [07:11<3:37:07,  8.40s/it, loss=0.0544, step=17200]

Saved samples to ddpm_original_runs/samples_step_17200.png


Epoch 6/100:  63%|██████▎   | 1975/3125 [09:01<2:41:57,  8.45s/it, loss=0.0436, step=17600]

Saved samples to ddpm_original_runs/samples_step_17600.png


Epoch 6/100:  76%|███████▌  | 2375/3125 [10:50<1:45:48,  8.46s/it, loss=0.0588, step=18000]

Saved samples to ddpm_original_runs/samples_step_18000.png


Epoch 6/100:  89%|████████▉ | 2775/3125 [12:39<49:06,  8.42s/it, loss=0.0389, step=18400]

Saved samples to ddpm_original_runs/samples_step_18400.png


Epoch 6/100: 100%|██████████| 3125/3125 [13:51<00:00,  3.76it/s, loss=0.092, step=18750]


Saved epoch samples to ddpm_original_runs/samples_epoch_6.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_6.pt


Epoch 7/100:   2%|▏         | 50/3125 [00:38<7:20:56,  8.60s/it, loss=0.0322, step=18800]

Saved samples to ddpm_original_runs/samples_step_18800.png


Epoch 7/100:  14%|█▍        | 450/3125 [02:27<6:16:36,  8.45s/it, loss=0.0416, step=19200]

Saved samples to ddpm_original_runs/samples_step_19200.png


Epoch 7/100:  27%|██▋       | 850/3125 [04:16<5:20:48,  8.46s/it, loss=0.0467, step=19601]

Saved samples to ddpm_original_runs/samples_step_19600.png


Epoch 7/100:  40%|████      | 1250/3125 [06:05<4:24:01,  8.45s/it, loss=0.0298, step=2e+4]

Saved samples to ddpm_original_runs/samples_step_20000.png


Epoch 7/100:  53%|█████▎    | 1650/3125 [07:54<3:26:53,  8.42s/it, loss=0.0375, step=20400]

Saved samples to ddpm_original_runs/samples_step_20400.png


Epoch 7/100:  66%|██████▌   | 2050/3125 [09:43<2:30:41,  8.41s/it, loss=0.0444, step=20800]

Saved samples to ddpm_original_runs/samples_step_20800.png


Epoch 7/100:  78%|███████▊  | 2450/3125 [11:33<1:35:07,  8.46s/it, loss=0.0769, step=21200]

Saved samples to ddpm_original_runs/samples_step_21200.png


Epoch 7/100:  91%|█████████ | 2850/3125 [13:22<38:54,  8.49s/it, loss=0.06, step=21600]

Saved samples to ddpm_original_runs/samples_step_21600.png


Epoch 7/100: 100%|██████████| 3125/3125 [14:19<00:00,  3.64it/s, loss=0.0785, step=21875]


Saved epoch samples to ddpm_original_runs/samples_epoch_7.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_7.pt


Epoch 8/100:   4%|▍         | 125/3125 [00:53<7:04:23,  8.49s/it, loss=0.113, step=22000]

Saved samples to ddpm_original_runs/samples_step_22000.png


Epoch 8/100:  17%|█▋        | 525/3125 [02:42<6:04:23,  8.41s/it, loss=0.0297, step=22400]

Saved samples to ddpm_original_runs/samples_step_22400.png


Epoch 8/100:  30%|██▉       | 925/3125 [04:31<5:09:13,  8.43s/it, loss=0.0307, step=22800]

Saved samples to ddpm_original_runs/samples_step_22800.png


Epoch 8/100:  42%|████▏     | 1325/3125 [06:20<4:14:08,  8.47s/it, loss=0.0423, step=23200]

Saved samples to ddpm_original_runs/samples_step_23200.png


Epoch 8/100:  55%|█████▌    | 1725/3125 [08:10<3:17:41,  8.47s/it, loss=0.0409, step=23600]

Saved samples to ddpm_original_runs/samples_step_23600.png


Epoch 8/100:  68%|██████▊   | 2125/3125 [09:59<2:20:19,  8.42s/it, loss=0.0594, step=24000]

Saved samples to ddpm_original_runs/samples_step_24000.png


Epoch 8/100:  81%|████████  | 2525/3125 [11:48<1:24:36,  8.46s/it, loss=0.0421, step=24400]

Saved samples to ddpm_original_runs/samples_step_24400.png


Epoch 8/100:  94%|█████████▎| 2925/3125 [13:38<28:14,  8.47s/it, loss=0.0703, step=24800]

Saved samples to ddpm_original_runs/samples_step_24800.png


Epoch 8/100: 100%|██████████| 3125/3125 [14:19<00:00,  3.64it/s, loss=0.0287, step=25000]


Saved epoch samples to ddpm_original_runs/samples_epoch_8.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_8.pt


Epoch 9/100:   6%|▋         | 200/3125 [01:08<6:49:52,  8.41s/it, loss=0.0462, step=25200]

Saved samples to ddpm_original_runs/samples_step_25200.png


Epoch 9/100:  19%|█▉        | 600/3125 [02:58<5:54:25,  8.42s/it, loss=0.0466, step=25600]

Saved samples to ddpm_original_runs/samples_step_25600.png


Epoch 9/100:  32%|███▏      | 1000/3125 [04:47<5:00:02,  8.47s/it, loss=0.0468, step=26000]

Saved samples to ddpm_original_runs/samples_step_26000.png


Epoch 9/100:  45%|████▍     | 1400/3125 [06:36<4:02:51,  8.45s/it, loss=0.0717, step=26400]

Saved samples to ddpm_original_runs/samples_step_26400.png


Epoch 9/100:  58%|█████▊    | 1800/3125 [08:26<3:06:24,  8.44s/it, loss=0.0423, step=26801]

Saved samples to ddpm_original_runs/samples_step_26800.png


Epoch 9/100:  70%|███████   | 2200/3125 [10:15<2:10:39,  8.47s/it, loss=0.0527, step=27200]

Saved samples to ddpm_original_runs/samples_step_27200.png


Epoch 9/100:  83%|████████▎ | 2600/3125 [12:05<1:14:15,  8.49s/it, loss=0.0425, step=27600]

Saved samples to ddpm_original_runs/samples_step_27600.png


Epoch 9/100:  96%|█████████▌| 3000/3125 [13:54<17:35,  8.44s/it, loss=0.0381, step=28000]

Saved samples to ddpm_original_runs/samples_step_28000.png


Epoch 9/100: 100%|██████████| 3125/3125 [14:20<00:00,  3.63it/s, loss=0.0429, step=28125]


Saved epoch samples to ddpm_original_runs/samples_epoch_9.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_9.pt


Epoch 10/100:   9%|▉         | 275/3125 [01:23<6:35:48,  8.33s/it, loss=0.0737, step=28400]

Saved samples to ddpm_original_runs/samples_step_28400.png


Epoch 10/100:  22%|██▏       | 675/3125 [03:13<5:43:41,  8.42s/it, loss=0.0495, step=28800]

Saved samples to ddpm_original_runs/samples_step_28800.png


Epoch 10/100:  34%|███▍      | 1075/3125 [05:02<4:50:13,  8.49s/it, loss=0.044, step=29200]

Saved samples to ddpm_original_runs/samples_step_29200.png


Epoch 10/100:  47%|████▋     | 1475/3125 [06:52<3:51:24,  8.41s/it, loss=0.0657, step=29600]

Saved samples to ddpm_original_runs/samples_step_29600.png


Epoch 10/100:  60%|██████    | 1875/3125 [08:41<2:56:57,  8.49s/it, loss=0.0375, step=3e+4]

Saved samples to ddpm_original_runs/samples_step_30000.png


Epoch 10/100:  73%|███████▎  | 2275/3125 [10:31<1:59:27,  8.43s/it, loss=0.0525, step=30400]

Saved samples to ddpm_original_runs/samples_step_30400.png


Epoch 10/100:  86%|████████▌ | 2675/3125 [12:20<1:03:42,  8.49s/it, loss=0.0616, step=30800]

Saved samples to ddpm_original_runs/samples_step_30800.png


Epoch 10/100:  98%|█████████▊| 3075/3125 [14:09<07:03,  8.46s/it, loss=0.0639, step=31200]

Saved samples to ddpm_original_runs/samples_step_31200.png


Epoch 10/100: 100%|██████████| 3125/3125 [14:20<00:00,  3.63it/s, loss=0.0375, step=31250]


Saved epoch samples to ddpm_original_runs/samples_epoch_10.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_10.pt


Epoch 11/100:  11%|█         | 350/3125 [01:39<6:28:34,  8.40s/it, loss=0.029, step=31600]

Saved samples to ddpm_original_runs/samples_step_31600.png


Epoch 11/100:  24%|██▍       | 750/3125 [03:29<5:34:57,  8.46s/it, loss=0.0568, step=32000]

Saved samples to ddpm_original_runs/samples_step_32000.png


Epoch 11/100:  37%|███▋      | 1150/3125 [05:18<4:39:27,  8.49s/it, loss=0.0465, step=32400]

Saved samples to ddpm_original_runs/samples_step_32400.png


Epoch 11/100:  50%|████▉     | 1550/3125 [07:08<3:42:40,  8.48s/it, loss=0.0315, step=32800]

Saved samples to ddpm_original_runs/samples_step_32800.png


Epoch 11/100:  62%|██████▏   | 1950/3125 [08:58<2:46:04,  8.48s/it, loss=0.0274, step=33200]

Saved samples to ddpm_original_runs/samples_step_33200.png


Epoch 11/100:  75%|███████▌  | 2350/3125 [10:47<1:48:56,  8.43s/it, loss=0.0472, step=33600]

Saved samples to ddpm_original_runs/samples_step_33600.png


Epoch 11/100:  88%|████████▊ | 2750/3125 [12:36<52:49,  8.45s/it, loss=0.0488, step=34000]

Saved samples to ddpm_original_runs/samples_step_34000.png


Epoch 11/100: 100%|██████████| 3125/3125 [13:53<00:00,  3.75it/s, loss=0.0544, step=34375]


Saved epoch samples to ddpm_original_runs/samples_epoch_11.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_11.pt


Epoch 12/100:   1%|          | 25/3125 [00:33<7:22:31,  8.56s/it, loss=0.0351, step=34400]

Saved samples to ddpm_original_runs/samples_step_34400.png


Epoch 12/100:  14%|█▎        | 425/3125 [02:22<6:22:34,  8.50s/it, loss=0.069, step=34800]

Saved samples to ddpm_original_runs/samples_step_34800.png


Epoch 12/100:  26%|██▋       | 825/3125 [04:11<5:22:55,  8.42s/it, loss=0.0526, step=35201]

Saved samples to ddpm_original_runs/samples_step_35200.png


Epoch 12/100:  39%|███▉      | 1225/3125 [06:01<4:27:57,  8.46s/it, loss=0.102, step=35600]

Saved samples to ddpm_original_runs/samples_step_35600.png


Epoch 12/100:  52%|█████▏    | 1625/3125 [07:50<3:31:43,  8.47s/it, loss=0.0317, step=36000]

Saved samples to ddpm_original_runs/samples_step_36000.png


Epoch 12/100:  65%|██████▍   | 2025/3125 [09:40<2:34:43,  8.44s/it, loss=0.0609, step=36400]

Saved samples to ddpm_original_runs/samples_step_36400.png


Epoch 12/100:  78%|███████▊  | 2425/3125 [11:29<1:38:23,  8.43s/it, loss=0.0652, step=36800]

Saved samples to ddpm_original_runs/samples_step_36800.png


Epoch 12/100:  90%|█████████ | 2825/3125 [13:19<42:22,  8.47s/it, loss=0.0582, step=37200]

Saved samples to ddpm_original_runs/samples_step_37200.png


Epoch 12/100: 100%|██████████| 3125/3125 [14:20<00:00,  3.63it/s, loss=0.025, step=37500]


Saved epoch samples to ddpm_original_runs/samples_epoch_12.png
Saved checkpoint ddpm_original_runs/ddpm_original_epoch_12.pt


Epoch 13/100:   3%|▎         | 100/3125 [00:47<7:04:35,  8.42s/it, loss=0.0654, step=37600]

Saved samples to ddpm_original_runs/samples_step_37600.png


Epoch 13/100:  16%|█▌        | 500/3125 [02:37<6:10:09,  8.46s/it, loss=0.0555, step=38000]

Saved samples to ddpm_original_runs/samples_step_38000.png


Epoch 13/100:  23%|██▎       | 718/3125 [03:22<08:06,  4.94it/s, loss=0.071, step=38218]