# Evaluation of Super-Resolution Models with and without Swin Transformer Integration

This notebook is designed to evaluate the performance of super-resolution (SR) models applied to medical images, specifically MRI slices. The main objective is to compare the effectiveness of SR models enhanced with Swin Transformer blocks against traditional interpolation methods. The evaluation process involves upscaling low-resolution images (64x64) to high-resolution images (256x256) using a cascaded super-resolution approach. The models tested include a direct application of DDPM for upscaling, as well as models that incorporate Swin Transformer blocks within a UNet architecture.

Key metrics for comparison include Peak Signal-to-Noise Ratio (PSNR), Structural Similarity Index Measure (SSIM), Mean Squared Error (MSE), Feature Similarity Index (FSIM), Visual Information Fidelity (VIF), Learned Perceptual Image Patch Similarity (LPIPS), and Average Gradient (AG). These metrics provide a comprehensive assessment of the models' ability to enhance image quality while preserving critical medical image features.


##### Importing Libraries

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
import os
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms, models
import numpy as np
import math
from sklearn.metrics import pairwise_distances
from scipy.stats import entropy
from scipy.linalg import sqrtm
from tqdm import tqdm


##### Implementing Swin Transformer and Helper Functions

In [2]:
"""
Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>
https://github.com/microsoft/Swin-Transformer
"""

# DropPath (Stochastic Depth) module to implement drop path regularization
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        output = x.div(keep_prob) * random_tensor
        return output

# Helper functions to handle tuple and truncation
def to_2tuple(x):
    if isinstance(x, (tuple, list)):
        return x
    return (x, x)

def trunc_normal_(tensor, mean=0., std=1.):
    with torch.no_grad():
        size = tensor.shape
        tmp = tensor.new_empty(size + (4,)).normal_()
        valid = (tmp < 2) & (tmp > -2)
        ind = valid.max(-1, keepdim=True)[1]
        tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
        tensor.data.mul_(std).add_(mean)
        return tensor

# MLP module used within Swin Transformer
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

# Functions to partition and reverse windows in the Swin Transformer
def window_partition(x, window_size):
    B, C, H, W = x.shape
    x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
    windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, -1, H, W)
    return x

# Window-based multi-head self-attention (W-MSA) module
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # Relative position bias table for all windows
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))

        # Get relative position index for each window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1)
        attn = attn + relative_position_bias.unsqueeze(0).to(attn.dtype)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

# Swin Transformer block implementing the shifted window-based attention mechanism
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * self.mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, drop=drop)

        if self.shift_size > 0:
            attn_mask = self.calculate_mask(self.input_resolution)
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def calculate_mask(self, x_size):
        H, W = x_size
        img_mask = torch.zeros((1, H, W, 1))
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size).view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        return attn_mask

    def forward(self, x):
        B, L, C = x.shape
        H, W = self.input_resolution

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        if self.attn_mask is not None:
            attn_windows = self.attn(x_windows, mask=self.attn_mask.to(x.dtype))
        else:
            attn_windows = self.attn(x_windows)

        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(B, H * W, C)

        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


##### Implementing Self, Cross Attention and Sinusoidal Positional Embedding Classes

In [3]:
# Self-attention block for feature refinement
class SelfAttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttentionBlock, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

