In [1]:
# Install required packages
!pip install torch torchvision torchmetrics torch-fidelity numpy matplotlib scikit-image
!pip install pytorch-fid  # For FID calculation
!pip install tqdm
# Install required packages
!pip install -q kornia tqdm einops

Collecting torchmetrics
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)


In [2]:
import math
import copy
import torch
from torch import nn, einsum
import torch.nn.functional as F
from inspect import isfunction
from functools import partial

from torch.utils import data
from pathlib import Path
from torch.optim import Adam
from torchvision import datasets, transforms, utils

import numpy as np
from tqdm import tqdm
from einops import rearrange

import matplotlib.pyplot as plt
from PIL import Image
import os
import shutil
import torchvision
from torchvision.datasets import MNIST, CIFAR10
import pytorch_fid

from skimage.metrics import structural_similarity as ssim
from pytorch_fid import fid_score

# colab setup
from google.colab import drive
drive.mount('/content/drive')
res_dir = '/content/drive/MyDrive/Project'

# helper funcs
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def cycle(dl):
    while True:
        for data in dl:
            yield data

def num_to_groups(num, divisor):
    grps = num // divisor
    leftover = num % divisor
    arr = [divisor] * grps
    if leftover > 0:
        arr.append(leftover)
    return arr

# small helpers
class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old, new = ma_params.data, current_params.data
            ma_params.data = self.update_average(old, new)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        dev = x.device
        half_d = self.dim // 2
        emb = math.log(10000) / (half_d - 1)
        emb = torch.exp(torch.arange(half_d, device=dev) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)

class LayerNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))

    def forward(self, x):
        var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
        mean = torch.mean(x, dim = 1, keepdim = True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

class ConvNextBlock(nn.Module):
    # from that paper
    def __init__(self, dim, dim_out, *, time_emb_dim = None, mult = 2, norm = True):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.GELU(),
            nn.Linear(time_emb_dim, dim)
        ) if exists(time_emb_dim) else None

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)

        self.net = nn.Sequential(
            LayerNorm(dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding = 1),
            nn.GELU(),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding = 1)
        )

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):
        h = self.ds_conv(x)

        if exists(self.mlp):
            assert exists(time_emb), 'time emb must be passed in'
            cond = self.mlp(time_emb)
            h = h + rearrange(cond, 'b c -> b c 1 1')

        h = self.net(h)
        return h + self.res_conv(x)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv)
        q = q * self.scale

        k = k.softmax(dim = -1)
        ctx = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', ctx, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
        return self.to_out(out)

class Unet(nn.Module):
    def __init__(
        self,
        dim,
        out_dim = None,
        dim_mults=(1, 2, 4, 8),
        channels = 3,
        with_time_emb = True,
        residual = False
    ):
        super().__init__()
        self.channels = channels
        self.residual = residual
        print("Is Time embed used ? ", with_time_emb)

        dims = [channels, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        if with_time_emb:
            time_dim = dim
            self.time_mlp = nn.Sequential(
                SinusoidalPosEmb(dim),
                nn.Linear(dim, dim * 4),
                nn.GELU(),
                nn.Linear(dim * 4, dim)
            )
        else:
            time_dim = None
            self.time_mlp = None

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                ConvNextBlock(dim_in, dim_out, time_emb_dim = time_dim, norm = ind != 0),
                ConvNextBlock(dim_out, dim_out, time_emb_dim = time_dim),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                Downsample(dim_out) if not is_last else nn.Identity()
            ]))

        mid_dim = dims[-1]
        self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim)))
        self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, time_emb_dim = time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(nn.ModuleList([
                ConvNextBlock(dim_out * 2, dim_in, time_emb_dim = time_dim),
                ConvNextBlock(dim_in, dim_in, time_emb_dim = time_dim),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                Upsample(dim_in) if not is_last else nn.Identity()
            ]))

        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            ConvNextBlock(dim, dim),
            nn.Conv2d(dim, out_dim, 1)
        )

    def forward(self, x, time):
        orig_x = x
        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        for convnext, convnext2, attn, downsample in self.downs:
            x = convnext(x, t)
            x = convnext2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for convnext, convnext2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = convnext(x, t)
            x = convnext2(x, t)
            x = attn(x)
            x = upsample(x)

        if self.residual:
            return self.final_conv(x) + orig_x

        return self.final_conv(x)

