## **Single image super-resolution with diffusion probabilistic models (SRDiff)**

Paper: [SRDiff: Single Image Super-Resolution with Diffusion Probabilistic Models](https://arxiv.org/abs/2104.14951)

Helpful Resources:
- [SRDiff's github repo](https://github.com/LeiaLi/SRDiff/tree/main)

In [1]:
# # This Python 3 environment comes with many helpful analytics libraries installed
# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# # For example, here's several helpful packages to load

# import numpy as np # linear algebra
# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# # Input data files are available in the read-only "../input/" directory
# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [20]:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch import einsum
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from torchinfo import summary
from torch import GradScaler, autocast

from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

from PIL import Image
import os
import math
from functools import partial
from tqdm import tqdm
from datetime import datetime
import pytz
import copy
import time
import gc

print("imports done!")

imports done!


In [3]:
def get_torch_version():
    torch_version = torch.__version__.split("+")[0]
    torch_number = torch_version.split(".")[:2]
    torch_number_float = torch_number[0] + "." + torch_number[1]
    torch_number_float = float(torch_number_float)
    return torch_number_float


def set_seed(seed=42):
    """
    Seeds basic parameters for reproducibility of results
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        # if get_torch_version() <= 1.7:
        #     torch.set_deterministic(True)
        # else:
        #     torch.use_deterministic_algorithms(True)
    print(f"seed {seed} set!")
    

def compute_accuracy(y_pred, y):
    assert len(y_pred)==len(y), "length of y_pred and y must be equal"
    acc = torch.eq(y_pred, y).sum().item()
    acc = acc/len(y_pred)
    return acc


def train_validation_split(train_dataset):
    X_train, X_valid, y_train, y_valid = train_test_split(train_dataset.data, train_dataset.targets, 
                                                          test_size=0.2, random_state=42, shuffle=True, 
                                                          stratify=train_dataset.targets)
    X_train = torch.tensor(X_train, dtype=torch.float64).permute(0, 3, 1, 2)
    X_valid = torch.tensor(X_valid, dtype=torch.float64).permute(0, 3, 1, 2)
    y_train = torch.tensor(y_train, dtype=torch.int64)
    y_valid = torch.tensor(y_valid, dtype=torch.int64)
    return X_train, X_valid, y_train, y_valid
    

def predict(model, img_path, device):
    img = cv2.imread(img_path)
    if img.shape[-1] == 4:
        img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
    img2 = img.copy()
    img = torch.tensor(img)
    img = img.permute(1,2,0)
    img = img.unsqueeze(dim=0)
    img = img.to(device)
    model.eval()
    with torch.inference_mode():
        logit = model(img)
    pred_prob = torch.softmax(logit, dim=1)
    pred_label = pred_prob.argmax(dim=1)
    plt.imshow(img2)
    plt.axis("off")
    plt.label(f"Prediction: {classes[pred_label]}\t\tProbability: {round(pred_prob)}")
    plt.show()


def set_scheduler(scheduler, results, scheduler_on):
    """Makes the neccessary updates to the scheduler."""
    if scheduler_on == "valid_acc":
        scheduler.step(results["valid_acc"][-1])
    elif scheduler_on == "valid_loss":
        scheduler.step(results["valid_loss"][-1])
    elif scheduler_on == "train_acc":
        scheduler.step(results["train_acc"][-1])
    elif scheduler_on == "train_loss":
        scheduler.step(results["train_loss"][-1])
    else:
        raise ValueError("Invalid `scheduler_on` choice.")
    return scheduler


def visualize_results(results, plot_name=None):
    """Plot the training and validation loss and accuracy, given the results dictionary"""
    train_loss, train_acc = results["train_loss"], results["train_acc"]
    val_loss, val_acc = results["valid_loss"], results["valid_acc"]
    cls = ["no", "vort", "sphere"]
    x = np.arange(len(train_loss))  # this is the number of epochs
    fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(12,12))
    # ax[0,0].set_title("Loss")
    ax[0,0].set_xlabel("Epochs")
    ax[0,0].set_ylabel("Loss")
    ax[0,0].plot(x, train_loss, label="train_loss", color="orange")
    ax[0,0].plot(x, val_loss, label="valid_loss", color="blue")
    ax[0,0].legend()
    # ax[0,1].set_title("Accuracy")
    ax[0,1].set_xlabel("Epochs")
    ax[0,1].set_ylabel("Accuracy")
    ax[0,1].plot(x, train_acc, label="train_acc", color="orange")
    ax[0,1].plot(x, val_acc, label="valid_acc", color="blue")
    ax[0,1].legend()
    # ax[1,0].set_title("Train ROC AUC Plot")
    ax[1,0].set_xlabel("Epochs")
    ax[1,0].set_ylabel("Train ROC AUC Score")
    ax[1,0].plot(x, results["train_roc_auc_0"], label=cls[0])
    ax[1,0].plot(x, results["train_roc_auc_1"], label=cls[1])
    ax[1,0].plot(x, results["train_roc_auc_2"], label=cls[2])
    ax[1,0].legend()
    # ax[1,1].set_title("Valid ROC AUC Plot")
    ax[1,1].set_xlabel("Epochs")
    ax[1,1].set_ylabel("Valid ROC AUC Score")
    ax[1,1].plot(x, results["valid_roc_auc_0"], label=cls[0])
    ax[1,1].plot(x, results["valid_roc_auc_1"], label=cls[1])
    ax[1,1].plot(x, results["valid_roc_auc_2"], label=cls[2])
    ax[1,1].legend()
    if plot_name is not None:
        plt.savefig(plot_name)
    plt.show()
    

def train_step(model, loss_fn, optimizer, dataloader, device, scaler=None):
    model.train()
    train_loss = 0
    train_acc = 0
    all_labels = []
    all_preds = []
    for X, y in dataloader:
        X = X.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        if scaler is not None:           # do automatic mixed precision training
            with autocast(device):       # mixed precision forward pass
                logit = model(X)
                pred_prob = torch.softmax(logit, dim=1)
                pred_label = pred_prob.argmax(dim=1)
                # note: first put logit and then y in the loss_fn
                # otherwise, if you put y first and then logit, then it will raise an error
                loss = loss_fn(logit, y)
            scaler.scale(loss).backward()      # mixed precision backward pass
            scaler.step(optimizer)             # updating optimizer
            scaler.update()                    # updating weights
        else:                     # don't do any mixed precision training
            logit = model(X)
            pred_prob = torch.softmax(logit, dim=1)
            pred_label = pred_prob.argmax(dim=1)
            loss = loss_fn(logit, y)
            loss.backward()
            optimizer.step()
        all_labels.extend(y.detach().cpu().numpy())
        all_preds.extend(pred_prob.detach().cpu().numpy())
        train_loss += loss.item()
        acc = compute_accuracy(pred_label, y)
        train_acc += acc
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    return train_loss, train_acc, all_labels, all_preds
        

def valid_step(model, loss_fn, dataloader, device):
    model.eval()
    valid_loss = 0
    valid_acc = 0
    all_labels = []
    all_preds = []
    with torch.inference_mode():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            logit = model(X)
            pred_prob = torch.softmax(logit, dim=1)
            pred_label = pred_prob.argmax(dim=1)
            # note: first put logit and then y in the loss_fn
            # otherwise, if you put y first and then logit, then it will raise an error
            loss = loss_fn(logit, y)
            valid_loss += loss.item()
            acc = compute_accuracy(pred_label, y)
            valid_acc += acc
            all_labels.extend(y.detach().cpu().numpy())
            all_preds.extend(pred_prob.detach().cpu().numpy())
    valid_loss = valid_loss / len(dataloader)
    valid_acc = valid_acc / len(dataloader)
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    return valid_loss, valid_acc, all_labels, all_preds


def training_fn1(model, loss_fn, optimizer, train_dataloader, valid_dataloader, device, 
                 epochs, scheduler=None, scheduler_on="val_acc", verbose=False, scaler=None,
                 save_best_model=False, path=None, model_name=None, optimizer_name=None, 
                 scheduler_name=None):
    """
    Does model training and validation for one fold in a k-fold cross validation setting.
    """
    results = {
        "train_loss": [],
        "train_acc": [],
        "valid_loss": [],
        "valid_acc": [],
        "train_roc_auc_0": [],
        "valid_roc_auc_0": [],
        "train_roc_auc_1": [],
        "valid_roc_auc_1": [],
        "train_roc_auc_2": [],
        "valid_roc_auc_2": [],
    }
    best_valid_roc_auc = 0.0
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc, train_labels, train_preds = train_step(model, loss_fn, optimizer, 
                                                                      train_dataloader, device, scaler)
        valid_loss, valid_acc, valid_labels, valid_preds = valid_step(model, loss_fn, valid_dataloader, 
                                                                      device)
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["valid_loss"].append(valid_loss)
        results["valid_acc"].append(valid_acc)
        bin_train_labels = label_binarize(train_labels, classes=[0,1,2])
        bin_valid_labels = label_binarize(valid_labels, classes=[0,1,2])
        mean_train_roc_auc = np.mean([results["train_roc_auc_0"], results["train_roc_auc_1"],
                                          results["train_roc_auc_2"]])
        mean_valid_roc_auc = np.mean([results["valid_roc_auc_0"], results["valid_roc_auc_1"],
                                          results["valid_roc_auc_2"]])
        for i in range(3):
            try:
                train_roc_auc = roc_auc_score(bin_train_labels[:, i], train_preds[:, i])
                valid_roc_auc = roc_auc_score(bin_valid_labels[:, i], valid_preds[:, i])
                results[f"train_roc_auc_{i}"].append(train_roc_auc)
                results[f"valid_roc_auc_{i}"].append(valid_roc_auc)
            except ValueError:
                print(f"Warning: AUC computation failed for class {i}")
        if verbose:
            print(
                    f"Epoch: {epoch+1} | Train_loss: {train_loss:.5f} | "
                    f"Train_acc: {train_acc:.5f} | Val_loss: {valid_loss:.5f} | "
                    f"Val_acc: {valid_acc:.5f} | Train_roc_auc: {mean_train_roc_auc:.5f} | "
                    f"Val_roc_auc: {mean_valid_roc_auc:.5f}"
                )
        if scheduler is not None:
            scheduler = set_scheduler(scheduler, results, scheduler_on)
            if mean_valid_roc_auc > best_valid_roc_auc:
                best_valid_roc_auc = mean_valid_roc_auc
                plot_name = path + "/" + model_name[:-3] + "_" + optimizer_name[:-3] + "_" + scheduler_name[:-3] + ".pdf"
                save_model_info(path, device, model, model_name, optimizer, optimizer_name, 
                    scheduler, scheduler_name)
        else:
            if mean_valid_roc_auc > best_valid_roc_auc:
                best_valid_roc_auc = mean_valid_roc_auc
                plot_name = path + "/" + model_name[:-3] + "_" + optimizer_name[:-3] + ".pdf"
                save_model_info(path, device, model, model_name, optimizer, optimizer_name) 
    visualize_results(results, plot_name)


def training_fn2(model, loss_fn, optimizer, train_dataset, device, epochs, 
                 scheduler=None, scheduler_on="val_acc", verbose=False, n_splits=5, scaler=None):
    """
    Does the training and validation for all the folds in a k-fold cross validation setting.
    """
    kf = StratifiedKFold(n_splits=n_splits, random_state=42, shuffle=True)
    MODELS = []
    for fold, (train_idx, val_idx) in enumerate(kf.split(X=train_dataset.data, y=train_dataset.targets)):
        train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, 
                                      sampler=SubsetRandomSampler(train_idx))
        valid_dataloader = DataLoader(dataset=train_dataset, batch_size=64, 
                                      sampler=SubsetRandomSampler(val_idx))
        results = training_fn1(model, loss_fn, optimizer, train_dataloader, valid_dataloader, device, 
                                epochs, scheduler=scheduler, scheduler_on=scheduler_on, verbose=verbose)
        train_loss = np.mean(results["train_loss"])
        valid_loss = np.mean(results["valid_loss"])
        train_acc = np.mean(results["train_acc"])
        valid_acc = np.mean(results["valid_acc"])
        print(
                f"Fold: {fold+1} | Train_loss: {train_loss:.5f} | "
                f"Train_acc: {train_acc:.5f} | Val_loss: {valid_loss:.5f} | "
                f"Val_acc: {valid_acc:.5f}"
            )
        visualize_results(results)
        MODELS.append(model)
    return MODELS


def training_function(model, loss_fn, optimizer, train_dataset, device, epochs, scheduler=None, 
                      scheduler_on="val_acc", verbose=False, validation_strategy="train test split",
                      n_splits=5, scaler=None):
    """
    validation_strategy: choose one of the following: 
        - "train test split"
        - "k-fold cross validation"
    """
    if validation_strategy == "train test split":
        X_train, X_valid, y_train, y_valid = train_validation_split(train_dataset)
        train_dataset = CustomDataset(features=X_train, targets=y_train)
        valid_dataset = CustomDataset(features=X_valid, targets=y_valid)
        train_dataloader = DataLoader(dataset=train_dataset, batch_size=CONFIG["batchsize"], shuffle=True)
        valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=CONFIG["batchsize"], shuffle=False)
        training_fn1(model, loss_fn, optimizer, train_dataloader, valid_dataloader, device, epochs, 
                     scheduler=scheduler, scheduler_on=scheduler_on, verbose=verbose)
    elif validation_strategy == "k-fold cross validation":
        training_fn2(model, loss_fn, optimizer, train_dataset, device, epochs, scheduler=scheduler, 
                     scheduler_on=scheduler_on, verbose=verbose, n_splits=n_splits)
    else:
        raise ValueError("Invalid validation strategy.\nChoose either \"train test split\" \
        or \"k-fold cross validation\"")
    

def save_model_info(path: str, device, model, model_name, optimizer, optimizer_name, 
                    scheduler=None, scheduler_name=""):
    model.to(device)
    torch.save(model.state_dict(), os.path.join(path,model_name))
    torch.save(optimizer.state_dict(), os.path.join(path,optimizer_name))
    if scheduler is not None:
        torch.save(scheduler.state_dict(), os.path.join(path,scheduler_name))    
    print("Model info saved!")
    
    
def load_model_info(PATH, device, model, model_name, optimizer, optimizer_name, 
                    scheduler=None, scheduler_name=""):
    model.load_state_dict(torch.load(os.path.join(path,model_name)))
    model.to(device)
    optimizer.load_state_dict(torch.load(os.path.join(path,optimizer_name)))
    if scheduler is not None:
        scheduler.load_state_dict(torch.load(os.path.join(path,scheduler_name)))
    print("Model info loaded!")
    
    
def get_current_time():
    """Returns the current time in Toronto."""
    now = datetime.now(pytz.timezone('Canada/Eastern'))
    current_time = now.strftime("%d_%m_%Y__%H_%M_%S")
    return current_time


print("Utility functions created!")

Utility functions created!


In [4]:
get_torch_version()

2.5

In [5]:
set_seed(42)

seed 42 set!


In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [18]:
path = "../input/"
train_path = path + "DIV2K_train_HR/DIV2K_train_HR/"
valid_path = path+"DIV2K_valid_HR/DIV2K_valid_HR/"
print("No. of images in the training dataset:", len(os.listdir(train_path)))
print("No. of images in the validation dataset:", len(os.listdir(valid_path)))

No. of images in the training dataset: 800
No. of images in the validation dataset: 100


In [None]:
def fn():
    train_img_name = random.choice(os.listdir(train_path))
    valid_img_name = random.choice(os.listdir(valid_path))
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12,12))
    ax[0].imshow(plt.imread(train_path+train_img_name))
    ax[0].axis("off")
    ax[0].set_title("Train Image")
    ax[1].imshow(plt.imread(valid_path+valid_img_name))
    ax[1].axis("off")
    ax[1].set_title("Validation Image")
    plt.show()

fn()

### **Conditional Network (Feature Extractor)**

In [9]:
# Residual Dense Block (RDB)
class ResidualDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_channels=32, num_layers=5):
        super(ResidualDenseBlock, self).__init__()
        self.num_layers = num_layers
        self.growth_channels = growth_channels
        
        # Create convolutional layers for dense connections
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(
                nn.Conv2d(in_channels + i * growth_channels, growth_channels, kernel_size=3, padding=1)
            )
        # Local feature fusion layer to combine features from all layers
        self.lff = nn.Conv2d(in_channels + num_layers * growth_channels, in_channels, kernel_size=1)
    
    def forward(self, x):
        features = [x]
        for layer in self.layers:
            # Concatenate previous features along the channel dimension
            concat_features = torch.cat(features, dim=1)
            out = F.relu(layer(concat_features))
            features.append(out)
        # Fuse all concatenated features
        concat_features = torch.cat(features, dim=1)
        fused = self.lff(concat_features)
        # Apply residual connection with scaling to stabilize training
        return fused * 0.2 + x

# Conditional Network using multiple Residual Dense Blocks
class ConditionalNet(nn.Module):
    def __init__(self, in_channels=3, num_features=64, num_blocks=5):
        super(ConditionalNet, self).__init__()
        # Initial convolution layer to extract basic features
        self.conv_first = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1)
        
        # Sequence of ResidualDenseBlocks
        self.rdb_blocks = nn.Sequential(
            *[ResidualDenseBlock(num_features) for _ in range(num_blocks)]
        )
        
        # Final convolution to produce the conditioned feature map
        self.conv_last = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
    
    def forward(self, x):
        out = self.conv_first(x)
        out = self.rdb_blocks(out)
        out = self.conv_last(out)
        return out

# Example usage
if __name__ == "__main__":
    # Create a dummy low-resolution image tensor (batch size=1, channels=3, 64x64 image)
    lr_image = torch.randn(1, 3, 64, 64)
    model = ConditionalNet()
    features = model(lr_image)
    print("Extracted features shape:", features.shape)


Extracted features shape: torch.Size([1, 64, 64, 64])


### **Diffusion Model**

In [13]:
def get_timestep_embedding(timesteps, embedding_dim):
    """
    Create sinusoidal embeddings for the given timesteps.
    timesteps: a tensor of shape (N,)
    embedding_dim: dimension of the embedding
    """
    half_dim = embedding_dim // 2
    emb_factor = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb_factor)
    emb = timesteps.float().unsqueeze(1) * emb.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:
        emb = F.pad(emb, (0, 1))
    return emb

class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super(DownBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
    
    def forward(self, x, t_emb):
        h = self.conv(x)
        # Incorporate time embedding: broadcast to spatial dimensions
        time_emb = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)
        h = h + time_emb
        h_down = self.downsample(h)
        return h_down, h  # Return downsampled feature and skip connection

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super(UpBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.conv = nn.Sequential(
            nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
    
    def forward(self, x, skip, t_emb):
        h = self.upconv(x)
        # Concatenate skip connection from down path
        h = torch.cat([h, skip], dim=1)
        h = self.conv(h)
        time_emb = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1)
        h = h + time_emb
        return h

class UNetDiffusion(nn.Module):
    def __init__(self, in_channels=3, base_channels=64, time_emb_dim=128):
        super(UNetDiffusion, self).__init__()
        self.time_emb_dim = time_emb_dim
        
        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        
        # Initial convolution to process the noisy input
        self.init_conv = nn.Conv2d(in_channels, base_channels, kernel_size=3, padding=1)
        
        # Downsampling path
        self.down1 = DownBlock(base_channels, base_channels*2, time_emb_dim)
        self.down2 = DownBlock(base_channels*2, base_channels*4, time_emb_dim)
        
        # Bottleneck layer
        self.bottleneck = nn.Sequential(
            nn.Conv2d(base_channels*4, base_channels*4, kernel_size=3, padding=1),
            nn.ReLU()
        )
        
        # Upsampling path
        self.up1 = UpBlock(base_channels*4, base_channels*2, time_emb_dim)
        self.up2 = UpBlock(base_channels*2, base_channels, time_emb_dim)
        
        # Final convolution to produce the denoised output image
        self.final_conv = nn.Conv2d(base_channels, in_channels, kernel_size=3, padding=1)
    
    def forward(self, x, t, cond_features=None):
        """
        x: Noisy image tensor of shape (B, C, H, W)
        t: Tensor containing the time step (e.g., shape (B,))
        cond_features: Optional conditional features from the Conditional Network
        """
        # Generate and process time embeddings
        t_emb = get_timestep_embedding(t, self.time_emb_dim)
        t_emb = self.time_mlp(t_emb)
        
        # Initial feature extraction
        h0 = self.init_conv(x)
        
        # Downsampling with skip connections
        h1, skip1 = self.down1(h0, t_emb)
        h2, skip2 = self.down2(h1, t_emb)
        
        # Bottleneck processing
        h_mid = self.bottleneck(h2)
        
        # Upsampling and merging with skip connections
        h_up1 = self.up1(h_mid, skip2, t_emb)
        h_up2 = self.up2(h_up1, skip1, t_emb)
        
        # Optionally add conditional features from the LR image
        if cond_features is not None:
            h_up2 = h_up2 + cond_features
        
        # Final convolution to produce the output image
        out = self.final_conv(h_up2)
        return out

# Example usage:
if __name__ == "__main__":
    # Create a dummy noisy image tensor (batch size=1, channels=3, 64x64 image)
    noisy_img = torch.randn(1, 3, 64, 64)
    # Create a dummy time step tensor (e.g., a single diffusion step)
    t = torch.tensor([10])
    
    model = UNetDiffusion()
    denoised = model(noisy_img, t)
    print("Denoised image shape:", denoised.shape)


### **Noise Schedule and Forward Diffusion Process**

In [14]:
class DiffusionSchedule:
    def __init__(self, num_timesteps, beta_start=1e-4, beta_end=0.02):
        """
        Initializes the diffusion schedule.
        num_timesteps: Total number of diffusion steps.
        beta_start, beta_end: Defines the linear schedule for beta.
        """
        self.num_timesteps = num_timesteps
        # Create a linear schedule for beta
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        # Calculate alphas: αₜ = 1 - βₜ
        self.alphas = 1.0 - self.betas
        # Compute cumulative product: \bar{α}_t = ∏_{s=1}^t α_s
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def get_alpha_bar(self, t):
        """
        Retrieve \bar{α}_t for a given time step t.
        t: A tensor of time steps (shape: [batch_size])
        Returns: Tensor of corresponding alpha_bar values (shape: [batch_size, 1, 1, 1])
        """
        # Indexing into alpha_bars for each t and reshape to broadcast over image dimensions
        alpha_bar = self.alpha_bars[t].view(-1, 1, 1, 1)
        return alpha_bar

def forward_diffusion_sample(x0, t, schedule):
    """
    Perform the forward diffusion process.
    x0: Original clean image tensor of shape (B, C, H, W)
    t: Tensor containing the diffusion time steps (shape: [B])
    schedule: Instance of DiffusionSchedule
    Returns: Noisy image x_t and the sampled noise epsilon.
    """
    # Get corresponding alpha_bar for each t
    alpha_bar = schedule.get_alpha_bar(t)
    # Sample random Gaussian noise
    noise = torch.randn_like(x0)
    # Compute noisy image using the forward process equation
    xt = torch.sqrt(alpha_bar) * x0 + torch.sqrt(1 - alpha_bar) * noise
    return xt, noise

# Example usage:
if __name__ == "__main__":
    # Define parameters for the noise schedule
    num_timesteps = 1000
    schedule = DiffusionSchedule(num_timesteps=num_timesteps)
    
    # Create a dummy high-resolution image tensor (batch size=1, channels=3, 64x64 image)
    hr_image = torch.randn(1, 3, 64, 64)
    # Create a tensor for time step, e.g., t=10 for the current batch (ensure type is long for indexing)
    t = torch.tensor([10], dtype=torch.long)
    
    # Generate the noisy image
    noisy_image, noise = forward_diffusion_sample(hr_image, t, schedule)
    print("Noisy image shape:", noisy_image.shape)


Noisy image shape: torch.Size([1, 3, 64, 64])


### **Residual Prediction Module**

In [15]:
# Define a simple Residual Block used in the Residual Prediction Module.
class ResidualBlock(nn.Module):
    def __init__(self, num_features):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
    
    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        return out + identity

# Residual Prediction Module.
class ResidualPredictionNet(nn.Module):
    def __init__(self, in_channels=3, num_features=64, num_residual_blocks=5, scale_factor=4):
        """
        in_channels: Number of channels in the input image (e.g., 3 for RGB).
        num_features: Number of feature maps used in intermediate layers.
        num_residual_blocks: How many residual blocks to use.
        scale_factor: Upscaling factor to reach high-resolution.
        """
        super(ResidualPredictionNet, self).__init__()
        # Upsampling layer using bilinear interpolation.
        self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
        # An initial convolution layer to process the upsampled image.
        self.entry_conv = nn.Conv2d(in_channels, num_features, kernel_size=3, padding=1)
        # Residual blocks to learn the missing details.
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features) for _ in range(num_residual_blocks)]
        )
        # A final convolution layer to predict the residual image.
        self.exit_conv = nn.Conv2d(num_features, in_channels, kernel_size=3, padding=1)
    
    def forward(self, lr_image):
        # Upsample the low-resolution image to the desired high-resolution size.
        upsampled = self.upsample(lr_image)
        x = self.entry_conv(upsampled)
        x = self.residual_blocks(x)
        residual = self.exit_conv(x)
        # The final prediction is the upsampled image plus the predicted residual.
        hr_pred = upsampled + residual
        return hr_pred

# Example usage:
if __name__ == "__main__":
    # Create a dummy low-resolution image tensor
    # Let's say the LR image has shape (batch_size=1, channels=3, height=16, width=16)
    # With a scale factor of 4, the HR image will have size 64x64.
    lr_image = torch.randn(1, 3, 16, 16)
    model = ResidualPredictionNet()
    hr_image = model(lr_image)
    print("Predicted HR image shape:", hr_image.shape)


Predicted HR image shape: torch.Size([1, 3, 64, 64])


### **Data Preprocessing and Dataset Pipeline**

In [21]:
class DIV2KDataset(Dataset):
    def __init__(self, hr_dir, scale_factor=4, transform=None):
        """
        hr_dir: Path to the directory containing high resolution images.
        scale_factor: Factor by which to downscale the HR image to generate the LR image.
        transform: Optional additional transformations to be applied on the HR image.
        """
        self.hr_dir = hr_dir
        # Collect all image paths with common image extensions.
        self.hr_image_paths = [
            os.path.join(hr_dir, f) for f in os.listdir(hr_dir)
            if f.lower().endswith(('.png', '.jpg', '.jpeg'))
        ]
        self.scale_factor = scale_factor
        self.transform = transform
        self.to_tensor = transforms.ToTensor()
    
    def __len__(self):
        return len(self.hr_image_paths)
    
    def __getitem__(self, idx):
        # Load HR image and ensure it's in RGB format.
        hr_path = self.hr_image_paths[idx]
        hr_image = Image.open(hr_path).convert("RGB")
        
        # Optionally apply additional transformations.
        if self.transform:
            hr_image = self.transform(hr_image)
        
        # Convert the HR image to a tensor.
        hr_tensor = self.to_tensor(hr_image)
        
        # Generate the corresponding LR image by downsampling using bicubic interpolation.
        w, h = hr_image.size
        lr_w, lr_h = w // self.scale_factor, h // self.scale_factor
        lr_image = hr_image.resize((lr_w, lr_h), Image.BICUBIC)
        lr_tensor = self.to_tensor(lr_image)
        
        return lr_tensor, hr_tensor

# Example usage:
if __name__ == "__main__":
    # Define the directory containing DIV2K HR images.
    hr_directory = train_path
    # Initialize the dataset with a 4x downscaling factor.
    dataset = DIV2KDataset(hr_directory, scale_factor=4)
    # Create a DataLoader for batching and shuffling.
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    # Iterate over the dataset.
    for lr_batch, hr_batch in dataloader:
        print("Low-resolution batch shape:", lr_batch.shape)
        print("High-resolution batch shape:", hr_batch.shape)
        break


RuntimeError: stack expects each tensor to be equal size, but got [3, 339, 510] at entry 0 and [3, 261, 510] at entry 2

### **Training and Evaluation Pipeline**

In [None]:
# Assume the following components have been defined:
# - ConditionalNet (from Component 1)
# - UNetDiffusion (from Component 2)
# - DiffusionSchedule and forward_diffusion_sample (from Component 3)
# - DIV2KDataset (from Component 4)

def train(diffusion_model, cond_net, diffusion_schedule, dataloader, num_epochs=10, device='cuda'):
    """
    Training loop for the diffusion model.
    
    Args:
      diffusion_model: The U-Net diffusion model.
      cond_net: The conditional network (pre-trained and fixed).
      diffusion_schedule: The noise schedule instance.
      dataloader: DataLoader providing (LR, HR) image pairs.
      num_epochs: Number of training epochs.
      device: 'cuda' or 'cpu'.
    """
    diffusion_model.train()
    # Set conditional network to evaluation mode if pre-trained
    cond_net.eval()  
    optimizer = optim.Adam(diffusion_model.parameters(), lr=1e-4)
    mse_loss = torch.nn.MSELoss()

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch_idx, (lr, hr) in enumerate(dataloader):
            lr = lr.to(device)
            hr = hr.to(device)
            
            # Extract conditional features from LR images
            with torch.no_grad():
                cond_features = cond_net(lr)
            
            # Sample a random diffusion time step for each image in the batch
            batch_size = hr.size(0)
            t = torch.randint(0, diffusion_schedule.num_timesteps, (batch_size,), device=device).long()
            
            # Generate a noisy version of the HR image using the forward diffusion process
            xt, noise = forward_diffusion_sample(hr, t, diffusion_schedule)
            xt = xt.to(device)
            noise = noise.to(device)
            
            # Predict the noise from the diffusion model
            pred_noise = diffusion_model(xt, t, cond_features)
            
            # Compute loss: how well did the model predict the added noise?
            loss = mse_loss(pred_noise, noise)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item():.4f}")
        
        avg_loss = epoch_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}] Average Loss: {avg_loss:.4f}")

def evaluate(diffusion_model, cond_net, diffusion_schedule, lr_image, num_steps=1000, device='cuda'):
    """
    A sketch of the evaluation/inference procedure.
    
    Args:
      diffusion_model: The trained diffusion model.
      cond_net: The conditional network.
      diffusion_schedule: The noise schedule.
      lr_image: A low-resolution image tensor.
      num_steps: Total diffusion steps (should match training schedule).
      device: 'cuda' or 'cpu'.
    
    Returns:
      hr_pred: The generated high-resolution image.
    """
    diffusion_model.eval()
    cond_net.eval()
    
    with torch.no_grad():
        # Extract conditional features from the LR image
        cond_features = cond_net(lr_image.to(device))
        
        # Start from pure noise
        hr_pred = torch.randn(lr_image.size(0), 3, lr_image.size(2) * 4, lr_image.size(3) * 4, device=device)
        
        # Here we would iteratively apply the reverse diffusion process.
        # For brevity, this is a simplified loop.
        for t in reversed(range(num_steps)):
            t_tensor = torch.full((lr_image.size(0),), t, device=device, dtype=torch.long)
            # One step of the reverse diffusion process:
            hr_pred = diffusion_model(hr_pred, t_tensor, cond_features)
            # Additional noise correction and scaling would be applied here in a full implementation.
    
    return hr_pred

# Example usage:
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Initialize the models
    cond_net = ConditionalNet().to(device)
    diffusion_model = UNetDiffusion().to(device)
    # Assume the conditional network is pre-trained; here, we keep it fixed.
    
    # Create the diffusion schedule (e.g., 1000 timesteps)
    diffusion_schedule = DiffusionSchedule(num_timesteps=1000)
    
    # Initialize the DIV2K dataset and DataLoader
    dataset = DIV2KDataset(hr_dir="data/raw/DIV2K", scale_factor=4)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    # Train the model
    train(diffusion_model, cond_net, diffusion_schedule, dataloader, num_epochs=5, device=device)
    
    # For evaluation, assume we take one low-resolution image from the dataset
    lr_sample, _ = dataset[0]
    lr_sample = lr_sample.unsqueeze(0)  # Add batch dimension
    hr_generated = evaluate(diffusion_model, cond_net, diffusion_schedule, lr_sample, num_steps=1000, device=device)
    print("Generated HR image shape:", hr_generated.shape)


### **Loss Functions and Optimization Strategy**

In [None]:
# Define a custom loss module for the diffusion model.
class DiffusionLoss(nn.Module):
    def __init__(self):
        super(DiffusionLoss, self).__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, pred_noise, true_noise):
        """
        Computes the mean squared error between the predicted noise and the actual noise.
        
        Args:
          pred_noise: The noise predicted by the diffusion model.
          true_noise: The actual noise that was added during the forward process.
        
        Returns:
          loss: A scalar loss value.
        """
        loss = self.mse_loss(pred_noise, true_noise)
        return loss

# Example setup of the optimizer and usage of the loss in a training step.
def training_step(diffusion_model, cond_net, diffusion_schedule, lr_batch, hr_batch, device):
    # Set models to proper modes.
    diffusion_model.train()
    cond_net.eval()  # Assuming the conditional network is pre-trained.
    
    # Extract conditional features.
    with torch.no_grad():
        cond_features = cond_net(lr_batch.to(device))
    
    # Sample random time steps for each image in the batch.
    batch_size = hr_batch.size(0)
    t = torch.randint(0, diffusion_schedule.num_timesteps, (batch_size,), device=device).long()
    
    # Generate the noisy image and corresponding noise using the forward diffusion process.
    xt, noise = forward_diffusion_sample(hr_batch.to(device), t, diffusion_schedule)
    
    # Predict the noise using the diffusion model.
    pred_noise = diffusion_model(xt, t, cond_features)
    
    # Compute the loss.
    loss_fn = DiffusionLoss().to(device)
    loss = loss_fn(pred_noise, noise)
    
    # Set up the optimizer (here we use Adam).
    optimizer = optim.Adam(diffusion_model.parameters(), lr=1e-4)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()

# Example usage (assuming required components are already defined and instantiated):
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Dummy instantiation of models (in practice, these are your defined models).
    cond_net = ConditionalNet().to(device)
    diffusion_model = UNetDiffusion().to(device)
    
    # Create a diffusion schedule (e.g., 1000 timesteps).
    diffusion_schedule = DiffusionSchedule(num_timesteps=1000)
    
    # Create dummy batches for low-resolution and high-resolution images.
    # Here, we assume lr_batch shape: (batch_size, 3, H, W) and hr_batch shape accordingly.
    lr_batch = torch.randn(4, 3, 64, 64)  # e.g., batch size=4, LR images 64x64.
    hr_batch = torch.randn(4, 3, 256, 256)  # HR images corresponding to 4x upscaling.
    
    # Perform one training step.
    loss_val = training_step(diffusion_model, cond_net, diffusion_schedule, lr_batch, hr_batch, device)
    print("Training step loss:", loss_val)