# Cross-attention block for the decoder to focus on relevant features
class CrossAttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(CrossAttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# Sinusoidal positional embedding for timestep encoding in DDPM
class SinusoidalPositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim, max_len=10000):
        super(SinusoidalPositionalEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.max_len = max_len

    def forward(self, timesteps):
        half_dim = self.embedding_dim // 2
        emb = math.log(self.max_len) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
        emb = timesteps[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if self.embedding_dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
        return emb


##### 64x64 UNet

In [4]:
class AttentionUNet(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super(AttentionUNet, self).__init__()

        self.encoder1 = self.conv_block(in_channels + emb_dim, 64)
        self.self_attention1 = SelfAttentionBlock(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.encoder2 = self.conv_block(64, 128)
        self.self_attention2 = SelfAttentionBlock(128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.encoder3 = self.conv_block(128, 256)
        self.self_attention3 = SelfAttentionBlock(256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.bottleneck = self.conv_block(256, 512)
        self.swin_block = SwinTransformerBlock(dim=512, input_resolution=(8, 8), num_heads=8, window_size=4) 

        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.cross_attention3 = CrossAttentionBlock(256, 256, 128)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.cross_attention2 = CrossAttentionBlock(128, 128, 64)
        self.decoder2 = self.conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.cross_attention1 = CrossAttentionBlock(64, 64, 32)
        self.decoder1 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.timestep_embedding_layer = SinusoidalPositionalEmbedding(emb_dim)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t):
        t_embed = self.timestep_embedding_layer(t)
        t_embed = t_embed.view(t.size(0), -1, 1, 1)
        t_embed = t_embed.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat((x, t_embed), dim=1)

        # Encoder
        enc1 = self.encoder1(x)
        enc1 = self.self_attention1(enc1)
        enc1_pooled = self.pool1(enc1)

        enc2 = self.encoder2(enc1_pooled)
        enc2 = self.self_attention2(enc2)
        enc2_pooled = self.pool2(enc2)

        enc3 = self.encoder3(enc2_pooled)
        enc3 = self.self_attention3(enc3)
        enc3_pooled = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc3_pooled)
        B, C, H, W = bottleneck.shape
        bottleneck = bottleneck.view(B, H * W, C)
        bottleneck = self.swin_block(bottleneck)
        bottleneck = bottleneck.view(B, C, H, W)

        # Decoder
        upconv3 = self.upconv3(bottleneck)
        enc3 = self.cross_attention3(upconv3, enc3)
        dec3 = torch.cat((upconv3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        upconv2 = self.upconv2(dec3)
        enc2 = self.cross_attention2(upconv2, enc2)
        dec2 = torch.cat((upconv2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        upconv1 = self.upconv1(dec2)
        enc1 = self.cross_attention1(upconv1, enc1)
        dec1 = torch.cat((upconv1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        final_output = self.final_conv(dec1)
        return final_output


##### 64x64 to 128x128 UNet

In [5]:

# Define the UNet model with attention
class SuperResUNet128(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super(SuperResUNet128, self).__init__()

        self.encoder1 = self.conv_block(in_channels + emb_dim, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.encoder2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.encoder3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.bottleneck = self.conv_block(256, 512)
        self.swin_block = SwinTransformerBlock(dim=512, input_resolution=(16, 16), num_heads=8, window_size=4)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self.conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.timestep_embedding_layer = SinusoidalPositionalEmbedding(emb_dim)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t):
        t_embed = self.timestep_embedding_layer(t)
        t_embed = t_embed.view(t.size(0), -1, 1, 1)
        t_embed = t_embed.repeat(1, 1, x.size(2), x.size(3))

        x = torch.cat((x, t_embed), dim=1)

        # Encoder
        enc1 = self.encoder1(x)
        enc1_pooled = self.pool1(enc1)

        enc2 = self.encoder2(enc1_pooled)
        enc2_pooled = self.pool2(enc2)

        enc3 = self.encoder3(enc2_pooled)
        enc3_pooled = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc3_pooled)
        B, C, H, W = bottleneck.shape
        bottleneck = bottleneck.view(B, H * W, C)
        bottleneck = self.swin_block(bottleneck)
        bottleneck = bottleneck.view(B, C, H, W)

        # Decoder
        upconv3 = self.upconv3(bottleneck)
        dec3 = torch.cat((upconv3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        upconv2 = self.upconv2(dec3)
        dec2 = torch.cat((upconv2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        upconv1 = self.upconv1(dec2)
        dec1 = torch.cat((upconv1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        final_output = self.final_conv(dec1)
        return final_output


##### 128x128 to 256x256 UNet

In [6]:

class SuperResUNet256(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super(SuperResUNet256, self).__init__()

        self.encoder1 = self.conv_block(in_channels + emb_dim, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.encoder2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.encoder3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.bottleneck = self.conv_block(256, 512)
        self.swin_block = SwinTransformerBlock(dim=512, input_resolution=(32, 32), num_heads=8, window_size=4)
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self.conv_block(256, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.timestep_embedding_layer = SinusoidalPositionalEmbedding(emb_dim)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t):
        t_embed = self.timestep_embedding_layer(t)
        t_embed = t_embed.view(t.size(0), -1, 1, 1)
        t_embed = t_embed.repeat(1, 1, x.size(2), x.size(3))

        x = torch.cat((x, t_embed), dim=1)

        # Encoder
        enc1 = self.encoder1(x)
        enc1_pooled = self.pool1(enc1)

        enc2 = self.encoder2(enc1_pooled)
        enc2_pooled = self.pool2(enc2)

        enc3 = self.encoder3(enc2_pooled)
        enc3_pooled = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc3_pooled)
        B, C, H, W = bottleneck.shape
        bottleneck = bottleneck.view(B, H * W, C)
        bottleneck = self.swin_block(bottleneck)
        bottleneck = bottleneck.view(B, C, H, W)

        # Decoder
        upconv3 = self.upconv3(bottleneck)
        dec3 = torch.cat((upconv3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        upconv2 = self.upconv2(dec3)
        dec2 = torch.cat((upconv2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        upconv1 = self.upconv1(dec2)
        dec1 = torch.cat((upconv1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        final_output = self.final_conv(dec1)
        return final_output


##### Implementing DDPM and SuperRes DDPM

In [7]:
"""
DDPM implementation adapted from:
https://github.com/hojonathanho/diffusion/tree/master
"""

class SuperResDDPM(nn.Module):
    def __init__(self, model, num_timesteps, beta_start=0.00085, beta_end=0.0120):
        super(SuperResDDPM, self).__init__()
        self.model = model
        self.num_timesteps = num_timesteps

        betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', 1 - betas)
        self.register_buffer('alphas_cumprod', torch.cumprod(1 - betas, dim=0))
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - self.alphas_cumprod))

    def forward(self, z_t, t, low_res_image):
        # Concatenate low_res_image with z_t to condition the model
        low_res_upsampled = F.interpolate(low_res_image, scale_factor=2, mode='bicubic', align_corners=False)
        return self.model(torch.cat([z_t, low_res_upsampled], dim=1), t)

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_timesteps, (batch_size,)).to(device)

    def forward_diffusion(self, target_high_res_img, t, noise):
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        return sqrt_alphas_cumprod_t * target_high_res_img + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, input_low_res_img, target_high_res_img, t):
        sqrt_alpha_t = self.sqrt_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        sqrt_one_minus_alpha_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)

        #input_high_res_img = F.interpolate(input_low_res_img, scale_factor=2, mode='bilinear', align_corners=False)

        noise = torch.randn_like(target_high_res_img)
        z_t = self.forward_diffusion(target_high_res_img, t, noise)

        predicted_noise = self.forward(z_t, t, input_low_res_img)

        return nn.MSELoss()(predicted_noise, noise)

    def sample(self, low_res_image):
        z_t = torch.randn_like(F.interpolate(low_res_image, scale_factor=2, mode='bicubic', align_corners=False))

        for t in reversed(range(self.num_timesteps)):
            t_tensor = torch.tensor([t], device=z_t.device).long()
            alpha_t = self.alphas[t]
            sqrt_alpha_t = torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t = torch.sqrt(1 - self.alphas_cumprod[t])
            beta_t = self.betas[t]

            predicted_noise = self.forward(z_t, t_tensor, low_res_image)

            z_t = (z_t - (1 - self.alphas[t]) * predicted_noise / sqrt_one_minus_alpha_t) / sqrt_alpha_t

            if t > 0:
                z_t += torch.randn_like(z_t) * torch.sqrt(beta_t)

        return z_t

    def p_sample(self, z, t, low_res_image):
        alpha_t = self.alphas[t]
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_t = torch.sqrt(1 - self.alphas_cumprod[t])
        beta_t = self.betas[t]
        predicted_noise = self.forward(z, t, low_res_image)

        z = (z - beta_t / sqrt_one_minus_alpha_t * predicted_noise) / sqrt_alpha_t
        return z

In [8]:
"""
DDPM implementation adapted from:
https://github.com/hojonathanho/diffusion/tree/master
"""

class DDPM(nn.Module):
    def __init__(self, model, num_timesteps, latent_dim, beta_start=0.00085, beta_end=0.0120):
        super(DDPM, self).__init__()
        self.model = model
        self.num_timesteps = num_timesteps

        betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', 1 - betas)
        self.register_buffer('alphas_cumprod', torch.cumprod(1 - betas, dim=0))
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1 - self.alphas_cumprod))

        self.latent_dim = latent_dim

    def forward(self, z_t, t):
        return self.model(z_t, t)

    def sample_timesteps(self, batch_size):
        return torch.randint(0, self.num_timesteps, (batch_size,)).to(device)

    def forward_diffusion(self, z_0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(z_0)
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].unsqueeze(1).unsqueeze(1).unsqueeze(1)
        return sqrt_alphas_cumprod_t * z_0 + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, z_0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(z_0)
        z_t = self.forward_diffusion(z_0, t, noise)
        predicted_noise = self.forward(z_t, t)
        return nn.MSELoss()(noise, predicted_noise)

    def sample(self, shape):
        z_t = torch.randn(shape).to(device)
        for t in reversed(range(self.num_timesteps)):
            t_tensor = torch.tensor([t], device=z_t.device).long()
            alpha_t = self.alphas[t]
            sqrt_alpha_t = torch.sqrt(alpha_t)
            sqrt_one_minus_alpha_t = torch.sqrt(1 - self.alphas_cumprod[t])
            beta_t = self.betas[t]

            # Predict the noise
            predicted_noise = self.forward(z_t, t_tensor)

            # Remove the predicted noise
            z_t = (z_t - beta_t / sqrt_one_minus_alpha_t * predicted_noise) / sqrt_alpha_t

            # Add noise for non-final steps
            if t > 0:
                z_t += torch.randn_like(z_t) * torch.sqrt(beta_t)

        return z_t

    def p_sample(self, z, t):
        predicted_noise = self.forward(z, t)
        alpha_t = self.alphas[t]
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_t = torch.sqrt(1 - alpha_t)
        z = (z - predicted_noise * (1 - alpha_t) / sqrt_one_minus_alpha_t) / sqrt_alpha_t
        return z


### LOAD MODELS

In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_model(ddpm, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    ddpm.load_state_dict(checkpoint['model_state_dict'])
    return ddpm

# Load the three models from checkpoints
def load_models():
    # Model 1: 64x64 Generation
    unet_64 = AttentionUNet(in_channels=1, out_channels=1, emb_dim=128).to(device)
    ddpm_64 = DDPM(unet_64, num_timesteps=2000, latent_dim=128).to(device)
    ddpm_64 = load_model(ddpm_64, "Model_Savepoints/swinddpm64RAW_checkpoint.pth")

    # Model 2: 64x64 to 128x128 Super-resolution
    unet_128 = SuperResUNet128(in_channels=2, out_channels=1, emb_dim=128).to(device)
    ddpm_128 = SuperResDDPM(unet_128, num_timesteps=1000).to(device)
    ddpm_128 = load_model(ddpm_128, "Model_Savepoints/cascadedddpm64200epochs_checkpoint.pth")

    # Model 3: 128x128 to 256x256 Super-resolution
    unet_256 = SuperResUNet256(in_channels=2, out_channels=1, emb_dim=128).to(device)
    ddpm_256 = SuperResDDPM(unet_256, num_timesteps=1000).to(device)
    ddpm_256 = load_model(ddpm_256, "Model_Savepoints/cascadedddpm128(225epoch)_checkpoint.pth")

    return ddpm_64, ddpm_128, ddpm_256

##### SSIM, FSIM, DSIM, MSE, PSNR, Visual Fidelity, LPIPS and Average Gradient

In [None]:
#pip install image-similarity-measures
#pip install torchmetrics
#pip install pyfftw

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
from PIL import Image
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchmetrics.image import VisualInformationFidelity
from image_similarity_measures.quality_metrics import fsim
from scipy.ndimage import sobel

In [28]:


class MRNetUpscaleDataset(Dataset):
    def __init__(self, slice_dir, label_files, transform_64=None, transform_256=None):
        super().__init__()
        self.slice_dir = slice_dir
        self.transform_64 = transform_64
        self.transform_256 = transform_256

        self.labels_dict = {}
        for label_file in label_files:
            records = pd.read_csv(label_file, header=None, names=['id', 'label'])
            records['id'] = records['id'].map(lambda i: '0' * (4 - len(str(i))) + str(i))
            self.labels_dict.update(dict(zip(records['id'], records['label'])))

        # List all slice files
        self.slice_files = [os.path.join(slice_dir, fname) for fname in os.listdir(slice_dir) if fname.endswith('.png')]

    def __len__(self):
        return len(self.slice_files)

    def __getitem__(self, index):
        slice_path = self.slice_files[index]
        image = Image.open(slice_path)
        
        if self.transform_64:
            image_64 = self.transform_64(image)
        if self.transform_256:
            image_256 = self.transform_256(image)
        
        # Extract ID from filename to match with label
        slice_id = os.path.basename(slice_path).split('_')[1]

        if slice_id in self.labels_dict:
            label = self.labels_dict[slice_id]
            label = torch.FloatTensor([label])
        else:
            print(f"Label for ID {slice_id} not found in the CSV file.")
            label = torch.FloatTensor([0]) 

        return {'data_64': image_64, 'data_256': image_256, 'label': label, 'id': slice_id}

# Define transformations for 64x64 and 256x256
transform_64 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((64, 64)),  # Resize to 64x64
    transforms.ToTensor()  # Convert to PyTorch tensor
])

transform_256 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((256, 256)),  # Resize to 256x256
    transforms.ToTensor()  # Convert to PyTorch tensor
])

# Initialize the dataset
root_dir = r"/rds/homes/a/avv306/Raw_Images/Raw_Images"
valid_slice_dir = os.path.join(root_dir, "valid_slices_raw")

valid_label_files = [
    os.path.join(root_dir, "valid-acl.csv"),
    os.path.join(root_dir, "valid-abnormal.csv"),
    os.path.join(root_dir, "valid-meniscus.csv")
]

# Create a dataset for 64x64 and 256x256
valid_dataset_64_256 = MRNetUpscaleDataset(
    slice_dir=valid_slice_dir,
    label_files=valid_label_files,
    transform_64=transform_64,
    transform_256=transform_256
)

# Create Subsets of the dataset for testing
subset_indices = list(range(100))  # Use the first 160 images for example
valid_subset_64_256 = Subset(valid_dataset_64_256, subset_indices)

# Create DataLoader for the subset
valid_loader_64_256 = DataLoader(valid_subset_64_256, batch_size=20, shuffle=True)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize PyTorch metric objects on the same device
psnr_metric = PeakSignalNoiseRatio().to(device)
ssim_metric = StructuralSimilarityIndexMeasure().to(device)
vif_metric = VisualInformationFidelity().to(device)
lpips_metric = LPIPS(net_type='vgg').to(device)  

# Function to prepare images for LPIPS
def prepare_for_lpips(img):
    # Convert grayscale to 3 channels by repeating the channel 3 times
    img_3c = img.repeat(1, 3, 1, 1)  # Repeat across the channel dimension
    # Normalize the image to [-1, 1]
    img_3c = (img_3c * 2) - 1
    # Clip the values to ensure they stay within the expected range
    img_3c = torch.clamp(img_3c, min=-1.0, max=1.0)
    return img_3c

# Function to compute PSNR using PyTorch
def compute_psnr(img1, img2):
    return psnr_metric(img1, img2).item()

# Function to compute SSIM and DSSIM using PyTorch
def compute_ssim_dssim(img1, img2):
    ssim_value = ssim_metric(img1, img2).item()
    dssim_value = (1 - ssim_value) / 2
    return ssim_value, dssim_value

# Function to compute MSE using PyTorch
def compute_mse(img1, img2):
    return torch.mean((img1 - img2) ** 2).item()

# Function to compute FSIM using image_similarity_measures
def compute_fsim(img1, img2):
    # Convert PyTorch tensors to NumPy arrays and scale to 8-bit integer
    img1_np = img1.cpu().numpy().squeeze() * 255
    img2_np = img2.cpu().numpy().squeeze() * 255

    # Ensure the images are 3D (add a channel dimension if they're 2D)
    if img1_np.ndim == 2:
        img1_np = np.expand_dims(img1_np, axis=-1)
    if img2_np.ndim == 2:
        img2_np = np.expand_dims(img2_np, axis=-1)
    
    img1_uint8 = img1_np.astype(np.uint8)
    img2_uint8 = img2_np.astype(np.uint8)

    # Compute FSIM using the image_similarity_measures library
    return fsim(img1_uint8, img2_uint8)

# Function to compute VIF using PyTorch
def compute_vif(img1, img2):
    return vif_metric(img1, img2).item()

# Function to compute LPIPS using PyTorch
def compute_lpips(img1, img2):
    img1 = prepare_for_lpips(img1)
    img2 = prepare_for_lpips(img2)
    return lpips_metric(img1, img2).item()

# Function to compute AG (Average Gradient)
def compute_ag(img):
    img_np = img.cpu().numpy().squeeze()
    gx = sobel(img_np, axis=0)
    gy = sobel(img_np, axis=1)
    grad_magnitude = np.sqrt(gx**2 + gy**2)
    return np.mean(grad_magnitude)

# Function to evaluate super-resolution models
def evaluate_super_resolution_cascaded(ddpm_128, ddpm_256, dataloader, device):
    psnr_values = []
    ssim_values = []
    dssim_values = []
    mse_values = []
    fsim_values = []
    vif_values = []
    lpips_values = []
    ag_values = []

    ddpm_128.eval()
    ddpm_256.eval()

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
            low_res = batch['data_64'].to(device)
            high_res = batch['data_256'].to(device)

            for i in range(low_res.size(0)):
                input_image = low_res[i].unsqueeze(0)
                target_image = high_res[i].unsqueeze(0)

                # Upscale from 64x64 to 256x256
                super_res_image = ddpm_128.sample(input_image)
                super_res_image = ddpm_256.sample(super_res_image)
                
                # Convert tensors back to images for evaluation
                high_res_img = target_image
                super_res_img = super_res_image

                # Compute metrics
                psnr_values.append(compute_psnr(high_res_img, super_res_img))
                ssim_value, dssim_value = compute_ssim_dssim(high_res_img, super_res_img)
                ssim_values.append(ssim_value)
                dssim_values.append(dssim_value)
                mse_values.append(compute_mse(high_res_img, super_res_img))
                fsim_values.append(compute_fsim(high_res_img, super_res_img))
                vif_values.append(compute_vif(high_res_img, super_res_img))
                lpips_values.append(compute_lpips(high_res_img, super_res_img))
                ag_values.append(compute_ag(super_res_img))
                
    return (np.mean(psnr_values), np.mean(ssim_values), np.mean(dssim_values), np.mean(mse_values), np.mean(fsim_values),
            np.mean(vif_values), np.mean(lpips_values), np.mean(ag_values))

# Load the models (ensure load_models is implemented as per your previous context)
ddpm_64, ddpm_128, ddpm_256 = load_models()

# Evaluate the cascaded super-resolution model from 64x64 to 256x256
(psnr_256, ssim_256, dssim_256, mse_256, fsim_256, 
 vif_256, lpips_256, ag_256) = evaluate_super_resolution_cascaded(ddpm_128, ddpm_256, valid_loader_64_256, device)

print(f"Cascaded Super-resolution from 64x64 to 256x256:")
print(f"PSNR={psnr_256}, SSIM={ssim_256}, DSSIM={dssim_256}, MSE={mse_256}, FSIM={fsim_256}")
print(f"VIF={vif_256}, LPIPS={lpips_256}, AG={ag_256}")


Processing batches: 100%|██████████| 5/5 [13:37<00:00, 163.46s/it]

Cascaded Super-resolution from 64x64 to 256x256:
PSNR=27.174150257110597, SSIM=0.7387099087238311, DSSIM=0.1306450456380844, MSE=0.0035601869842503218, FSIM=0.5452157156335029
VIF=0.8344831836223602, LPIPS=0.17526448905467987, AG=0.26342782378196716





### NOSWIN

##### 64x64 UNet

In [9]:
class AttentionUNet(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super(AttentionUNet, self).__init__()

        self.encoder1 = self.conv_block(in_channels + emb_dim, 64)
        self.self_attention1 = SelfAttentionBlock(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.encoder2 = self.conv_block(64, 128)
        self.self_attention2 = SelfAttentionBlock(128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.encoder3 = self.conv_block(128, 256)
        self.self_attention3 = SelfAttentionBlock(256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.bottleneck = self.conv_block(256, 512)
       
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.cross_attention3 = CrossAttentionBlock(256, 256, 128)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.cross_attention2 = CrossAttentionBlock(128, 128, 64)
        self.decoder2 = self.conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.cross_attention1 = CrossAttentionBlock(64, 64, 32)
        self.decoder1 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.timestep_embedding_layer = SinusoidalPositionalEmbedding(emb_dim)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t):
        t_embed = self.timestep_embedding_layer(t)
        t_embed = t_embed.view(t.size(0), -1, 1, 1)
        t_embed = t_embed.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat((x, t_embed), dim=1)

        # Encoder
        enc1 = self.encoder1(x)
        enc1 = self.self_attention1(enc1)
        enc1_pooled = self.pool1(enc1)

        enc2 = self.encoder2(enc1_pooled)
        enc2 = self.self_attention2(enc2)
        enc2_pooled = self.pool2(enc2)

        enc3 = self.encoder3(enc2_pooled)
        enc3 = self.self_attention3(enc3)
        enc3_pooled = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc3_pooled)

        # Decoder
        upconv3 = self.upconv3(bottleneck)
        enc3 = self.cross_attention3(upconv3, enc3)
        dec3 = torch.cat((upconv3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        upconv2 = self.upconv2(dec3)
        enc2 = self.cross_attention2(upconv2, enc2)
        dec2 = torch.cat((upconv2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        upconv1 = self.upconv1(dec2)
        enc1 = self.cross_attention1(upconv1, enc1)
        dec1 = torch.cat((upconv1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        final_output = self.final_conv(dec1)
        return final_output

##### 64x64 to 128x128 UNet

In [10]:
class SuperResUNet128(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super(SuperResUNet128, self).__init__()

        self.encoder1 = self.conv_block(in_channels + emb_dim, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.encoder2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.encoder3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.bottleneck = self.conv_block(256, 512)

        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self.conv_block(256, 128)

        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.timestep_embedding_layer = SinusoidalPositionalEmbedding(emb_dim)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t):
        t_embed = self.timestep_embedding_layer(t)
        t_embed = t_embed.view(t.size(0), -1, 1, 1)
        t_embed = t_embed.repeat(1, 1, x.size(2), x.size(3))

        x = torch.cat((x, t_embed), dim=1)

        # Encoder
        enc1 = self.encoder1(x)
        enc1_pooled = self.pool1(enc1)

        enc2 = self.encoder2(enc1_pooled)
        enc2_pooled = self.pool2(enc2)

        enc3 = self.encoder3(enc2_pooled)
        enc3_pooled = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc3_pooled)

        # Decoder
        upconv3 = self.upconv3(bottleneck)
        dec3 = torch.cat((upconv3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        upconv2 = self.upconv2(dec3)
        dec2 = torch.cat((upconv2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        upconv1 = self.upconv1(dec2)
        dec1 = torch.cat((upconv1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        final_output = self.final_conv(dec1)
        return final_output


##### 128x128 to 256x256 UNet

In [11]:
class SuperResUNet256(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super(SuperResUNet256, self).__init__()

        self.encoder1 = self.conv_block(in_channels + emb_dim, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.encoder2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.encoder3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.bottleneck = self.conv_block(256, 512)
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self.conv_block(512, 256)

        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self.conv_block(256, 128)
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self.conv_block(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.timestep_embedding_layer = SinusoidalPositionalEmbedding(emb_dim)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, t):
        t_embed = self.timestep_embedding_layer(t)
        t_embed = t_embed.view(t.size(0), -1, 1, 1)
        t_embed = t_embed.repeat(1, 1, x.size(2), x.size(3))

        x = torch.cat((x, t_embed), dim=1)

        # Encoder
        enc1 = self.encoder1(x)
        enc1_pooled = self.pool1(enc1)

        enc2 = self.encoder2(enc1_pooled)
        enc2_pooled = self.pool2(enc2)

        enc3 = self.encoder3(enc2_pooled)
        enc3_pooled = self.pool3(enc3)

        # Bottleneck
        bottleneck = self.bottleneck(enc3_pooled)

        # Decoder
        upconv3 = self.upconv3(bottleneck)
        dec3 = torch.cat((upconv3, enc3), dim=1)
        dec3 = self.decoder3(dec3)

        upconv2 = self.upconv2(dec3)
        dec2 = torch.cat((upconv2, enc2), dim=1)
        dec2 = self.decoder2(dec2)

        upconv1 = self.upconv1(dec2)
        dec1 = torch.cat((upconv1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        final_output = self.final_conv(dec1)
        return final_output


##### Load Models

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_model(ddpm, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    ddpm.load_state_dict(checkpoint['model_state_dict'])
    return ddpm

# Load the three models from checkpoints
def load_models():
    # Model 1: 64x64 Generation
    unet_64 = AttentionUNet(in_channels=1, out_channels=1, emb_dim=128).to(device)
    ddpm_64 = DDPM(unet_64, num_timesteps=2000, latent_dim=128).to(device)
    ddpm_64 = load_model(ddpm_64, "Model_Savepoints/ddpm64NOSWIN_checkpoint.pth")

    # Model 2: 64x64 to 128x128 Super-resolution
    unet_128 = SuperResUNet128(in_channels=2, out_channels=1, emb_dim=128).to(device)
    ddpm_128 = SuperResDDPM(unet_128, num_timesteps=1000).to(device)
    ddpm_128 = load_model(ddpm_128, "Model_Savepoints/cascadedddpm64NOSWIN(100epoch)_checkpoint.pth")

    # Model 3: 128x128 to 256x256 Super-resolution
    unet_256 = SuperResUNet256(in_channels=2, out_channels=1, emb_dim=128).to(device)
    ddpm_256 = SuperResDDPM(unet_256, num_timesteps=1000).to(device)
    ddpm_256 = load_model(ddpm_256, "Model_Savepoints/cascadedddpm128NOSWIN_checkpoint.pth")

    return ddpm_64, ddpm_128, ddpm_256

##### SSIM, FSIM, DSIM, MSE, PSNR, Visual Fidelity, LPIPS and Average Gradient

In [47]:
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
from image_similarity_measures.quality_metrics import fsim  # Import FSIM from image_similarity_measures


class MRNetUpscaleDataset(Dataset):
    def __init__(self, slice_dir, label_files, transform_64=None, transform_256=None):
        super().__init__()
        self.slice_dir = slice_dir
        self.transform_64 = transform_64
        self.transform_256 = transform_256

        self.labels_dict = {}
        for label_file in label_files:
            records = pd.read_csv(label_file, header=None, names=['id', 'label'])
            records['id'] = records['id'].map(lambda i: '0' * (4 - len(str(i))) + str(i))
            self.labels_dict.update(dict(zip(records['id'], records['label'])))

        # List all slice files
        self.slice_files = [os.path.join(slice_dir, fname) for fname in os.listdir(slice_dir) if fname.endswith('.png')]

    def __len__(self):
        return len(self.slice_files)

    def __getitem__(self, index):
        slice_path = self.slice_files[index]
        image = Image.open(slice_path)
        
        if self.transform_64:
            image_64 = self.transform_64(image)
        if self.transform_256:
            image_256 = self.transform_256(image)
        
        # Extract ID from filename to match with label
        slice_id = os.path.basename(slice_path).split('_')[1]

        if slice_id in self.labels_dict:
            label = self.labels_dict[slice_id]
            label = torch.FloatTensor([label])
        else:
            print(f"Label for ID {slice_id} not found in the CSV file.")
            label = torch.FloatTensor([0])  

        return {'data_64': image_64, 'data_256': image_256, 'label': label, 'id': slice_id}

# Define transformations for 64x64 and 256x256
transform_64 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((64, 64)),  # Resize to 64x64
    transforms.ToTensor()  # Convert to PyTorch tensor
])

transform_256 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((256, 256)),  # Resize to 256x256
    transforms.ToTensor()  # Convert to PyTorch tensor
])

# Initialize the dataset
root_dir = r"/rds/homes/a/avv306/Raw_Images/Raw_Images"
valid_slice_dir = os.path.join(root_dir, "valid_slices_raw")

valid_label_files = [
    os.path.join(root_dir, "valid-acl.csv"),
    os.path.join(root_dir, "valid-abnormal.csv"),
    os.path.join(root_dir, "valid-meniscus.csv")
]

# Create a dataset for 64x64 and 256x256
valid_dataset_64_256 = MRNetUpscaleDataset(
    slice_dir=valid_slice_dir,
    label_files=valid_label_files,
    transform_64=transform_64,
    transform_256=transform_256
)

# Create Subsets of the dataset for testing
subset_indices = list(range(160))  # Use the first 160 images for example
valid_subset_64_256 = Subset(valid_dataset_64_256, subset_indices)

# Create DataLoader for the subset
valid_loader_64_256 = DataLoader(valid_subset_64_256, batch_size=16, shuffle=True)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize PyTorch metric objects on the same device
psnr_metric = PeakSignalNoiseRatio().to(device)
ssim_metric = StructuralSimilarityIndexMeasure().to(device)

# Function to compute PSNR using PyTorch
def compute_psnr(img1, img2):
    return psnr_metric(img1, img2).item()

# Function to compute SSIM and DSSIM using PyTorch
def compute_ssim_dssim(img1, img2):
    ssim_value = ssim_metric(img1, img2).item()
    dssim_value = (1 - ssim_value) / 2
    return ssim_value, dssim_value

# Function to compute MSE using PyTorch
def compute_mse(img1, img2):
    return torch.mean((img1 - img2) ** 2).item()

# Function to compute FSIM using image_similarity_measures
def compute_fsim(img1, img2):
    # Convert PyTorch tensors to NumPy arrays and scale to 8-bit integer
    img1_np = img1.cpu().numpy().squeeze() * 255
    img2_np = img2.cpu().numpy().squeeze() * 255

    # Ensure the images are 3D (add a channel dimension if they're 2D)
    if img1_np.ndim == 2:
        img1_np = np.expand_dims(img1_np, axis=-1)
    if img2_np.ndim == 2:
        img2_np = np.expand_dims(img2_np, axis=-1)
    
    img1_uint8 = img1_np.astype(np.uint8)
    img2_uint8 = img2_np.astype(np.uint8)

    # Compute FSIM using the image_similarity_measures library
    return fsim(img1_uint8, img2_uint8)


# Function to evaluate super-resolution models
def evaluate_super_resolution_cascaded(ddpm_128, ddpm_256, dataloader, device):
    psnr_values = []
    ssim_values = []
    dssim_values = []
    mse_values = []
    fsim_values = []

    ddpm_128.eval()
    ddpm_256.eval()

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
            low_res = batch['data_64'].to(device)
            high_res = batch['data_256'].to(device)

            for i in range(low_res.size(0)):
                input_image = low_res[i].unsqueeze(0)
                #print(input_image.shape)
                target_image = high_res[i].unsqueeze(0)
                #print(target_image.shape)
                # Upscale from 64x64 to 256x256
                super_res_image = ddpm_128.sample(input_image)
                super_res_image = ddpm_256.sample(super_res_image)
                #print(super_res_image.shape)
                # Convert tensors back to images for evaluation
                high_res_img = target_image
                super_res_img = super_res_image

                psnr_values.append(compute_psnr(high_res_img, super_res_img))
                ssim_value, dssim_value = compute_ssim_dssim(high_res_img, super_res_img)
                ssim_values.append(ssim_value)
                dssim_values.append(dssim_value)
                mse_values.append(compute_mse(high_res_img, super_res_img))
                fsim_values.append(compute_fsim(high_res_img, super_res_img))
                
    return np.mean(psnr_values), np.mean(ssim_values), np.mean(dssim_values), np.mean(mse_values), np.mean(fsim_values)

# Load the models 
ddpm_64, ddpm_128, ddpm_256 = load_models()

# Evaluate the cascaded super-resolution model from 64x64 to 256x256
psnr_256, ssim_256, dssim_256, mse_256, fsim_256 = evaluate_super_resolution_cascaded(ddpm_128, ddpm_256, valid_loader_64_256, device)

print(f"Cascaded NO SWIN Super-resolution from 64x64 to 256x256: PSNR={psnr_256}, SSIM={ssim_256}, DSSIM={dssim_256}, MSE={mse_256}, FSIM={fsim_256}")


Processing batches: 100%|██████████| 10/10 [10:26<00:00, 62.67s/it]

Cascaded NO SWIN Super-resolution from 64x64 to 256x256: PSNR=27.686824488639832, SSIM=0.6121361188590526, DSSIM=0.19393194057047367, MSE=0.004092157847480848, FSIM=0.45221029153800674





In [48]:
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchmetrics.image import VisualInformationFidelity
from scipy.ndimage import sobel


class MRNetUpscaleDataset(Dataset):
    def __init__(self, slice_dir, label_files, transform_64=None, transform_256=None):
        super().__init__()
        self.slice_dir = slice_dir
        self.transform_64 = transform_64
        self.transform_256 = transform_256

        self.labels_dict = {}
        for label_file in label_files:
            records = pd.read_csv(label_file, header=None, names=['id', 'label'])
            records['id'] = records['id'].map(lambda i: '0' * (4 - len(str(i))) + str(i))
            self.labels_dict.update(dict(zip(records['id'], records['label'])))

        # List all slice files
        self.slice_files = [os.path.join(slice_dir, fname) for fname in os.listdir(slice_dir) if fname.endswith('.png')]

    def __len__(self):
        return len(self.slice_files)

    def __getitem__(self, index):
        slice_path = self.slice_files[index]
        image = Image.open(slice_path)
        
        if self.transform_64:
            image_64 = self.transform_64(image)
        if self.transform_256:
            image_256 = self.transform_256(image)
        
        # Extract ID from filename to match with label
        slice_id = os.path.basename(slice_path).split('_')[1]

        if slice_id in self.labels_dict:
            label = self.labels_dict[slice_id]
            label = torch.FloatTensor([label])
        else:
            print(f"Label for ID {slice_id} not found in the CSV file.")
            label = torch.FloatTensor([0])  

        return {'data_64': image_64, 'data_256': image_256, 'label': label, 'id': slice_id}

# Define transformations for 64x64 and 256x256
transform_64 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((64, 64)),  # Resize to 64x64
    transforms.ToTensor()  # Convert to PyTorch tensor
])

transform_256 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((256, 256)),  # Resize to 256x256
    transforms.ToTensor()  # Convert to PyTorch tensor
])

# Initialize the dataset
root_dir = r"/rds/homes/a/avv306/Raw_Images/Raw_Images"
valid_slice_dir = os.path.join(root_dir, "valid_slices_raw")

valid_label_files = [
    os.path.join(root_dir, "valid-acl.csv"),
    os.path.join(root_dir, "valid-abnormal.csv"),
    os.path.join(root_dir, "valid-meniscus.csv")
]

# Create a dataset for 64x64 and 256x256
valid_dataset_64_256 = MRNetUpscaleDataset(
    slice_dir=valid_slice_dir,
    label_files=valid_label_files,
    transform_64=transform_64,
    transform_256=transform_256
)

# Create Subsets of the dataset for testing
subset_indices = list(range(160))  # Use the first 160 images for example
valid_subset_64_256 = Subset(valid_dataset_64_256, subset_indices)

# Create DataLoader for the subset
valid_loader_64_256 = DataLoader(valid_subset_64_256, batch_size=16, shuffle=True)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize PyTorch metric objects on the same device
vif_metric = VisualInformationFidelity().to(device)
lpips_metric = LPIPS(net_type='vgg').to(device) 

# Function to prepare images for LPIPS
def prepare_for_lpips(img):
    # Convert grayscale to 3 channels by repeating the channel 3 times
    img_3c = img.repeat(1, 3, 1, 1)  # Repeat across the channel dimension
    # Normalize the image to [-1, 1]
    img_3c = (img_3c * 2) - 1
    # Clip the values to ensure they stay within the expected range
    img_3c = torch.clamp(img_3c, min=-1.0, max=1.0)
    return img_3c


# Function to compute VIF using PyTorch
def compute_vif(img1, img2):
    return vif_metric(img1, img2).item()

# Function to compute LPIPS using PyTorch
def compute_lpips(img1, img2):
    img1 = prepare_for_lpips(img1)
    img2 = prepare_for_lpips(img2)
    return lpips_metric(img1, img2).item()

# Function to compute AG (Average Gradient)
def compute_ag(img):
    img_np = img.cpu().numpy().squeeze()
    gx = sobel(img_np, axis=0)
    gy = sobel(img_np, axis=1)
    grad_magnitude = np.sqrt(gx**2 + gy**2)
    return np.mean(grad_magnitude)

# Function to evaluate super-resolution models
def evaluate_super_resolution_cascaded(ddpm_128, ddpm_256, dataloader, device):
    vif_values = []
    lpips_values = []
    ag_values = []

    ddpm_128.eval()
    ddpm_256.eval()

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing batches")):
            low_res = batch['data_64'].to(device)
            high_res = batch['data_256'].to(device)

            for i in range(low_res.size(0)):
                input_image = low_res[i].unsqueeze(0)
                target_image = high_res[i].unsqueeze(0)
                # Upscale from 64x64 to 256x256
                super_res_image = ddpm_128.sample(input_image)
                super_res_image = ddpm_256.sample(super_res_image)
                
                # Convert tensors back to images for evaluation
                high_res_img = target_image
                super_res_img = super_res_image

                vif_values.append(compute_vif(high_res_img, super_res_img))
                lpips_values.append(compute_lpips(high_res_img, super_res_img))
                ag_values.append(compute_ag(super_res_img))
                
    return np.mean(vif_values), np.mean(lpips_values), np.mean(ag_values)

# Load the models 
ddpm_64, ddpm_128, ddpm_256 = load_models()

# Evaluate the cascaded super-resolution model from 64x64 to 256x256
vif_256, lpips_256, ag_256 = evaluate_super_resolution_cascaded(ddpm_128, ddpm_256, valid_loader_64_256, device)

print(f"Cascaded NO SWIN Super-resolution from 64x64 to 256x256: VIF={vif_256}, LPIPS={lpips_256}, AG={ag_256}")


Processing batches: 100%|██████████| 10/10 [09:57<00:00, 59.75s/it]

Cascaded NO SWIN Super-resolution from 64x64 to 256x256: VIF=0.8160529281944037, LPIPS=0.36456783562898637, AG=0.3209536373615265





### INTERPOLATION

In [20]:



class MRNetUpscaleDataset(Dataset):
    def __init__(self, slice_dir, label_files, transform_64=None, transform_256=None):
        super().__init__()
        self.slice_dir = slice_dir
        self.transform_64 = transform_64
        self.transform_256 = transform_256

        self.labels_dict = {}
        for label_file in label_files:
            records = pd.read_csv(label_file, header=None, names=['id', 'label'])
            records['id'] = records['id'].map(lambda i: '0' * (4 - len(str(i))) + str(i))
            self.labels_dict.update(dict(zip(records['id'], records['label'])))

        # List all slice files
        self.slice_files = [os.path.join(slice_dir, fname) for fname in os.listdir(slice_dir) if fname.endswith('.png')]

    def __len__(self):
        return len(self.slice_files)

    def __getitem__(self, index):
        slice_path = self.slice_files[index]
        image = Image.open(slice_path)
        
        if self.transform_64:
            image_64 = self.transform_64(image)
        if self.transform_256:
            image_256 = self.transform_256(image)
        
        # Extract ID from filename to match with label
        slice_id = os.path.basename(slice_path).split('_')[1]

        if slice_id in self.labels_dict:
            label = self.labels_dict[slice_id]
            label = torch.FloatTensor([label])
        else:
            print(f"Label for ID {slice_id} not found in the CSV file.")
            label = torch.FloatTensor([0])  

        return {'data_64': image_64, 'data_256': image_256, 'label': label, 'id': slice_id}

# Define transformations for 64x64 and 256x256
transform_64 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((64, 64)),  # Resize to 64x64
    transforms.ToTensor()  # Convert to PyTorch tensor
])

transform_256 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((256, 256)),  # Resize to 256x256
    transforms.ToTensor()  # Convert to PyTorch tensor
])

# Initialize the dataset
root_dir = r"/rds/homes/a/avv306/Raw_Images/Raw_Images"
valid_slice_dir = os.path.join(root_dir, "valid_slices_raw")

valid_label_files = [
    os.path.join(root_dir, "valid-acl.csv"),
    os.path.join(root_dir, "valid-abnormal.csv"),
    os.path.join(root_dir, "valid-meniscus.csv")
]

# Create a dataset for 64x64 and 256x256
valid_dataset_64_256 = MRNetUpscaleDataset(
    slice_dir=valid_slice_dir,
    label_files=valid_label_files,
    transform_64=transform_64,
    transform_256=transform_256
)

# Create Subsets of the dataset for testing
subset_indices = list(range(160))  # Use the first 160 images for example
valid_subset_64_256 = Subset(valid_dataset_64_256, subset_indices)

# Create DataLoader for the subset
valid_loader_64_256 = DataLoader(valid_subset_64_256, batch_size=16, shuffle=True)

# Initialize the torchmetrics objects
psnr_metric = PeakSignalNoiseRatio().to(device)
ssim_metric = StructuralSimilarityIndexMeasure().to(device)

# Function to compute PSNR using torchmetrics
def compute_psnr(img1, img2):
    return psnr_metric(img1, img2).item()

# Function to compute SSIM and DSSIM using torchmetrics
def compute_ssim_dssim(img1, img2):
    ssim_value = ssim_metric(img1, img2).item()
    dssim_value = (1 - ssim_value) / 2
    return ssim_value, dssim_value

# Function to compute MSE using PyTorch
def compute_mse(img1, img2):
    return torch.mean((img1 - img2) ** 2).item()

# Function to compute FSIM using image_similarity_measures
def compute_fsim(img1, img2):
    # Convert PyTorch tensors to NumPy arrays and scale to 8-bit integer
    img1_np = img1.cpu().numpy().squeeze() * 255
    img2_np = img2.cpu().numpy().squeeze() * 255

    # Ensure the images are 3D (add a channel dimension if they're 2D)
    if img1_np.ndim == 2:
        img1_np = np.expand_dims(img1_np, axis=-1)
    if img2_np.ndim == 2:
        img2_np = np.expand_dims(img2_np, axis=-1)
    
    img1_uint8 = img1_np.astype(np.uint8)
    img2_uint8 = img2_np.astype(np.uint8)

    # Compute FSIM using the image_similarity_measures library
    return fsim(img1_uint8, img2_uint8)

# Function to perform interpolation from 64x64 directly to 256x256 and evaluate metrics
def evaluate_direct_interpolation(interpolation_method, dataloader, device):
    psnr_values = []
    ssim_values = []
    dssim_values = []
    mse_values = []
    fsim_values = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {interpolation_method} interpolation"):
            low_res = batch['data_64'].to(device)
            high_res = batch['data_256'].to(device)

            for i in range(low_res.size(0)):
                low_res_img = low_res[i].unsqueeze(0)  # Add batch dimension
                high_res_img = high_res[i].unsqueeze(0)  # Add batch dimension

                # Apply interpolation directly from 64x64 to 256x256
                low_res_img_np = low_res_img.cpu().numpy().squeeze()
                if interpolation_method == 'bilinear':
                    upsampled_img_np = cv2.resize(low_res_img_np, (256, 256), interpolation=cv2.INTER_LINEAR)
                elif interpolation_method == 'bicubic':
                    upsampled_img_np = cv2.resize(low_res_img_np, (256, 256), interpolation=cv2.INTER_CUBIC)
                elif interpolation_method == 'lanczos':
                    upsampled_img_np = cv2.resize(low_res_img_np, (256, 256), interpolation=cv2.INTER_LANCZOS4)
                else:  # default to nearest-neighbor
                    upsampled_img_np = cv2.resize(low_res_img_np, (256, 256), interpolation=cv2.INTER_NEAREST)

                # Convert back to torch tensor
                upsampled_img = torch.tensor(upsampled_img_np, device=device).unsqueeze(0).unsqueeze(0)

                # Evaluate metrics
                psnr_values.append(compute_psnr(high_res_img, upsampled_img))
                ssim_value, dssim_value = compute_ssim_dssim(high_res_img, upsampled_img)
                ssim_values.append(ssim_value)
                dssim_values.append(dssim_value)
                mse_values.append(compute_mse(high_res_img, upsampled_img))
                fsim_values.append(compute_fsim(high_res_img, upsampled_img))

    return np.mean(psnr_values), np.mean(ssim_values), np.mean(dssim_values), np.mean(mse_values), np.mean(fsim_values)



# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")

# List of interpolation methods
interpolation_methods = ['nearest', 'bilinear', 'bicubic', 'lanczos']

# Evaluate all interpolation techniques directly from 64x64 to 256x256
print("\nEvaluating Direct Interpolation Techniques for 64x64 to 256x256:")
for method in interpolation_methods:
    psnr256, ssim256, dssim256, mse256, fsim256 = evaluate_direct_interpolation(method, valid_loader_64_256, device)
    print(f"{method.capitalize()} Interpolation from 64x64 to 256x256: PSNR={psnr256}, SSIM={ssim256}, DSSIM={dssim256}, MSE={mse256}, FSIM={fsim256}")


Using GPU: NVIDIA A100-SXM4-40GB

Evaluating Direct Interpolation Techniques for 64x64 to 256x256:


Evaluating nearest interpolation: 100%|██████████| 10/10 [00:30<00:00,  3.10s/it]


Nearest Interpolation from 64x64 to 256x256: PSNR=23.112764418125153, SSIM=0.6773493640124798, DSSIM=0.1613253179937601, MSE=0.004943166526209098, FSIM=0.5011008700408051


Evaluating bilinear interpolation: 100%|██████████| 10/10 [00:30<00:00,  3.10s/it]


Bilinear Interpolation from 64x64 to 256x256: PSNR=23.694603943824767, SSIM=0.700542651861906, DSSIM=0.14972867406904697, MSE=0.004306619871204021, FSIM=0.6232317244863224


Evaluating bicubic interpolation: 100%|██████████| 10/10 [00:30<00:00,  3.09s/it]


Bicubic Interpolation from 64x64 to 256x256: PSNR=25.035907018184663, SSIM=0.7310318142175675, DSSIM=0.13448409289121627, MSE=0.0035574150926549917, FSIM=0.6323735310671865


Evaluating lanczos interpolation: 100%|██████████| 10/10 [00:30<00:00,  3.10s/it]

Lanczos Interpolation from 64x64 to 256x256: PSNR=25.097058725357055, SSIM=0.7344014767557383, DSSIM=0.13279926162213088, MSE=0.003449835801438894, FSIM=0.6401947585409061





In [22]:
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchmetrics.image import VisualInformationFidelity
from scipy.ndimage import sobel


class MRNetUpscaleDataset(Dataset):
    def __init__(self, slice_dir, label_files, transform_64=None, transform_256=None):
        super().__init__()
        self.slice_dir = slice_dir
        self.transform_64 = transform_64
        self.transform_256 = transform_256

        self.labels_dict = {}
        for label_file in label_files:
            records = pd.read_csv(label_file, header=None, names=['id', 'label'])
            records['id'] = records['id'].map(lambda i: '0' * (4 - len(str(i))) + str(i))
            self.labels_dict.update(dict(zip(records['id'], records['label'])))

        # List all slice files
        self.slice_files = [os.path.join(slice_dir, fname) for fname in os.listdir(slice_dir) if fname.endswith('.png')]

    def __len__(self):
        return len(self.slice_files)

    def __getitem__(self, index):
        slice_path = self.slice_files[index]
        image = Image.open(slice_path)
        
        if self.transform_64:
            image_64 = self.transform_64(image)
        if self.transform_256:
            image_256 = self.transform_256(image)
        
        # Extract ID from filename to match with label
        slice_id = os.path.basename(slice_path).split('_')[1]

        if slice_id in self.labels_dict:
            label = self.labels_dict[slice_id]
            label = torch.FloatTensor([label])
        else:
            print(f"Label for ID {slice_id} not found in the CSV file.")
            label = torch.FloatTensor([0])  

        return {'data_64': image_64, 'data_256': image_256, 'label': label, 'id': slice_id}

# Define transformations for 64x64 and 256x256
transform_64 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((64, 64)),  # Resize to 64x64
    transforms.ToTensor()  # Convert to PyTorch tensor
])

transform_256 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((256, 256)),  # Resize to 256x256
    transforms.ToTensor()  # Convert to PyTorch tensor
])

# Initialize the dataset
root_dir = r"/rds/homes/a/avv306/Raw_Images/Raw_Images"
valid_slice_dir = os.path.join(root_dir, "valid_slices_raw")

valid_label_files = [
    os.path.join(root_dir, "valid-acl.csv"),
    os.path.join(root_dir, "valid-abnormal.csv"),
    os.path.join(root_dir, "valid-meniscus.csv")
]

# Create a dataset for 64x64 and 256x256
valid_dataset_64_256 = MRNetUpscaleDataset(
    slice_dir=valid_slice_dir,
    label_files=valid_label_files,
    transform_64=transform_64,
    transform_256=transform_256
)

# Create Subsets of the dataset for testing
subset_indices = list(range(160))  # Use the first 160 images for example
valid_subset_64_256 = Subset(valid_dataset_64_256, subset_indices)

# Create DataLoader for the subset
valid_loader_64_256 = DataLoader(valid_subset_64_256, batch_size=16, shuffle=True)

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")

# Initialize PyTorch metric objects on the same device
vif_metric = VisualInformationFidelity().to(device)
lpips_metric = LPIPS(net_type='vgg').to(device) 

# Function to prepare images for LPIPS
def prepare_for_lpips(img):
    # Convert grayscale to 3 channels by repeating the channel 3 times
    img_3c = img.repeat(1, 3, 1, 1)  # Repeat across the channel dimension
    # Normalize the image to [-1, 1]
    img_3c = (img_3c * 2) - 1
    # Clip the values to ensure they stay within the expected range
    img_3c = torch.clamp(img_3c, min=-1.0, max=1.0)
    return img_3c

# Function to compute VIF using PyTorch
def compute_vif(img1, img2):
    return vif_metric(img1, img2).item()

# Function to compute LPIPS using PyTorch
def compute_lpips(img1, img2):
    img1 = prepare_for_lpips(img1)
    img2 = prepare_for_lpips(img2)
    return lpips_metric(img1, img2).item()

# Function to compute AG (Average Gradient)
def compute_ag(img):
    img_np = img.cpu().numpy().squeeze()
    gx = sobel(img_np, axis=0)
    gy = sobel(img_np, axis=1)
    grad_magnitude = np.sqrt(gx**2 + gy**2)
    return np.mean(grad_magnitude)

# Function to perform interpolation from 64x64 directly to 256x256 and evaluate metrics
def evaluate_direct_interpolation(interpolation_method, dataloader, device):
    vif_values = []
    lpips_values = []
    ag_values = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Evaluating {interpolation_method} interpolation"):
            low_res = batch['data_64'].to(device)
            high_res = batch['data_256'].to(device)

            for i in range(low_res.size(0)):
                low_res_img = low_res[i].unsqueeze(0)  # Add batch dimension
                high_res_img = high_res[i].unsqueeze(0)  # Add batch dimension

                # Apply interpolation directly from 64x64 to 256x256
                low_res_img_np = low_res_img.cpu().numpy().squeeze()
                if interpolation_method == 'bilinear':
                    upsampled_img_np = cv2.resize(low_res_img_np, (256, 256), interpolation=cv2.INTER_LINEAR)
                elif interpolation_method == 'bicubic':
                    upsampled_img_np = cv2.resize(low_res_img_np, (256, 256), interpolation=cv2.INTER_CUBIC)
                elif interpolation_method == 'lanczos':
                    upsampled_img_np = cv2.resize(low_res_img_np, (256, 256), interpolation=cv2.INTER_LANCZOS4)
                else:  # default to nearest-neighbor
                    upsampled_img_np = cv2.resize(low_res_img_np, (256, 256), interpolation=cv2.INTER_NEAREST)

                # Convert back to torch tensor
                upsampled_img = torch.tensor(upsampled_img_np, device=device).unsqueeze(0).unsqueeze(0)

                # Evaluate metrics
                vif_values.append(compute_vif(high_res_img, upsampled_img))
                lpips_values.append(compute_lpips(high_res_img, upsampled_img))
                ag_values.append(compute_ag(upsampled_img))

    return np.mean(vif_values), np.mean(lpips_values), np.mean(ag_values)

# List of interpolation methods
interpolation_methods = ['nearest', 'bilinear', 'bicubic', 'lanczos']

# Evaluate all interpolation techniques directly from 64x64 to 256x256
print("\nEvaluating Direct Interpolation Techniques for 64x64 to 256x256:")
for method in interpolation_methods:
    vif_256, lpips_256, ag_256 = evaluate_direct_interpolation(method, valid_loader_64_256, device)
    print(f"{method.capitalize()} Interpolation from 64x64 to 256x256: VIF={vif_256}, LPIPS={lpips_256}, AG={ag_256}")


Using GPU: NVIDIA A100-SXM4-40GB

Evaluating Direct Interpolation Techniques for 64x64 to 256x256:


Evaluating nearest interpolation: 100%|██████████| 10/10 [00:01<00:00,  6.58it/s]


Nearest Interpolation from 64x64 to 256x256: VIF=1.08458410538733, LPIPS=0.4279451141133904, AG=0.15417876839637756


Evaluating bilinear interpolation: 100%|██████████| 10/10 [00:01<00:00,  6.65it/s]


Bilinear Interpolation from 64x64 to 256x256: VIF=1.5611980214715004, LPIPS=0.49902590084820986, AG=0.13009855151176453


Evaluating bicubic interpolation: 100%|██████████| 10/10 [00:01<00:00,  6.62it/s]


Bicubic Interpolation from 64x64 to 256x256: VIF=1.3168386608362197, LPIPS=0.47253904230892657, AG=0.14681777358055115


Evaluating lanczos interpolation: 100%|██████████| 10/10 [00:01<00:00,  6.56it/s]

Lanczos Interpolation from 64x64 to 256x256: VIF=1.3331476621329785, LPIPS=0.4727118736132979, AG=0.147621750831604





In [21]:
import os
import cv2
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
from PIL import Image
import pandas as pd
from tqdm import tqdm




class MRNetUpscaleDataset(Dataset):
    def __init__(self, slice_dir, label_files, transform_64=None, transform_256=None):
        super().__init__()
        self.slice_dir = slice_dir
        self.transform_64 = transform_64
        self.transform_256 = transform_256

        self.labels_dict = {}
        for label_file in label_files:
            records = pd.read_csv(label_file, header=None, names=['id', 'label'])
            records['id'] = records['id'].map(lambda i: '0' * (4 - len(str(i))) + str(i))
            self.labels_dict.update(dict(zip(records['id'], records['label'])))

        # List all slice files
        self.slice_files = [os.path.join(slice_dir, fname) for fname in os.listdir(slice_dir) if fname.endswith('.png')]

    def __len__(self):
        return len(self.slice_files)

    def __getitem__(self, index):
        slice_path = self.slice_files[index]
        image = Image.open(slice_path)
        
        if self.transform_64:
            image_64 = self.transform_64(image)
        if self.transform_256:
            image_256 = self.transform_256(image)
        
        # Extract ID from filename to match with label
        slice_id = os.path.basename(slice_path).split('_')[1]

        if slice_id in self.labels_dict:
            label = self.labels_dict[slice_id]
            label = torch.FloatTensor([label])
        else:
            print(f"Label for ID {slice_id} not found in the CSV file.")
            label = torch.FloatTensor([0])  

        return {'data_64': image_64, 'data_256': image_256, 'label': label, 'id': slice_id}

# Define transformations for 64x64 and 256x256
transform_64 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((64, 64)),  # Resize to 64x64
    transforms.ToTensor()  # Convert to PyTorch tensor
])

transform_256 = transforms.Compose([
    transforms.Grayscale(),  # Ensure images are single-channel
    transforms.Resize((256, 256)),  # Resize to 256x256
    transforms.ToTensor()  # Convert to PyTorch tensor
])

# Initialize the dataset
root_dir = r"C:\Users\ASUS\Documents\Uobd\project\Development\Datasets\MRNet-v1.0\Raw_Images"
valid_slice_dir = os.path.join(root_dir, "valid_slices_raw")

valid_label_files = [
    os.path.join(root_dir, "valid-acl.csv"),
    os.path.join(root_dir, "valid-abnormal.csv"),
    os.path.join(root_dir, "valid-meniscus.csv")
]

# Create a dataset for 64x64 and 256x256
valid_dataset_64_256 = MRNetUpscaleDataset(
    slice_dir=valid_slice_dir,
    label_files=valid_label_files,
    transform_64=transform_64,
    transform_256=transform_256
)

# Create Subsets of the dataset for testing
subset_indices = list(range(100))  # Use the first 100 images for example
valid_subset_64_256 = Subset(valid_dataset_64_256, subset_indices)

# Create DataLoader for the subset
valid_loader_64_256 = DataLoader(valid_subset_64_256, batch_size=20, shuffle=True)

# Function to perform interpolation from 64x64 directly to 256x256, generate super-res image, and save comparison as an image
def evaluate_and_save_comparison_with_sr(dataloader, ddpm_128, ddpm_256, device, save_dir='trial'):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # We'll process only one batch for display
    batch = next(iter(dataloader))
    
    low_res = batch['data_64'].to(device)
    high_res = batch['data_256'].to(device)

    interpolation_methods = {
        'Nearest': cv2.INTER_NEAREST,
        'Bilinear': cv2.INTER_LINEAR,
        'Bicubic': cv2.INTER_CUBIC,
        'Lanczos': cv2.INTER_LANCZOS4
    }

    for i in range(low_res.size(0)):
        low_res_img = low_res[i].cpu().numpy().squeeze()
        high_res_img = high_res[i].cpu().numpy().squeeze()

        # Create a figure to hold the comparison table
        fig, axs = plt.subplots(1, len(interpolation_methods) + 3, figsize=(25, 5))
        axs[0].imshow(low_res_img, cmap='gray')
        axs[0].set_title('64x64 Low Res')
        axs[0].axis('off')

        # Perform and display interpolation methods
        for idx, (name, method) in enumerate(interpolation_methods.items(), start=1):
            upsampled_img = cv2.resize(low_res_img, (256, 256), interpolation=method)
            axs[idx].imshow(upsampled_img, cmap='gray')
            axs[idx].set_title(f'{name} 256x256')
            axs[idx].axis('off')
            plt.imsave(os.path.join(save_dir, f'{name}_upsampled_{i}.png'), upsampled_img, cmap='gray')

        # Perform super-resolution with the model
        with torch.no_grad():
            input_tensor = low_res[i].unsqueeze(0).to(device)
            intermediate_image = ddpm_128.sample(input_tensor)
            super_res_image = ddpm_256.sample(intermediate_image).cpu().numpy().squeeze()

        axs[-2].imshow(super_res_image, cmap='gray')
        axs[-2].set_title('NO SWIN Super-Resolution Model')
        axs[-2].axis('off')

        axs[-1].imshow(high_res_img, cmap='gray')
        axs[-1].set_title('256x256 High Res')
        axs[-1].axis('off')

        # Save the comparison table as an image
        comparison_image_path = os.path.join(save_dir, f'comparison_with_sr_{i}.png')
        plt.savefig(comparison_image_path)
        plt.close(fig)

        print(f"Saved comparison image with SR at {comparison_image_path}")
        break  # Process only one image for comparison

# Load the super-resolution models
ddpm_64, ddpm_128, ddpm_256 = load_models()  # Ensure you load the models correctly

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")

# Save and display the comparison for one image with SR and interpolation methods
evaluate_and_save_comparison_with_sr(valid_loader_64_256, ddpm_128, ddpm_256, device)


Using GPU: NVIDIA A100-SXM4-40GB
Saved comparison image with SR at comparison_images/comparison_with_sr_0.png