# gaussian diffusion trainer class - from the second code
class GaussianDiffusion(nn.Module):
    def __init__(
        self,
        denoise_fn,
        *,
        image_size,
        device_of_kernel,
        channels = 3,
        timesteps = 1000,
        loss_type = 'l1',
        kernel_std = 0.1,
        kernel_size = 3,
        blur_routine = 'Incremental',
        train_routine = 'Final',
        sampling_routine='default',
        discrete=False
    ):
        super().__init__()
        self.channels = channels
        self.image_size = image_size
        self.denoise_fn = denoise_fn
        self.device_of_kernel = device_of_kernel

        self.num_timesteps = int(timesteps)
        self.loss_type = loss_type
        self.kernel_std = kernel_std
        self.kernel_size = kernel_size
        self.blur_routine = blur_routine

        # Create Gaussian kernels for each timestep
        self.gaussian_kernels = nn.ModuleList(self.get_kernels())
        self.train_routine = train_routine
        self.sampling_routine = sampling_routine
        self.discrete = discrete

    def blur(self, dims, std):
        """Creates a Gaussian blur kernel with the given dimensions and standard deviation"""
        kernel = self._create_gaussian_kernel(dims, std)
        return kernel

    def get_kernels(self):
        """Creates all Gaussian blur kernels for all timesteps"""
        kernels = []
        for i in range(self.num_timesteps):
            if self.blur_routine == 'Incremental':
                kstd = self.kernel_std*(i+1)
                kernels.append(self.get_conv(self.kernel_size, kstd))
            elif self.blur_routine == 'Constant':
                # For MNIST: fixed standard deviation
                kernels.append(self.get_conv(self.kernel_size, self.kernel_std))
            elif self.blur_routine == 'Exponential':
                kstd = np.exp(self.kernel_std * i)
                kernels.append(self.get_conv(self.kernel_size, kstd))
            elif self.blur_routine == 'CIFAR':
                # For CIFAR-10: 0.01*t + 0.35
                kstd = 0.01 * i + 0.35
                kernels.append(self.get_conv(self.kernel_size, kstd))
        return kernels

    def get_conv(self, kernel_size, std, mode='circular'):
        """Creates a convolution with Gaussian kernel"""
        kernel = self._create_gaussian_kernel(kernel_size, std)
        conv = nn.Conv2d(in_channels=self.channels, out_channels=self.channels,
                        kernel_size=kernel_size, padding=int((kernel_size-1)/2), padding_mode=mode,
                        bias=False, groups=self.channels)
        with torch.no_grad():
            # Ensure kernel has the right shape before repeating
            if kernel.dim() == 2:
                kernel = kernel.unsqueeze(0).unsqueeze(0)

            # Now repeat for each channel
            kernel = kernel.repeat(self.channels, 1, 1, 1)
            conv.weight = nn.Parameter(kernel)

        return conv

    def _create_gaussian_kernel(self, kernel_size, std):
        """Create a 2D Gaussian kernel manually"""
        x = torch.arange(-(kernel_size // 2), kernel_size // 2 + 1, 1).to(self.device_of_kernel)
        y = torch.arange(-(kernel_size // 2), kernel_size // 2 + 1, 1).to(self.device_of_kernel)
        xx, yy = torch.meshgrid(x, y, indexing='ij')
        kernel = torch.exp(-(xx**2 + yy**2) / (2 * std**2))
        kernel = kernel / kernel.sum()
        return kernel

    def q_sample(self, x_start, t):
        """Forward process: progressively blur the image based on timestep t"""
        max_iters = torch.max(t)
        all_blurs = []
        x = x_start
        for i in range(max_iters+1):
            with torch.no_grad():
                x = self.gaussian_kernels[i](x)
                if self.discrete:
                    if i == (self.num_timesteps-1):
                        x = torch.mean(x, [2, 3], keepdim=True)
                        x = x.expand(x_start.shape[0], x_start.shape[1], x_start.shape[2], x_start.shape[3])
                all_blurs.append(x)

        all_blurs = torch.stack(all_blurs)

        # Select the appropriate blur level for each sample in the batch
        choose_blur = []
        for step in range(t.shape[0]):
            if step != -1:
                choose_blur.append(all_blurs[t[step], step])
            else:
                choose_blur.append(x_start[step])

        choose_blur = torch.stack(choose_blur)
        return choose_blur

    def p_losses(self, x_start, t):
        """Calculate training loss"""
        b, c, h, w = x_start.shape
        if self.train_routine == 'Final':
            x_blur = self.q_sample(x_start=x_start, t=t)
            x_recon = self.denoise_fn(x_blur, t)
            if self.loss_type == 'l1':
                loss = (x_start - x_recon).abs().mean()
            elif self.loss_type == 'l2':
                loss = F.mse_loss(x_start, x_recon)
            else:
                raise NotImplementedError()
        return loss

    def forward(self, x, *args, **kwargs):
        """Forward pass during training"""
        b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
        assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
        return self.p_losses(x, t, *args, **kwargs)

    @torch.no_grad()
    def sample(self, batch_size=16, img=None, t=None):
        """Sampling algorithm - Algorithm 2 from the paper"""
        self.denoise_fn.eval()

        if t is None:
            t = self.num_timesteps

        # Apply forward process (blur)
        if self.blur_routine == 'Individual_Incremental':
            img = self.gaussian_kernels[t-1](img)
        else:
            for i in range(t):
                with torch.no_grad():
                    img = self.gaussian_kernels[i](img)

        # Store blurred input for later
        xt = img.clone()
        direct_recons = None

        # Iterative deblurring (reverse process)
        while(t):
            step = torch.full((batch_size,), t - 1, dtype=torch.long).to(img.device)
            x0_pred = self.denoise_fn(img, step)  # Predict clean image

            if self.train_routine == 'Final':
                if direct_recons is None:
                    direct_recons = x0_pred

                if self.sampling_routine == 'x0_step_down':
                    # Algorithm 2 from the paper
                    x_times = x0_pred.clone()
                    for i in range(t):
                        with torch.no_grad():
                            x_times = self.gaussian_kernels[i](x_times)
                            if self.discrete:
                                if i == (self.num_timesteps - 1):
                                    x_times = torch.mean(x_times, [2, 3], keepdim=True)
                                    x_times = x_times.expand(x_times.shape[0], x_times.shape[1],
                                                             x_times.shape[2], x_times.shape[3])

                    x_times_sub_1 = x0_pred.clone()
                    for i in range(t - 1):
                        with torch.no_grad():
                            x_times_sub_1 = self.gaussian_kernels[i](x_times_sub_1)

                    # Key step in Algorithm 2
                    img = img - x_times + x_times_sub_1

                elif self.sampling_routine == 'default':
                    # Algorithm 1 - for comparison
                    if self.blur_routine == 'Individual_Incremental':
                        img = self.gaussian_kernels[t - 2](x0_pred) if t > 1 else x0_pred
                    else:
                        img = x0_pred.clone()
                        for i in range(t-1):
                            with torch.no_grad():
                                img = self.gaussian_kernels[i](img)

            t = t - 1

        self.denoise_fn.train()
        return xt, direct_recons, img

class Dataset(data.Dataset):
    def __init__(self, folder, image_size, exts=['jpg', 'jpeg', 'png']):
        super().__init__()
        self.folder = folder
        self.image_size = image_size
        self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]

        self.transform = transforms.Compose([
            transforms.Resize((int(image_size*1.12), int(image_size*1.12))),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)  # Scale to [-1, 1]
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(path)
        return self.transform(img)

def calculate_metrics(orig_imgs, recon_imgs):
    """metrics calc - ssim/rmse/fid"""
    # convert to numpy
    orig_np = (orig_imgs.cpu().permute(0, 2, 3, 1) + 1) / 2
    recon_np = (recon_imgs.cpu().permute(0, 2, 3, 1) + 1) / 2

    # ssim
    ssim_vals = []
    for i in range(orig_np.shape[0]):
        orig = orig_np[i].numpy()
        recon = recon_np[i].numpy()

        if orig.shape[2] == 1:
            orig = orig.squeeze(2)
            recon = recon.squeeze(2)
            ssim_val = ssim(orig, recon, data_range=1.0, win_size=5, channel_axis=None)
        else:
            ssim_val = ssim(orig, recon, data_range=1.0, win_size=5, channel_axis=2)

        ssim_vals.append(ssim_val)

    avg_ssim = np.mean(ssim_vals)

    # rmse
    mse = F.mse_loss(orig_imgs, recon_imgs).item()
    rmse = np.sqrt(mse)

    # fid stuff
    tmp_orig = 'temp_orig'
    tmp_recon = 'temp_recon'
    os.makedirs(tmp_orig, exist_ok=True)
    os.makedirs(tmp_recon, exist_ok=True)

    # save images for fid
    for i in range(min(orig_imgs.shape[0], 100)):
        utils.save_image((orig_imgs[i] + 1) / 2, os.path.join(tmp_orig, f'{i}.png'))
        utils.save_image((recon_imgs[i] + 1) / 2, os.path.join(tmp_recon, f'{i}.png'))

    # calc fid
    fid = fid_score.calculate_fid_given_paths([tmp_orig, tmp_recon],
                                   batch_size=50, device=orig_imgs.device,
                                   dims=2048)
    # cleanup
    shutil.rmtree(tmp_orig)
    shutil.rmtree(tmp_recon)

    return avg_ssim, rmse, fid

class Trainer(object):
    def __init__(
        self,
        diffusion_model,
        dataset_type,
        *,
        ema_decay = 0.995,
        image_size = 32,
        train_batch_size = 32,
        eval_batch_size = 16,
        train_lr = 2e-5,
        train_num_steps = 100000,
        gradient_accumulate_every = 2,
        step_start_ema = 2000,
        update_ema_every = 10,
        save_and_sample_every = 1000,
        results_folder = './results',
        load_path = None,
        eval_dataset = None,
        save_model_every = 10000,
        eval_every = 10000
    ):
        super().__init__()
        self.model = diffusion_model
        self.ema = EMA(ema_decay)
        self.ema_model = copy.deepcopy(self.model)
        self.update_ema_every = update_ema_every

        self.step_start_ema = step_start_ema
        self.save_and_sample_every = save_and_sample_every
        self.save_model_every = save_model_every
        self.eval_every = eval_every

        self.batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.image_size = image_size
        self.gradient_accumulate_every = gradient_accumulate_every
        self.train_num_steps = train_num_steps

        self.dataset_type = dataset_type
        self.setup_dataset(dataset_type)

        self.eval_dataset = eval_dataset
        if eval_dataset is None:
            # use training subset
            eval_size = min(1000, len(self.ds))
            eval_indices = np.random.choice(len(self.ds), eval_size, replace=False)
            self.eval_dataset = torch.utils.data.Subset(self.ds, eval_indices)

        self.eval_dl = data.DataLoader(self.eval_dataset, batch_size=eval_batch_size,
                                    shuffle=False, drop_last=False)

        self.opt = Adam(diffusion_model.parameters(), lr=train_lr)
        self.step = 0

        self.results_folder = Path(results_folder)
        self.results_folder.mkdir(exist_ok=True, parents=True)

        # metrics file
        self.metrics_file = self.results_folder / 'metrics.csv'
        if not self.metrics_file.exists():
            with open(self.metrics_file, 'w') as f:
                f.write('step,ssim,rmse,fid\n')

        self.reset_parameters()

        if load_path is not None:
            self.load(load_path)

    def setup_dataset(self, dataset_type):
        if dataset_type == 'mnist':
            transform = transforms.Compose([
                transforms.Pad(2),
                transforms.ToTensor(),
                transforms.Lambda(lambda t: (t * 2) - 1)
            ])

            # get mnist
            self.ds = MNIST(root='./data', train=True, download=True, transform=transform)
            self.dl = cycle(data.DataLoader(self.ds, batch_size=self.batch_size, shuffle=True,
                                        pin_memory=True, num_workers=8, drop_last=True))

        elif dataset_type == 'cifar10':
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Lambda(lambda t: (t * 2) - 1)
            ])

            # get cifar
            self.ds = CIFAR10(root='./data', train=True, download=True, transform=transform)
            self.dl = cycle(data.DataLoader(self.ds, batch_size=self.batch_size, shuffle=True,
                                        pin_memory=True, num_workers=8, drop_last=True))

        else:
            raise ValueError(f"Unknown dataset type: {dataset_type}")

    def reset_parameters(self):
        self.ema_model.load_state_dict(self.model.state_dict())

    def step_ema(self):
        if self.step < self.step_start_ema:
            self.reset_parameters()
            return
        self.ema.update_model_average(self.ema_model, self.model)

    def save(self, milestone=None):
        data = {
            'step': self.step,
            'model': self.model.state_dict(),
            'ema': self.ema_model.state_dict()
        }
        if milestone is None:
            torch.save(data, str(self.results_folder / f'model.pt'))
        else:
            torch.save(data, str(self.results_folder / f'model_{milestone}.pt'))

    def load(self, load_path):
        print(f"Loading model from {load_path}")
        data = torch.load(load_path)

        self.step = data['step']
        self.model.load_state_dict(data['model'])
        self.ema_model.load_state_dict(data['ema'])

    def evaluate(self):
        """eval on test set"""
        self.ema_model.eval()

        all_orig_imgs = []
        all_blurred_imgs = []
        all_direct_recons = []
        all_deblurred_imgs = []

        with torch.no_grad():
            for batch in self.eval_dl:
                if isinstance(batch, (list, tuple)):
                    # handle (img, label)
                    orig_imgs = batch[0].to(self.model.device_of_kernel)
                else:
                    # handle just img
                    orig_imgs = batch.to(self.model.device_of_kernel)

                # use model for deblur
                blurred_imgs, direct_recons, deblurred_imgs = self.ema_model.sample(
                    batch_size=orig_imgs.shape[0], img=orig_imgs)

                all_orig_imgs.append(orig_imgs)
                all_blurred_imgs.append(blurred_imgs)
                all_direct_recons.append(direct_recons)
                all_deblurred_imgs.append(deblurred_imgs)

                # limit to 100 imgs
                if len(all_orig_imgs) * self.eval_batch_size >= 100:
                    break

        # cat everything
        orig_imgs = torch.cat(all_orig_imgs, dim=0)
        blurred_imgs = torch.cat(all_blurred_imgs, dim=0)
        direct_recons = torch.cat(all_direct_recons, dim=0)
        deblurred_imgs = torch.cat(all_deblurred_imgs, dim=0)

        # direct recon metrics
        direct_ssim, direct_rmse, direct_fid = calculate_metrics(orig_imgs, direct_recons)

        # alg 2 metrics
        deblurred_ssim, deblurred_rmse, deblurred_fid = calculate_metrics(orig_imgs, deblurred_imgs)

        # baseline metrics
        blurred_ssim, blurred_rmse, blurred_fid = calculate_metrics(orig_imgs, blurred_imgs)

        # results dict
        metrics = {
            'direct': {'ssim': direct_ssim, 'rmse': direct_rmse, 'fid': direct_fid},
            'deblurred': {'ssim': deblurred_ssim, 'rmse': deblurred_rmse, 'fid': deblurred_fid},
            'blurred': {'ssim': blurred_ssim, 'rmse': blurred_rmse, 'fid': blurred_fid}
        }

        # log to file
        with open(self.metrics_file, 'a') as f:
            f.write(f"{self.step},{deblurred_ssim},{deblurred_rmse},{deblurred_fid}\n")

        # make pictures
        vis_folder = self.results_folder / f"eval_{self.step}"
        vis_folder.mkdir(exist_ok=True)

        # save imgs
        n_samples = min(8, orig_imgs.shape[0])
        samples = torch.stack([
            orig_imgs[:n_samples],
            blurred_imgs[:n_samples],
            direct_recons[:n_samples],
            deblurred_imgs[:n_samples]
        ])
        samples = (samples + 1) / 2  # [0, 1] range

        # grid: diff samples in rows
        grid = utils.make_grid(samples.view(-1, *samples.shape[2:]), nrow=n_samples)
        utils.save_image(grid, vis_folder / 'samples.png')

        # plot metrics
        if self.step > 0:
            self.plot_metrics()

        print(f"Eval at step {self.step}:")
        print(f"Direct recon: SSIM={direct_ssim:.4f}, RMSE={direct_rmse:.4f}, FID={direct_fid:.2f}")
        print(f"Deblur (Alg 2): SSIM={deblurred_ssim:.4f}, RMSE={deblurred_rmse:.4f}, FID={deblurred_fid:.2f}")

        return metrics

    def plot_metrics(self):
        """plot stuff over time"""
        import pandas as pd
        import matplotlib.pyplot as plt

        df = pd.read_csv(self.metrics_file)

        fig, axs = plt.subplots(1, 3, figsize=(18, 5))

        # ssim plot
        axs[0].plot(df['step'], df['ssim'])
        axs[0].set_title('SSIM over time (higher=better)')
        axs[0].set_xlabel('Training steps')
        axs[0].set_ylabel('SSIM')

        # rmse plot
        axs[1].plot(df['step'], df['rmse'])
        axs[1].set_title('RMSE over time (lower=better)')
        axs[1].set_xlabel('Training steps')
        axs[1].set_ylabel('RMSE')

        # fid plot
        axs[2].plot(df['step'], df['fid'])
        axs[2].set_title('FID over time (lower=better)')
        axs[2].set_xlabel('Training steps')
        axs[2].set_ylabel('FID')

        plt.tight_layout()
        plt.savefig(self.results_folder / 'metrics_plot.png')
        plt.close()

    def train(self):
        """train the model"""
        device = self.model.device_of_kernel

        acc_loss = 0
        while self.step < self.train_num_steps:
            uloss = 0
            for i in range(self.gradient_accumulate_every):
                # get batch
                if self.dataset_type in ['mnist', 'cifar10']:
                    # handle (img, label)
                    data, _ = next(self.dl)
                    data = data.to(device)
                else:
                    # handle just imgs
                    data = next(self.dl).to(device)

                # forward
                loss = torch.mean(self.model(data))
                if self.step % 100 == 0:
                    print(f'Step {self.step}: Loss = {loss.item():.6f}')

                uloss += loss.item()

                # backward with grad accum
                loss = loss / self.gradient_accumulate_every
                loss.backward()

            # update
            self.opt.step()
            self.opt.zero_grad()

            # track avg loss
            acc_loss = acc_loss + (uloss / self.gradient_accumulate_every)

            # ema update
            if self.step % self.update_ema_every == 0:
                self.step_ema()

            # save stuff and metrics
            if self.step != 0 and self.step % self.save_and_sample_every == 0:
                milestone = self.step // self.save_and_sample_every
                print(f"\nSaving samples at step {self.step}")

                # get imgs for viz
                if self.dataset_type in ['mnist', 'cifar10']:
                    viz_batch, _ = next(self.dl)
                    viz_batch = viz_batch.to(device)
                else:
                    viz_batch = next(self.dl).to(device)

                # sample w/ ema
                with torch.no_grad():
                    blurred, direct_recons, deblurred = self.ema_model.sample(
                        batch_size=viz_batch.shape[0], img=viz_batch)

                # save imgs
                viz_folder = self.results_folder / f"samples_{self.step}"
                viz_folder.mkdir(exist_ok=True)

                # [0,1] range for viz
                viz_batch = (viz_batch + 1) * 0.5
                blurred = (blurred + 1) * 0.5
                direct_recons = (direct_recons + 1) * 0.5
                deblurred = (deblurred + 1) * 0.5

                # grid + save
                grid = torch.cat([
                    viz_batch[:8],
                    blurred[:8],
                    direct_recons[:8],
                    deblurred[:8]
                ])
                utils.save_image(grid, viz_folder / 'grid.png', nrow=8)

                # individual imgs
                utils.save_image(viz_batch[:8], viz_folder / 'original.png', nrow=4)
                utils.save_image(blurred[:8], viz_folder / 'blurred.png', nrow=4)
                utils.save_image(direct_recons[:8], viz_folder / 'direct_recons.png', nrow=4)
                utils.save_image(deblurred[:8], viz_folder / 'deblurred.png', nrow=4)

                # print avg loss
                avg_loss = acc_loss / (self.save_and_sample_every)
                print(f'Avg loss (last {self.save_and_sample_every} steps): {avg_loss:.6f}')
                acc_loss = 0

            # save checkpoint
            if self.step != 0 and self.step % self.save_model_every == 0:
                self.save(self.step)
                print(f"Model saved at step {self.step}")

            # eval model
            if self.step != 0 and self.step % self.eval_every == 0:
                print(f"\nEvaluating at step {self.step}")
                self.evaluate()

            self.step += 1

        # final stuff
        self.save()
        self.evaluate()
        print('Training done!')

# Run main
args = type('Args', (), {
    'dataset': 'cifar10',  # mnist or cifar10
    'batch_size': 32,
    'eval_batch_size': 16,
    'epochs': 50,
    'save_every': 10000,
    'eval_every': 10000,
    'sample_every': 1000,
    'results_dir': f"{res_dir}/blur_diffusion",
    'load_model': '/content/drive/MyDrive/Project/blur_diffusion/model_170000.pt',
    'train': True,
    'eval': True
})()

# device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# setup config
if args.dataset == 'mnist':
    channels = 1
    image_size = 32
    model_dim = 32
    timesteps = 40
    kernel_size = 11
    kernel_std = 7.0  # Fixed std=7 for MNIST
    blur_routine = 'Constant'  # Use constant for fixed std
elif args.dataset == 'cifar10':
    channels = 3
    image_size = 32
    model_dim = 64
    timesteps = 100
    kernel_size = 11
    kernel_std = 0.01  # For formula: 0.01*t + 0.35
    blur_routine = 'CIFAR'  # Custom routine for CIFAR-10

# make results dir
results_dir = Path(args.results_dir) / f"blur_{args.dataset}"
results_dir.mkdir(exist_ok=True, parents=True)

# make model
model = Unet(
    dim=model_dim,
    channels=channels,
    dim_mults=(1, 2, 4, 8),
    with_time_emb=True
).to(device)

# make diffusion
diffusion = GaussianDiffusion(
    model,
    image_size=image_size,
    device_of_kernel=device,
    channels=channels,
    timesteps=timesteps,
    loss_type='l1',
    kernel_std=kernel_std,
    kernel_size=kernel_size,
    blur_routine=blur_routine,
    train_routine='Final',
    sampling_routine='x0_step_down'  # Use Algorithm 2 from the paper
).to(device)

total_steps = 200000

# make trainer
trainer = Trainer(
    diffusion,
    args.dataset,
    image_size=image_size,
    train_batch_size=args.batch_size,
    eval_batch_size=args.eval_batch_size,
    train_lr=2e-5,
    train_num_steps=total_steps,
    save_and_sample_every=args.sample_every,
    save_model_every=args.save_every,
    eval_every=args.eval_every,
    results_folder=results_dir,
    load_path=args.load_model
)

if args.train:
    print(f"Starting training: {args.epochs} epochs ({total_steps} steps)...")
    trainer.train()

if args.eval:
    print("Evaluating model...")
    trainer.evaluate()

Mounted at /content/drive
Using device: cuda
Is Time embed used ?  True


100%|██████████| 170M/170M [00:12<00:00, 13.5MB/s]


Loading model from /content/drive/MyDrive/Project/cifar10/model_110000.pt
Starting training for 50 epochs (200000 steps)...
Step 110000: Loss = 0.055589
Step 110000: Loss = 0.051443

Saving samples at step 110000
Average loss over last 1000 steps: 0.000054
Model saved at step 110000

Evaluating model at step 110000


Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:02<00:00, 42.4MB/s]
100%|██████████| 2/2 [00:00<00:00,  3.42it/s]
100%|██████████| 2/2 [00:00<00:00,  4.24it/s]
100%|██████████| 2/2 [00:00<00:00,  3.59it/s]
100%|██████████| 2/2 [00:00<00:00,  3.93it/s]
100%|██████████| 2/2 [00:00<00:00,  3.49it/s]
100%|██████████| 2/2 [00:00<00:00,  3.82it/s]


Evaluation at step 110000:
Direct reconstruction: SSIM=0.8545, RMSE=0.0966, FID=184.41
Deblurred (Algorithm 2): SSIM=0.7787, RMSE=0.1327, FID=176.01
Step 110100: Loss = 0.040513
Step 110100: Loss = 0.050713
Step 110200: Loss = 0.053033
Step 110200: Loss = 0.054532
Step 110300: Loss = 0.052247
Step 110300: Loss = 0.058440
Step 110400: Loss = 0.047899
Step 110400: Loss = 0.052425
Step 110500: Loss = 0.056654
Step 110500: Loss = 0.057279
Step 110600: Loss = 0.053105
Step 110600: Loss = 0.051403
Step 110700: Loss = 0.055127
Step 110700: Loss = 0.053607




Step 110800: Loss = 0.051494
Step 110800: Loss = 0.049341
Step 110900: Loss = 0.050152
Step 110900: Loss = 0.058397
Step 111000: Loss = 0.056159
Step 111000: Loss = 0.052885

Saving samples at step 111000
Average loss over last 1000 steps: 0.053480
Step 111100: Loss = 0.048697
Step 111100: Loss = 0.053227
Step 111200: Loss = 0.052164
Step 111200: Loss = 0.050322
Step 111300: Loss = 0.054191
Step 111300: Loss = 0.050957
Step 111400: Loss = 0.054332
Step 111400: Loss = 0.049473
Step 111500: Loss = 0.047252
Step 111500: Loss = 0.051949
Step 111600: Loss = 0.042478
Step 111600: Loss = 0.060494
Step 111700: Loss = 0.054231
Step 111700: Loss = 0.054048
Step 111800: Loss = 0.046805
Step 111800: Loss = 0.052990
Step 111900: Loss = 0.049807
Step 111900: Loss = 0.051549
Step 112000: Loss = 0.060223
Step 112000: Loss = 0.052049

Saving samples at step 112000
Average loss over last 1000 steps: 0.053256
Step 112100: Loss = 0.052444
Step 112100: Loss = 0.057199
Step 112200: Loss = 0.057390
Step 1122

100%|██████████| 2/2 [00:00<00:00,  4.01it/s]
100%|██████████| 2/2 [00:00<00:00,  4.07it/s]
100%|██████████| 2/2 [00:00<00:00,  3.77it/s]
100%|██████████| 2/2 [00:00<00:00,  4.04it/s]
100%|██████████| 2/2 [00:00<00:00,  3.76it/s]
100%|██████████| 2/2 [00:00<00:00,  4.07it/s]


Evaluation at step 120000:
Direct reconstruction: SSIM=0.8695, RMSE=0.0906, FID=176.31
Deblurred (Algorithm 2): SSIM=0.7980, RMSE=0.1258, FID=170.33
Step 120100: Loss = 0.047140
Step 120100: Loss = 0.042617




Step 120200: Loss = 0.051119
Step 120200: Loss = 0.050028
Step 120300: Loss = 0.060766
Step 120300: Loss = 0.055664
Step 120400: Loss = 0.048456
Step 120400: Loss = 0.050442
Step 120500: Loss = 0.058009
Step 120500: Loss = 0.048300
Step 120600: Loss = 0.056807
Step 120600: Loss = 0.053776
Step 120700: Loss = 0.051912
Step 120700: Loss = 0.056474
Step 120800: Loss = 0.047658
Step 120800: Loss = 0.047239
Step 120900: Loss = 0.046451
Step 120900: Loss = 0.051546
Step 121000: Loss = 0.046599
Step 121000: Loss = 0.047136

Saving samples at step 121000
Average loss over last 1000 steps: 0.051100
Step 121100: Loss = 0.049037
Step 121100: Loss = 0.046721
Step 121200: Loss = 0.052056
Step 121200: Loss = 0.050697
Step 121300: Loss = 0.050662
Step 121300: Loss = 0.056075
Step 121400: Loss = 0.047638
Step 121400: Loss = 0.042601
Step 121500: Loss = 0.051191
Step 121500: Loss = 0.042312
Step 121600: Loss = 0.050919
Step 121600: Loss = 0.059549
Step 121700: Loss = 0.054360
Step 121700: Loss = 0.0480

100%|██████████| 2/2 [00:00<00:00,  3.97it/s]
100%|██████████| 2/2 [00:00<00:00,  3.97it/s]
100%|██████████| 2/2 [00:00<00:00,  3.58it/s]
100%|██████████| 2/2 [00:00<00:00,  3.95it/s]
100%|██████████| 2/2 [00:00<00:00,  3.43it/s]
100%|██████████| 2/2 [00:00<00:00,  3.75it/s]


Evaluation at step 130000:
Direct reconstruction: SSIM=0.8804, RMSE=0.0863, FID=172.75
Deblurred (Algorithm 2): SSIM=0.8096, RMSE=0.1223, FID=167.86
Step 130100: Loss = 0.046845
Step 130100: Loss = 0.051682
Step 130200: Loss = 0.048954
Step 130200: Loss = 0.048567




Step 130300: Loss = 0.054585
Step 130300: Loss = 0.049101
Step 130400: Loss = 0.046736
Step 130400: Loss = 0.047048
Step 130500: Loss = 0.046210
Step 130500: Loss = 0.051135
Step 130600: Loss = 0.052190
Step 130600: Loss = 0.048168
Step 130700: Loss = 0.047618
Step 130700: Loss = 0.051613
Step 130800: Loss = 0.045105
Step 130800: Loss = 0.042674
Step 130900: Loss = 0.047669
Step 130900: Loss = 0.046136
Step 131000: Loss = 0.049667
Step 131000: Loss = 0.042221

Saving samples at step 131000
Average loss over last 1000 steps: 0.049045
Step 131100: Loss = 0.050448
Step 131100: Loss = 0.044364
Step 131200: Loss = 0.045550
Step 131200: Loss = 0.046763
Step 131300: Loss = 0.046285
Step 131300: Loss = 0.046314
Step 131400: Loss = 0.048048
Step 131400: Loss = 0.048156
Step 131500: Loss = 0.052495
Step 131500: Loss = 0.046544
Step 131600: Loss = 0.045919
Step 131600: Loss = 0.052020
Step 131700: Loss = 0.043792
Step 131700: Loss = 0.047986
Step 131800: Loss = 0.050925
Step 131800: Loss = 0.0522

100%|██████████| 2/2 [00:00<00:00,  3.98it/s]
100%|██████████| 2/2 [00:00<00:00,  4.02it/s]
100%|██████████| 2/2 [00:00<00:00,  3.67it/s]
100%|██████████| 2/2 [00:00<00:00,  4.10it/s]
100%|██████████| 2/2 [00:00<00:00,  3.76it/s]
100%|██████████| 2/2 [00:00<00:00,  4.09it/s]


Evaluation at step 140000:
Direct reconstruction: SSIM=0.8894, RMSE=0.0824, FID=167.16
Deblurred (Algorithm 2): SSIM=0.8223, RMSE=0.1175, FID=161.72
Step 140100: Loss = 0.046706
Step 140100: Loss = 0.047970
Step 140200: Loss = 0.045365
Step 140200: Loss = 0.047231
Step 140300: Loss = 0.051599
Step 140300: Loss = 0.048840
Step 140400: Loss = 0.048932
Step 140400: Loss = 0.044949




Step 140500: Loss = 0.045501
Step 140500: Loss = 0.059164
Step 140600: Loss = 0.044780
Step 140600: Loss = 0.052061
Step 140700: Loss = 0.053378
Step 140700: Loss = 0.045849
Step 140800: Loss = 0.046075
Step 140800: Loss = 0.049140
Step 140900: Loss = 0.044312
Step 140900: Loss = 0.042193
Step 141000: Loss = 0.041250
Step 141000: Loss = 0.049316

Saving samples at step 141000
Average loss over last 1000 steps: 0.047162
Step 141100: Loss = 0.054058
Step 141100: Loss = 0.039757
Step 141200: Loss = 0.045897
Step 141200: Loss = 0.041824
Step 141300: Loss = 0.048781
Step 141300: Loss = 0.050944
Step 141400: Loss = 0.049455
Step 141400: Loss = 0.043074
Step 141500: Loss = 0.046852
Step 141500: Loss = 0.044051
Step 141600: Loss = 0.050910
Step 141600: Loss = 0.048987
Step 141700: Loss = 0.057186
Step 141700: Loss = 0.043217
Step 141800: Loss = 0.048123
Step 141800: Loss = 0.055546
Step 141900: Loss = 0.044772
Step 141900: Loss = 0.048426
Step 142000: Loss = 0.042239
Step 142000: Loss = 0.0419

100%|██████████| 2/2 [00:00<00:00,  4.03it/s]
100%|██████████| 2/2 [00:00<00:00,  4.03it/s]
100%|██████████| 2/2 [00:00<00:00,  3.80it/s]
100%|██████████| 2/2 [00:00<00:00,  3.94it/s]
100%|██████████| 2/2 [00:00<00:00,  3.37it/s]
100%|██████████| 2/2 [00:00<00:00,  3.65it/s]


Evaluation at step 150000:
Direct reconstruction: SSIM=0.8963, RMSE=0.0795, FID=160.59
Deblurred (Algorithm 2): SSIM=0.8332, RMSE=0.1131, FID=155.13
Step 150100: Loss = 0.043713
Step 150100: Loss = 0.046247
Step 150200: Loss = 0.044508
Step 150200: Loss = 0.048422
Step 150300: Loss = 0.046958
Step 150300: Loss = 0.044339
Step 150400: Loss = 0.052140
Step 150400: Loss = 0.047867
Step 150500: Loss = 0.041190
Step 150500: Loss = 0.048157




Step 150600: Loss = 0.042222
Step 150600: Loss = 0.043409
Step 150700: Loss = 0.047165
Step 150700: Loss = 0.040808
Step 150800: Loss = 0.050357
Step 150800: Loss = 0.040770
Step 150900: Loss = 0.045988
Step 150900: Loss = 0.042896
Step 151000: Loss = 0.041792
Step 151000: Loss = 0.051783

Saving samples at step 151000
Average loss over last 1000 steps: 0.045518
Step 151100: Loss = 0.037086
Step 151100: Loss = 0.045734
Step 151200: Loss = 0.044705
Step 151200: Loss = 0.043743
Step 151300: Loss = 0.052317
Step 151300: Loss = 0.042980
Step 151400: Loss = 0.044128
Step 151400: Loss = 0.050552
Step 151500: Loss = 0.044496
Step 151500: Loss = 0.045063
Step 151600: Loss = 0.048401
Step 151600: Loss = 0.043762
Step 151700: Loss = 0.048763
Step 151700: Loss = 0.050971
Step 151800: Loss = 0.041717
Step 151800: Loss = 0.041497
Step 151900: Loss = 0.046451
Step 151900: Loss = 0.046233
Step 152000: Loss = 0.048981
Step 152000: Loss = 0.039890

Saving samples at step 152000
Average loss over last 1

100%|██████████| 2/2 [00:00<00:00,  4.03it/s]
100%|██████████| 2/2 [00:00<00:00,  4.04it/s]
100%|██████████| 2/2 [00:00<00:00,  3.81it/s]
100%|██████████| 2/2 [00:00<00:00,  4.05it/s]
100%|██████████| 2/2 [00:00<00:00,  3.76it/s]
100%|██████████| 2/2 [00:00<00:00,  4.04it/s]


Evaluation at step 160000:
Direct reconstruction: SSIM=0.9031, RMSE=0.0763, FID=154.54
Deblurred (Algorithm 2): SSIM=0.8417, RMSE=0.1098, FID=151.16
Step 160100: Loss = 0.047385
Step 160100: Loss = 0.043758
Step 160200: Loss = 0.048927
Step 160200: Loss = 0.049143
Step 160300: Loss = 0.037205
Step 160300: Loss = 0.043998
Step 160400: Loss = 0.040812
Step 160400: Loss = 0.044180
Step 160500: Loss = 0.048836
Step 160500: Loss = 0.038799
Step 160600: Loss = 0.045133
Step 160600: Loss = 0.046070
Step 160700: Loss = 0.043212
Step 160700: Loss = 0.042303




Step 160800: Loss = 0.045549
Step 160800: Loss = 0.043007
Step 160900: Loss = 0.042268
Step 160900: Loss = 0.045394
Step 161000: Loss = 0.042596
Step 161000: Loss = 0.047540

Saving samples at step 161000
Average loss over last 1000 steps: 0.044124
Step 161100: Loss = 0.044040
Step 161100: Loss = 0.041808
Step 161200: Loss = 0.044820
Step 161200: Loss = 0.042469
Step 161300: Loss = 0.044699
Step 161300: Loss = 0.045813
Step 161400: Loss = 0.046349
Step 161400: Loss = 0.043504
Step 161500: Loss = 0.047051
Step 161500: Loss = 0.048008
Step 161600: Loss = 0.040829
Step 161600: Loss = 0.045857
Step 161700: Loss = 0.044005
Step 161700: Loss = 0.045147
Step 161800: Loss = 0.044409
Step 161800: Loss = 0.051474
Step 161900: Loss = 0.044887
Step 161900: Loss = 0.049844
Step 162000: Loss = 0.047906
Step 162000: Loss = 0.044737

Saving samples at step 162000
Average loss over last 1000 steps: 0.043939
Step 162100: Loss = 0.051308
Step 162100: Loss = 0.048866
Step 162200: Loss = 0.042198
Step 1622

100%|██████████| 2/2 [00:00<00:00,  3.94it/s]
100%|██████████| 2/2 [00:00<00:00,  3.91it/s]
100%|██████████| 2/2 [00:00<00:00,  3.72it/s]
100%|██████████| 2/2 [00:00<00:00,  3.98it/s]
100%|██████████| 2/2 [00:00<00:00,  3.39it/s]
100%|██████████| 2/2 [00:00<00:00,  3.66it/s]


Evaluation at step 170000:
Direct reconstruction: SSIM=0.9081, RMSE=0.0741, FID=151.13
Deblurred (Algorithm 2): SSIM=0.8500, RMSE=0.1065, FID=149.17
Step 170100: Loss = 0.042277
Step 170100: Loss = 0.042799




Step 170200: Loss = 0.037417
Step 170200: Loss = 0.047957
Step 170300: Loss = 0.043363
Step 170300: Loss = 0.045428
Step 170400: Loss = 0.039620
Step 170400: Loss = 0.050143
Step 170500: Loss = 0.044939
Step 170500: Loss = 0.041869
Step 170600: Loss = 0.049081
Step 170600: Loss = 0.046443
Step 170700: Loss = 0.040678
Step 170700: Loss = 0.047152
Step 170800: Loss = 0.044784
Step 170800: Loss = 0.041134
Step 170900: Loss = 0.046071
Step 170900: Loss = 0.036457
Step 171000: Loss = 0.041792
Step 171000: Loss = 0.046097

Saving samples at step 171000
Average loss over last 1000 steps: 0.042803
Step 171100: Loss = 0.039196
Step 171100: Loss = 0.041594
Step 171200: Loss = 0.036640
Step 171200: Loss = 0.043550
Step 171300: Loss = 0.045002
Step 171300: Loss = 0.046696
Step 171400: Loss = 0.039002
Step 171400: Loss = 0.043285
Step 171500: Loss = 0.042301
Step 171500: Loss = 0.041602
Step 171600: Loss = 0.041421
Step 171600: Loss = 0.043692
Step 171700: Loss = 0.047912
Step 171700: Loss = 0.0398

100%|██████████| 2/2 [00:00<00:00,  3.59it/s]
100%|██████████| 2/2 [00:00<00:00,  3.58it/s]
100%|██████████| 2/2 [00:00<00:00,  3.22it/s]
100%|██████████| 2/2 [00:00<00:00,  3.95it/s]
100%|██████████| 2/2 [00:00<00:00,  3.29it/s]
100%|██████████| 2/2 [00:00<00:00,  4.02it/s]


Evaluation at step 180000:
Direct reconstruction: SSIM=0.9133, RMSE=0.0717, FID=145.47
Deblurred (Algorithm 2): SSIM=0.8558, RMSE=0.1042, FID=147.21
Step 180100: Loss = 0.041246
Step 180100: Loss = 0.044864
Step 180200: Loss = 0.041006
Step 180200: Loss = 0.040412




Step 180300: Loss = 0.043467
Step 180300: Loss = 0.037406
Step 180400: Loss = 0.050416
Step 180400: Loss = 0.045024
Step 180500: Loss = 0.039613
Step 180500: Loss = 0.035748
Step 180600: Loss = 0.039618
Step 180600: Loss = 0.042142
Step 180700: Loss = 0.049574
Step 180700: Loss = 0.042445
Step 180800: Loss = 0.041276
Step 180800: Loss = 0.042486
Step 180900: Loss = 0.039332
Step 180900: Loss = 0.042858
Step 181000: Loss = 0.035779
Step 181000: Loss = 0.041534

Saving samples at step 181000
Average loss over last 1000 steps: 0.041582
Step 181100: Loss = 0.044255
Step 181100: Loss = 0.047648
Step 181200: Loss = 0.037442
Step 181200: Loss = 0.038199
Step 181300: Loss = 0.042162
Step 181300: Loss = 0.047838
Step 181400: Loss = 0.049963
Step 181400: Loss = 0.039337
Step 181500: Loss = 0.043777
Step 181500: Loss = 0.039052
Step 181600: Loss = 0.045141
Step 181600: Loss = 0.045497
Step 181700: Loss = 0.043059
Step 181700: Loss = 0.045304
Step 181800: Loss = 0.047608
Step 181800: Loss = 0.0428

100%|██████████| 2/2 [00:00<00:00,  3.61it/s]
100%|██████████| 2/2 [00:00<00:00,  3.44it/s]
100%|██████████| 2/2 [00:00<00:00,  3.43it/s]
100%|██████████| 2/2 [00:00<00:00,  3.65it/s]
100%|██████████| 2/2 [00:00<00:00,  3.40it/s]
100%|██████████| 2/2 [00:00<00:00,  3.63it/s]


Evaluation at step 190000:
Direct reconstruction: SSIM=0.9172, RMSE=0.0699, FID=138.36
Deblurred (Algorithm 2): SSIM=0.8622, RMSE=0.1017, FID=143.73
Step 190100: Loss = 0.040381
Step 190100: Loss = 0.031687
Step 190200: Loss = 0.041423
Step 190200: Loss = 0.042270
Step 190300: Loss = 0.041293
Step 190300: Loss = 0.037788
Step 190400: Loss = 0.042536
Step 190400: Loss = 0.039511




Step 190500: Loss = 0.040644
Step 190500: Loss = 0.042727
Step 190600: Loss = 0.037635
Step 190600: Loss = 0.040672
Step 190700: Loss = 0.037422
Step 190700: Loss = 0.047592
Step 190800: Loss = 0.038141
Step 190800: Loss = 0.038794
Step 190900: Loss = 0.041040
Step 190900: Loss = 0.039700
Step 191000: Loss = 0.042775
Step 191000: Loss = 0.040733

Saving samples at step 191000
Average loss over last 1000 steps: 0.040556
Step 191100: Loss = 0.037466
Step 191100: Loss = 0.035183
Step 191200: Loss = 0.043208
Step 191200: Loss = 0.037726
Step 191300: Loss = 0.037024
Step 191300: Loss = 0.038495
Step 191400: Loss = 0.039682
Step 191400: Loss = 0.038828
Step 191500: Loss = 0.046736
Step 191500: Loss = 0.040643
Step 191600: Loss = 0.044563
Step 191600: Loss = 0.044807
Step 191700: Loss = 0.036635
Step 191700: Loss = 0.045717
Step 191800: Loss = 0.040917
Step 191800: Loss = 0.043493
Step 191900: Loss = 0.041019
Step 191900: Loss = 0.040742
Step 192000: Loss = 0.041295
Step 192000: Loss = 0.0431

100%|██████████| 2/2 [00:00<00:00,  3.93it/s]
100%|██████████| 2/2 [00:00<00:00,  3.92it/s]
100%|██████████| 2/2 [00:00<00:00,  3.65it/s]
100%|██████████| 2/2 [00:00<00:00,  3.86it/s]
100%|██████████| 2/2 [00:00<00:00,  3.41it/s]
100%|██████████| 2/2 [00:00<00:00,  3.46it/s]


Evaluation at step 200000:
Direct reconstruction: SSIM=0.9211, RMSE=0.0678, FID=136.12
Deblurred (Algorithm 2): SSIM=0.8675, RMSE=0.0991, FID=144.40
Training completed!
Evaluating model...


100%|██████████| 2/2 [00:00<00:00,  3.76it/s]
100%|██████████| 2/2 [00:00<00:00,  3.54it/s]
100%|██████████| 2/2 [00:00<00:00,  3.32it/s]
100%|██████████| 2/2 [00:00<00:00,  3.56it/s]
100%|██████████| 2/2 [00:00<00:00,  3.36it/s]
100%|██████████| 2/2 [00:00<00:00,  3.58it/s]


Evaluation at step 200000:
Direct reconstruction: SSIM=0.9211, RMSE=0.0678, FID=136.12
Deblurred (Algorithm 2): SSIM=0.8675, RMSE=0.0991, FID=144.62
