- import necessary libraries

In [1]:
!pip install torch-fidelity


Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->torch-fidelity)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->torch-fidelity)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->torch-fidelity)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->torch-fidelity)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->torch-fidelity)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->torch-fidel

In [1]:
import os
import math
from pathlib import Path
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

import numpy as np
from tqdm import tqdm
import random
import torchmetrics
import torch_fidelity


- FID metric

In [2]:

try:
    from torchmetrics.image.fid import FrechetInceptionDistance
except Exception:
    FrechetInceptionDistance = None

- Configuration

In [3]:
class CFG:
    image_size = 32
    channels = 3
    batch_size = 32
    lr = 2e-4
    epochs = 100
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    diffusion_steps = 400
    beta_start = 1e-4
    beta_end = 0.02

    base_channels = 128
    channel_mult = (1, 2, 2, 2)  # -> [128,256,256,256]
    attn_resolutions = (16,)     # attention at 16x16
    num_res_blocks = 2
    dropout = 0.1

    out_dir = "./ddpm_original_runs_1"
    save_every = 400
    sample_batch = 16

- Beta schedule & helpers

In [6]:
def linear_beta_schedule(timesteps, beta_start, beta_end):
    return torch.linspace(beta_start, beta_end, timesteps)


def make_diffusion_series(T, beta_start, beta_end, device):
    betas = linear_beta_schedule(T, beta_start, beta_end).to(device)
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)
    alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
    posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
    return {
        "betas": betas,
        "alphas": alphas,
        "alphas_cumprod": alphas_cumprod,
        "sqrt_alphas_cumprod": sqrt_alphas_cumprod,
        "sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod,
        "posterior_variance": posterior_variance
    }

- Sinusoidal time embedding

In [7]:
def sinusoidal_positional_embedding(timesteps: torch.Tensor, dim: int):
    assert len(timesteps.shape) == 1
    half = dim // 2
    freqs = torch.exp(- math.log(10000) * torch.arange(half, dtype=torch.float32, device=timesteps.device) / (half - 1))
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    if dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb

- ResNet block

In [8]:
class ResnetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, dropout):
        super().__init__()
        self.in_ch = in_ch
        self.out_ch = out_ch
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)

        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(dropout)
        self.act = nn.SiLU()
        self.res_conv = nn.Conv2d(in_ch, out_ch, kernel_size=1) if in_ch != out_ch else nn.Identity()

        # Lazy GroupNorms (created on first forward pass to match actual channels)
        self.norm1 = None
        self.norm2 = None

    def forward(self, x, t_emb=None):
        if self.norm1 is None or self.norm1.num_channels != x.size(1):
            self.norm1 = nn.GroupNorm(8, x.size(1)).to(x.device)
        if self.norm2 is None or self.norm2.num_channels != self.conv1.out_channels:
            self.norm2 = nn.GroupNorm(8, self.conv1.out_channels).to(x.device)

        h = self.norm1(x)
        h = self.act(h)
        h = self.conv1(h)

        if t_emb is not None:
            h = h + self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)

        h = self.norm2(h)
        h = self.act(h)
        h = self.dropout(h)
        h = self.conv2(h)

        return h + self.res_conv(x)


- Attention block

