# Evaluation of Cascaded Super-Resolution Models Against Direct 256x256 DDPM Model Using FID and Inception Score

This notebook is designed to assess the performance of cascaded super-resolution (SR) models, both with and without the Swin Transformer, against a DDPM model trained directly on 256x256 images. The evaluation employs two key metrics: Frechet Inception Distance (FID) and Inception Score (IS). FID is used to measure the similarity between the generated images and real images, where a lower FID score indicates higher quality and greater similarity. Inception Score, on the other hand, evaluates the diversity and quality of generated images, with a higher score indicating better performance.

The code compares the SR models' performance by computing the FID score between real and generated images and the Inception Score for the generated images. This comparison helps in understanding the effectiveness of SR models with and without Swin Transformer integration, relative to a direct DDPM model trained on high-resolution images.


### IMPORTING LIBRARIES

In [1]:
import torch
import numpy as np
import os
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms, models
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

### FID Score

In [None]:


def to_uint8(tensor):
    return (tensor * 255).byte()

class ImageDataset(Dataset):
    def __init__(self, image_dir, device):
        self.image_dir = image_dir
        self.image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith(('png', 'jpg', 'jpeg'))]
        self.device = device
        self.transform = transforms.Compose([
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Lambda(to_uint8)  # Convert to uint8
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)
        return img.to(self.device)

# Directories containing the images
real_images_dir = "Raw_Images/valid_slices_raw"
generated_images_dir = "generated_images"
if __name__ == '__main__':
    # Move to the appropriate device (GPU or CPU)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create Dataset and DataLoader for real and generated images
    real_dataset = ImageDataset(real_images_dir, device)
    generated_dataset = ImageDataset(generated_images_dir, device)

    real_dataloader = DataLoader(real_dataset, batch_size=16, shuffle=True, num_workers=0)
    generated_dataloader = DataLoader(generated_dataset, batch_size=16, shuffle=False, num_workers=0)

    # Initialize FID metric
    fid = FrechetInceptionDistance(feature=2048, normalize=True)
    fid = fid.to(device)

    # Function to compute FID score
    def compute_fid(dataloader, real=True):
        for batch in dataloader:
            fid.update(batch, real=real)

    # Compute FID score
    fid.reset()
    compute_fid(generated_dataloader, real=False)
    compute_fid(real_dataloader, real=True)

    fid_value = fid.compute()
    print('FID:', fid_value)

### IS Score

In [None]:
# Directories containing the generated images
generated_images_dir = "generated_images"

if __name__ == '__main__':
    # Move to the appropriate device (GPU or CPU)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create Dataset and DataLoader for generated images
    generated_dataset = ImageDataset(generated_images_dir, device)
    generated_dataloader = DataLoader(generated_dataset, batch_size=16, shuffle=False, num_workers=0)

    # Check the number of generated images
    num_generated_images = len(generated_dataset)
    print(f'Number of generated images: {num_generated_images}')

    # # Display a few generated images
    # fig, axes = plt.subplots(1, 5, figsize=(15, 5))
    # for i, ax in enumerate(axes.flat):
    #     img = generated_dataset[i].cpu().permute(1, 2, 0).numpy().astype('uint8')
    #     ax.imshow(img)
    #     ax.axis('off')
    # plt.show()

    # Initialize Inception Score metric
    inception_score = InceptionScore().to(device)

    # Compute Inception Score in chunks
    chunk_size = 100 
    for i in range(0, len(generated_dataset), chunk_size):
        chunk_indices = list(range(i, min(i + chunk_size, len(generated_dataset))))
        chunk_subset = Subset(generated_dataset, chunk_indices)
        chunk_dataloader = DataLoader(chunk_subset, batch_size=16, shuffle=False, num_workers=0)
        for batch in chunk_dataloader:
            inception_score.update(batch)

    # Get the Inception Score
    is_mean, is_std = inception_score.compute()
    print(f'Inception Score: Mean = {is_mean}, Std = {is_std}')