In [9]:
class AttentionBlock(nn.Module):
    def __init__(self, ch, num_heads=4):
        super().__init__()
        assert ch % num_heads == 0
        self.num_heads = num_heads
        self.norm = nn.GroupNorm(8, ch)
        self.q = nn.Conv2d(ch, ch, 1)
        self.k = nn.Conv2d(ch, ch, 1)
        self.v = nn.Conv2d(ch, ch, 1)
        self.proj_out = nn.Conv2d(ch, ch, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        q = self.q(h).view(B, self.num_heads, C // self.num_heads, H * W)
        k = self.k(h).view(B, self.num_heads, C // self.num_heads, H * W)
        v = self.v(h).view(B, self.num_heads, C // self.num_heads, H * W)
        scale = 1.0 / math.sqrt(C // self.num_heads)
        attn = torch.einsum('bhdn,bhdm->bhnm', q, k) * scale
        attn = F.softmax(attn, dim=-1)
        out = torch.einsum('bhnm,bhdm->bhdn', attn, v)
        out = out.contiguous().view(B, C, H, W)
        out = self.proj_out(out)
        return x + out

- UNet model architecture

In [10]:
class OriginalDDPMUNet(nn.Module):
    def __init__(self, in_ch=3, base_ch=128, channel_mult=(1,2,2,2),
                 attn_resolutions=(16,), num_res_blocks=2, dropout=0.1, time_emb_dim=256):
        super().__init__()
        self.in_ch = in_ch
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )

        self.init_conv = nn.Conv2d(in_ch, base_ch, kernel_size=3, padding=1)

        # compute channel list per level
        chs = [base_ch * m for m in channel_mult]  # e.g. [128,256,256,256]
        in_out = list(zip([base_ch] + chs[:-1], chs))

        # Down path (we record skip channel sizes)
        self.down_blocks = nn.ModuleList()
        self.skip_channels = []
        curr_res = CFG.image_size
        for i, (lvl_in, lvl_out) in enumerate(in_out):
            block_layers = nn.ModuleList()
            attn_layers = nn.ModuleList()
            for j in range(num_res_blocks):
                in_ch_block = lvl_in if j == 0 else lvl_out
                block_layers.append(ResnetBlock(in_ch_block, lvl_out, time_emb_dim, dropout))
            if curr_res in attn_resolutions:
                attn_layers.append(AttentionBlock(lvl_out))
            down_sample = nn.Conv2d(lvl_out, lvl_out, kernel_size=3, stride=2, padding=1) if i < len(in_out)-1 else nn.Identity()
            self.down_blocks.append(nn.ModuleList([block_layers, attn_layers, down_sample]))
            self.skip_channels.append(lvl_out)
            curr_res //= 2

        # Middle
        mid_ch = chs[-1]
        self.mid_block1 = ResnetBlock(mid_ch, mid_ch, time_emb_dim, dropout)
        self.mid_attn = AttentionBlock(mid_ch)
        self.mid_block2 = ResnetBlock(mid_ch, mid_ch, time_emb_dim, dropout)

        # Up path: build using recorded skip_channels in reverse
        self.up_blocks = nn.ModuleList()
        curr_ch = mid_ch
        # spatial resolution at the bottleneck
        curr_res = max(1, CFG.image_size // (2 ** (len(chs)-1)))

        # Use reversed skip_channels so we pop them in the same order in forward
        rev_skip_chs = list(reversed(self.skip_channels))  # e.g. [lvl3, lvl2, lvl1, lvl0]
        for i, skip_ch in enumerate(rev_skip_chs):
            block_layers = nn.ModuleList()
            attn_layers = nn.ModuleList()

            # first block consumes concatenated channels: curr_ch + skip_ch -> skip_ch
            block_layers.append(ResnetBlock(curr_ch + skip_ch, skip_ch, time_emb_dim, dropout))
            # subsequent blocks (if any) are skip_ch -> skip_ch
            for _ in range(num_res_blocks):
                block_layers.append(ResnetBlock(skip_ch, skip_ch, time_emb_dim, dropout))

            if curr_res in attn_resolutions:
                attn_layers.append(AttentionBlock(skip_ch))

            # upsample that preserves channel count of "h" (curr_ch)
            # For the last iteration (when we've reached the highest resolution), we DON'T upsample
            if i < len(rev_skip_chs) - 1:
                up_sample = nn.ConvTranspose2d(curr_ch, curr_ch, kernel_size=4, stride=2, padding=1)
            else:
                up_sample = nn.Identity()

            self.up_blocks.append(nn.ModuleList([block_layers, attn_layers, up_sample]))

            # after this level, h will have channels = skip_ch
            curr_ch = skip_ch
            curr_res *= 2

        # final layers (GroupNorm -> SiLU -> Conv)
        # After the final up-block the number of channels equals `base_ch`,
        # so final_norm should use base_ch groups/channels and final_conv maps
        # back to the input image channels (e.g. 3 for RGB).
        self.final_norm = nn.GroupNorm(8, base_ch)
        self.final_act = nn.SiLU()
        self.final_conv = nn.Conv2d(base_ch, in_ch, kernel_size=3, padding=1)

    def forward(self, x, t):
        t_emb = sinusoidal_positional_embedding(t, self.time_mlp[0].in_features)
        t_emb = self.time_mlp(t_emb)

        h = self.init_conv(x)
        skips: List[torch.Tensor] = []

        for block_layers, attn_layers, down_sample in self.down_blocks:
            for block in block_layers:
                h = block(h, t_emb)
            for attn in attn_layers:
                h = attn(h)
            skips.append(h)
            h = down_sample(h)

        h = self.mid_block1(h, t_emb)
        h = self.mid_attn(h)
        h = self.mid_block2(h, t_emb)

        # up path: pop skips in reverse order
        for block_layers, attn_layers, up_sample in self.up_blocks:
            h = up_sample(h)
            if len(skips) == 0:
                raise RuntimeError("Skip stack empty — mismatch between down and up blocks.")
            skip = skips.pop()
            if h.shape[2:] != skip.shape[2:]:
                skip = F.interpolate(skip, size=h.shape[2:], mode='nearest')
            h = torch.cat([h, skip], dim=1)  # concatenated channels match first ResnetBlock constructor
            for block in block_layers:
                h = block(h, t_emb)
            for attn in attn_layers:
                h = attn(h)

        h = self.final_norm(h)
        h = self.final_act(h)
        out = self.final_conv(h)
        return out

- Dataloader ( I have not used any data augmentation which has been used in the paper e.g. horizontal flipping )

In [11]:
def get_dataloader(batch_size, image_size, train=True):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
    ])
    ds = torchvision.datasets.CIFAR10(root='./data', train=train, download=True, transform=transform)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=train, num_workers=2, pin_memory=True)
    return loader

- q_sample & p_sample_loop

In [12]:
def q_sample(x0: torch.Tensor, t: torch.LongTensor, noise: torch.Tensor, series):
    sqrt_ac = series['sqrt_alphas_cumprod'][t].view(-1,1,1,1)
    sqrt_om = series['sqrt_one_minus_alphas_cumprod'][t].view(-1,1,1,1)
    return sqrt_ac * x0 + sqrt_om * noise


@torch.no_grad()
def p_sample_loop(model: nn.Module, shape: Tuple[int,int,int,int], series, device, progress=False):
    model.eval()
    B = shape[0]
    x = torch.randn(shape, device=device)
    T = series['betas'].shape[0]
    rng = range(T-1, -1, -1)
    if progress:
        rng = tqdm(rng, desc='sampling')
    for t in rng:
        t_tensor = torch.full((B,), t, dtype=torch.long, device=device)
        eps_pred = model(x, t_tensor)

        beta_t = series['betas'][t]
        alpha_t = series['alphas'][t]
        alpha_cumprod_t = series['alphas_cumprod'][t]

        x0_pred = (x - torch.sqrt(1 - alpha_cumprod_t) * eps_pred) / torch.sqrt(alpha_cumprod_t)

        if t > 0:
            posterior_var = series['posterior_variance'][t]
            mean = ((beta_t * torch.sqrt(series['alphas_cumprod'][t-1]) / (1.0 - alpha_cumprod_t)) * x0_pred
                    + ((1.0 - series['alphas_cumprod'][t-1]) * torch.sqrt(alpha_t) / (1.0 - alpha_cumprod_t)) * x)
            noise = torch.randn_like(x)
            x = mean + torch.sqrt(posterior_var) * noise
        else:
            mean = ((beta_t * torch.sqrt(series['alphas_cumprod'][t-1]) / (1.0 - alpha_cumprod_t)) * x0_pred
                    + ((1.0 - series['alphas_cumprod'][t-1]) * torch.sqrt(alpha_t) / (1.0 - alpha_cumprod_t)) * x)
            x = mean

    x = torch.clamp(x, -1.0, 1.0)
    return x


- FID 

In [13]:
def evaluate_fid(model: nn.Module, series, device, num_gen=5000, batch_size=128):
    if FrechetInceptionDistance is None:
        print('torchmetrics FID not available; skipping FID.')
        return None
    print(f'Computing FID with {num_gen} generated images (batch_size {batch_size})...')
    fid = FrechetInceptionDistance(feature=2048).to(device)
    real_loader = get_dataloader(batch_size=batch_size, image_size=CFG.image_size, train=True)
    real_count = 0
    for x_real, _ in tqdm(real_loader, desc='Updating FID with real images'):
        x_real = x_real.to(device)
        imgs_uint8 = ((x_real.clamp(-1,1)+1.0)/2.0*255.0).to(torch.uint8)
        fid.update(imgs_uint8, real=True)
        real_count += imgs_uint8.shape[0]
        if real_count >= num_gen:
            break
    gen_count = 0
    gen_bs = min(batch_size, 64)
    while gen_count < num_gen:
        to_gen = min(gen_bs, num_gen - gen_count)
        samples = p_sample_loop(model, (to_gen, CFG.channels, CFG.image_size, CFG.image_size), series, device, progress=False)
        imgs_uint8 = ((samples.clamp(-1,1)+1.0)/2.0*255.0).to(torch.uint8)
        fid.update(imgs_uint8, real=False)
        gen_count += imgs_uint8.shape[0]
        print(f'Generated {gen_count}/{num_gen} for FID', end='')
    result = fid.compute().item()
    print(f'FID: {result:.4f}')
    return result

- Training loop ( There is two training loops 1st i have used for saving the models after each epoch and the 2nd only after 10 epochs and generate the images(and compute the FID))

In [None]:
# def train():
#     os.makedirs(CFG.out_dir, exist_ok=True)
#     loader = get_dataloader(CFG.batch_size, CFG.image_size, train=True)

#     model = OriginalDDPMUNet(
#         in_ch=CFG.channels,
#         base_ch=CFG.base_channels,
#         channel_mult=CFG.channel_mult,
#         attn_resolutions=CFG.attn_resolutions,
#         num_res_blocks=CFG.num_res_blocks,
#         dropout=CFG.dropout
#     ).to(CFG.device)

#     opt = optim.Adam(model.parameters(), lr=CFG.lr)
#     series = make_diffusion_series(CFG.diffusion_steps, CFG.beta_start, CFG.beta_end, CFG.device)

#     global_step = 0
#     print('Training on', CFG.device)
#     print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')

#     for epoch in range(CFG.epochs):
#         model.train()
#         pbar = tqdm(loader, desc=f'Epoch {epoch+1}/{CFG.epochs}')
#         for x, _ in pbar:
#             x = x.to(CFG.device)
#             b = x.shape[0]
#             t = torch.randint(0, CFG.diffusion_steps, (b,), device=CFG.device, dtype=torch.long)
#             noise = torch.randn_like(x)
#             x_t = q_sample(x, t, noise, series)

#             eps_pred = model(x_t, t)
#             loss = F.mse_loss(eps_pred, noise)

#             opt.zero_grad()
#             loss.backward()
#             opt.step()

#             global_step += 1
#             pbar.set_postfix({'loss': float(loss.item()), 'step': global_step})

#             if global_step % CFG.save_every == 0:
#                 model.eval()
#                 with torch.no_grad():
#                     samples = p_sample_loop(model, (CFG.sample_batch, CFG.channels, CFG.image_size, CFG.image_size), series, CFG.device, progress=False)
#                     grid = (samples + 1.0) / 2.0
#                     save_path = Path(CFG.out_dir) / f'samples_step_{global_step}.png'
#                     save_image(grid, str(save_path), nrow=4)
#                     print(f'Saved samples to {save_path}')
#                 model.train()

#         # epoch checkpoint & sample
#         model.eval()
#         with torch.no_grad():
#             samples = p_sample_loop(model, (CFG.sample_batch, CFG.channels, CFG.image_size, CFG.image_size), series, CFG.device, progress=False)
#             grid = (samples + 1.0) / 2.0
#             save_path = Path(CFG.out_dir) / f'samples_epoch_{epoch+1}.png'
#             save_image(grid, str(save_path), nrow=4)
#             print(f'Saved epoch samples to {save_path}')

#         ckpt = Path(CFG.out_dir) / f'ddpm_original_epoch_{epoch+1}.pt'
#         torch.save({'model': model.state_dict(), 'opt': opt.state_dict(), 'epoch': epoch+1}, ckpt)
#         print(f'Saved checkpoint {ckpt}')
#         model.train()

#     print('Training finished.')
#     return model, series

In [14]:
def train():
    os.makedirs(CFG.out_dir, exist_ok=True)
    loader = get_dataloader(CFG.batch_size, CFG.image_size, train=True)

    model = OriginalDDPMUNet(
        in_ch=CFG.channels,
        base_ch=CFG.base_channels,
        channel_mult=CFG.channel_mult,
        attn_resolutions=CFG.attn_resolutions,
        num_res_blocks=CFG.num_res_blocks,
        dropout=CFG.dropout
    ).to(CFG.device)

    opt = optim.Adam(model.parameters(), lr=CFG.lr)
    series = make_diffusion_series(CFG.diffusion_steps, CFG.beta_start, CFG.beta_end, CFG.device)

    global_step = 0
    print('Training on', CFG.device)
    print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')

    for epoch in range(CFG.epochs):
        model.train()
        pbar = tqdm(loader, desc=f'Epoch {epoch+1}/{CFG.epochs}')
        for x, _ in pbar:
            x = x.to(CFG.device)
            b = x.shape[0]
            t = torch.randint(0, CFG.diffusion_steps, (b,), device=CFG.device, dtype=torch.long)
            noise = torch.randn_like(x)
            x_t = q_sample(x, t, noise, series)

            eps_pred = model(x_t, t)
            loss = F.mse_loss(eps_pred, noise)

            opt.zero_grad()
            loss.backward()
            opt.step()

            global_step += 1
            pbar.set_postfix({'loss': float(loss.item()), 'step': global_step})

        # -------------------------
        # Save & evaluate every 10 epochs
        # -------------------------
        if (epoch + 1) % 20 == 0:
            model.eval()
            with torch.no_grad():
                # Generate samples
                samples = p_sample_loop(
                    model, (CFG.sample_batch, CFG.channels, CFG.image_size, CFG.image_size),
                    series, CFG.device, progress=False
                )
                grid = (samples + 1.0) / 2.0

                # Compute FID
                fid_score = evaluate_fid(model, series, CFG.device, num_gen=2000, batch_size=128)

                # Save image with FID in filename
                save_path = Path(CFG.out_dir) / f'samples_epoch_{epoch+1}_fid{fid_score:.2f}.png'
                save_image(grid, str(save_path), nrow=4)
                print(f'Saved epoch {epoch+1} samples with FID={fid_score:.2f} -> {save_path}')

            # Save checkpoint
            ckpt = Path(CFG.out_dir) / f'ddpm_original_epoch_{epoch+1}.pt'
            torch.save({'model': model.state_dict(), 'opt': opt.state_dict(), 'epoch': epoch+1}, ckpt)
            print(f'Saved checkpoint {ckpt}')
            model.train()

    print('Training finished.')
    return model, series


- Seed everything for repeatability

In [None]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

model, series = train()

    # Final sampling after training
with torch.no_grad():
    samples = p_sample_loop(
        model,
        (CFG.sample_batch, CFG.channels, CFG.image_size, CFG.image_size),
        series,
        CFG.device,
        progress=True
        )
    grid = (samples + 1.0) / 2.0
    save_image(grid, str(Path(CFG.out_dir) / "final_samples_128.png"), nrow=4)
    print("Saved final samples.")

Training on cuda
Model parameters: 30,471,427


Epoch 1/100: 100%|██████████| 1563/1563 [10:00<00:00,  2.60it/s, loss=0.0427, step=1563]
Epoch 2/100: 100%|██████████| 1563/1563 [10:10<00:00,  2.56it/s, loss=0.122, step=3126] 
Epoch 3/100: 100%|██████████| 1563/1563 [10:13<00:00,  2.55it/s, loss=0.026, step=4689] 
Epoch 4/100: 100%|██████████| 1563/1563 [10:11<00:00,  2.56it/s, loss=0.0706, step=6252]
Epoch 5/100: 100%|██████████| 1563/1563 [10:12<00:00,  2.55it/s, loss=0.0397, step=7815]
Epoch 6/100: 100%|██████████| 1563/1563 [10:10<00:00,  2.56it/s, loss=0.0622, step=9378]
Epoch 7/100: 100%|██████████| 1563/1563 [10:11<00:00,  2.56it/s, loss=0.0627, step=10941]
Epoch 8/100: 100%|██████████| 1563/1563 [10:13<00:00,  2.55it/s, loss=0.0585, step=12504]
Epoch 9/100: 100%|██████████| 1563/1563 [10:12<00:00,  2.55it/s, loss=0.0594, step=14067]
Epoch 10/100: 100%|██████████| 1563/1563 [10:12<00:00,  2.55it/s, loss=0.0429, step=15630]
Epoch 11/100: 100%|██████████| 1563/1563 [10:14<00:00,  2.54it/s, loss=0.058, step=17193] 
Epoch 12/100: 

Computing FID with 2000 generated images (batch_size 128)...


Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:00<00:00, 287MB/s] 
Updating FID with real images:   4%|▍         | 15/391 [00:08<03:22,  1.86it/s]


Generated 64/2000 for FIDGenerated 128/2000 for FIDGenerated 192/2000 for FIDGenerated 256/2000 for FIDGenerated 320/2000 for FIDGenerated 384/2000 for FIDGenerated 448/2000 for FIDGenerated 512/2000 for FIDGenerated 576/2000 for FIDGenerated 640/2000 for FIDGenerated 704/2000 for FIDGenerated 768/2000 for FIDGenerated 832/2000 for FIDGenerated 896/2000 for FIDGenerated 960/2000 for FIDGenerated 1024/2000 for FIDGenerated 1088/2000 for FIDGenerated 1152/2000 for FIDGenerated 1216/2000 for FIDGenerated 1280/2000 for FIDGenerated 1344/2000 for FIDGenerated 1408/2000 for FIDGenerated 1472/2000 for FIDGenerated 1536/2000 for FIDGenerated 1600/2000 for FIDGenerated 1664/2000 for FIDGenerated 1728/2000 for FIDGenerated 1792/2000 for FIDGenerated 1856/2000 for FIDGenerated 1920/2000 for FIDGenerated 1984/2000 for FIDGenerated 2000/2000 for FIDFID: 214.4690
Saved epoch 20 samples with FID=214.47 -> ddpm_original_runs_1/samples_epoch_20_fid214.47.png
Saved checkpoint ddpm_original_runs_1/ddpm_o

Epoch 21/100: 100%|██████████| 1563/1563 [10:11<00:00,  2.56it/s, loss=0.0429, step=32823]
Epoch 22/100: 100%|██████████| 1563/1563 [10:13<00:00,  2.55it/s, loss=0.0731, step=34386]
Epoch 23/100: 100%|██████████| 1563/1563 [10:13<00:00,  2.55it/s, loss=0.0653, step=35949]
Epoch 24/100: 100%|██████████| 1563/1563 [10:13<00:00,  2.55it/s, loss=0.0353, step=37512]
Epoch 25/100: 100%|██████████| 1563/1563 [10:15<00:00,  2.54it/s, loss=0.0742, step=39075]
Epoch 26/100: 100%|██████████| 1563/1563 [10:12<00:00,  2.55it/s, loss=0.0279, step=40638]
Epoch 27/100: 100%|██████████| 1563/1563 [10:13<00:00,  2.55it/s, loss=0.0341, step=42201]
Epoch 28/100: 100%|██████████| 1563/1563 [10:13<00:00,  2.55it/s, loss=0.0542, step=43764]
Epoch 29/100: 100%|██████████| 1563/1563 [10:14<00:00,  2.54it/s, loss=0.0465, step=45327]
Epoch 30/100: 100%|██████████| 1563/1563 [10:14<00:00,  2.54it/s, loss=0.0608, step=46890]
Epoch 31/100: 100%|██████████| 1563/1563 [10:14<00:00,  2.54it/s, loss=0.0657, step=48453]

Computing FID with 2000 generated images (batch_size 128)...


Updating FID with real images:   4%|▍         | 15/391 [00:08<03:20,  1.87it/s]


Generated 64/2000 for FIDGenerated 128/2000 for FIDGenerated 192/2000 for FIDGenerated 256/2000 for FIDGenerated 320/2000 for FIDGenerated 384/2000 for FIDGenerated 448/2000 for FIDGenerated 512/2000 for FIDGenerated 576/2000 for FIDGenerated 640/2000 for FIDGenerated 704/2000 for FIDGenerated 768/2000 for FIDGenerated 832/2000 for FIDGenerated 896/2000 for FIDGenerated 960/2000 for FIDGenerated 1024/2000 for FIDGenerated 1088/2000 for FIDGenerated 1152/2000 for FIDGenerated 1216/2000 for FIDGenerated 1280/2000 for FIDGenerated 1344/2000 for FIDGenerated 1408/2000 for FIDGenerated 1472/2000 for FIDGenerated 1536/2000 for FIDGenerated 1600/2000 for FIDGenerated 1664/2000 for FIDGenerated 1728/2000 for FIDGenerated 1792/2000 for FIDGenerated 1856/2000 for FIDGenerated 1920/2000 for FIDGenerated 1984/2000 for FIDGenerated 2000/2000 for FIDFID: 240.9081
Saved epoch 40 samples with FID=240.91 -> ddpm_original_runs_1/samples_epoch_40_fid240.91.png
Saved checkpoint ddpm_original_runs_1/ddpm_o

Epoch 41/100: 100%|██████████| 1563/1563 [10:13<00:00,  2.55it/s, loss=0.0457, step=64083]
Epoch 42/100: 100%|██████████| 1563/1563 [10:13<00:00,  2.55it/s, loss=0.0346, step=65646]
Epoch 43/100: 100%|██████████| 1563/1563 [10:14<00:00,  2.55it/s, loss=0.0268, step=67209]
Epoch 44/100: 100%|██████████| 1563/1563 [10:14<00:00,  2.54it/s, loss=0.073, step=68772] 
Epoch 45/100: 100%|██████████| 1563/1563 [10:14<00:00,  2.54it/s, loss=0.034, step=70335] 
Epoch 46/100: 100%|██████████| 1563/1563 [10:14<00:00,  2.54it/s, loss=0.0449, step=71898]
Epoch 47/100: 100%|██████████| 1563/1563 [10:14<00:00,  2.54it/s, loss=0.0596, step=73461]
Epoch 48/100: 100%|██████████| 1563/1563 [10:14<00:00,  2.54it/s, loss=0.0371, step=75024]
Epoch 49/100:  65%|██████▌   | 1021/1563 [06:41<03:33,  2.54it/s, loss=0.0314, step=76046]