In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset, ConcatDataset
from torchvision import transforms, datasets
import torchvision.utils as vutils
from torchmetrics import Metric
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image import InceptionScore
from skimage.metrics import structural_similarity as ssim
from scipy.ndimage import gaussian_filter

device = torch.device('cpu')#torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Configurations
batch_size = 8
image_size = (299, 299)

# Define a custom transformation for Gaussian blur
class GaussianBlur(object):
    def __init__(self, sigma=1):
        self.sigma = sigma

    def __call__(self, image):
        # Convert PIL Image to numpy array
        image_np = np.array(image)
        # Apply Gaussian blur
        blurred_image_np = gaussian_filter(image_np, sigma=self.sigma)
        # Convert numpy array back to PIL Image
        blurred_image = transforms.ToPILImage()(blurred_image_np)
        return blurred_image
    

# Transformations to resize and normalize images
transform_real = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_gen = transforms.Compose([
    transforms.Resize(image_size),
    #GaussianBlur(sigma=1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
def filter_dataloader_by_class(dataloader, target_idx):

    filtered_images = []
    for images, labels in dataloader:
        mask = labels == target_idx
        filtered_images.append(images[mask])
    filtered_images = torch.cat(filtered_images)
    filtered_labels = torch.full((len(filtered_images),), target_idx, dtype=torch.long)

    return DataLoader(TensorDataset(filtered_images, filtered_labels), batch_size=batch_size, shuffle=True)

In [None]:
# Função para converter imagens de float para uint8
def convert_to_uint8(images):
    images = (images * 255).byte()
    return images

#Paths

In [None]:
real_images_dir = 'data-students/TRAIN' #C:/Users/Nicoli Leal/Desktop/project2/TRAIN'

generated_images_dir = 'GEN_DATASETS/DCGAN_50'
#generated_images_dir = 'GEN_DATASETS/DCGAN_100'
#generated_images_dir = 'GEN_DATASETS/DCGAN_500'

#generated_images_dir = 'GEN_DATASETS/DDPM_50'
#generated_images_dir = 'GEN_DATASETS/DDPM_100'
#generated_images_dir = 'GEN_DATASETS/DDPM_500'

#generated_images_dir = 'GEN_DATASETS/VAE_50'
#generated_images_dir = 'GEN_DATASETS/VAE_100'
#generated_images_dir = 'GEN_DATASETS/VAE_500'

#Load images

In [8]:
# Load dataset of real images
real_dataset = datasets.ImageFolder(real_images_dir, transform=transform_real)
real_dataloader = DataLoader(real_dataset, batch_size=batch_size, shuffle=False, drop_last = True)
classes = real_dataset.classes
num_classes = len(classes)

# Load dataset of generated images
generated_dataset = datasets.ImageFolder(generated_images_dir, transform=transform_gen)
generated_dataloader = DataLoader(generated_dataset, batch_size=batch_size, shuffle=False, drop_last = True)

# Plot

In [None]:
def plot_real_and_generated_images_by_class(real_dataloader, generated_dataloader, device, num_classes=10):
    
    for i in range(num_classes):
        filtered_dataloader_gen = filter_dataloader_by_class(generated_dataloader, i)
        filtered_dataloader_real = filter_dataloader_by_class(real_dataloader, i)

        # Obtém lotes de imagens reais e geradas
        real_batch = next(iter(filtered_dataloader_real))
        generated_batch = next(iter(filtered_dataloader_gen))

        # Cria a figura
        plt.figure(figsize=(8, 8))

        # Plot da imagem real
        plt.subplot(1, 5, 1)
        plt.axis("off")
        plt.title("Real Image")
        real_image = real_batch[0][0].to(device)  # Pega a primeira imagem do lote
        plt.imshow(np.transpose(vutils.make_grid(real_image, padding=2, normalize=True).cpu(), (1, 2, 0)))

        # Plot das imagens geradas
        for j in range(4):
            plt.subplot(1, 5, j + 2)
            plt.axis("off")
            plt.title(f"Generated {j+1}")
            generated_image = generated_batch[0][j].to(device)  # Pega a j-ésima imagem gerada do lote
            plt.imshow(np.transpose(vutils.make_grid(generated_image, padding=2, normalize=True).cpu(), (1, 2, 0)))

        # Mostrar o plot
        plt.tight_layout()
        plt.show()

# Exemplo de chamada da função
plot_real_and_generated_images_by_class(real_dataloader, generated_dataloader, device)

# Inception Score

In [None]:
def calculate_inception_scores_for_classes(generated_dataloader, num_classes=num_classes, device=device, splits=10, classes=classes):

    results = []

    for i in range(1,2): #num_classes):
        
        filtered_dataloader_gen = filter_dataloader_by_class(generated_dataloader, i)

        # Initialize the InceptionScore calculator
        inception_score = InceptionScore(feature='logits_unbiased', splits=splits, normalize=True).to(device)

        # Calculate IS for generated images
        for images, _ in filtered_dataloader_gen:
            images = convert_to_uint8(images)
            images = images.to(device)
            inception_score.update(images)

        # Obtain the IS value
        is_mean, is_std = inception_score.compute()
        results.append((i, is_mean.item(), is_std.item()))
        print(f'Inception Score for class {classes[i]}: {is_mean} ± {is_std}')

        # Reset the metric for the next class
        inception_score.reset()

    return results

In [None]:
def calculate_inception_scores_for_one_class(generated_dataloader, num_classes=num_classes, device=device, splits=10, classes=classes):

    results = []
        
    filtered_dataloader_gen = generated_dataloader

    # Initialize the InceptionScore calculator
    inception_score = InceptionScore(feature='logits_unbiased', splits=splits, normalize=True).to(device)

    # Calculate IS for generated images
    for images, _ in filtered_dataloader_gen:
        images = convert_to_uint8(images)
        images = images.to(device)
        inception_score.update(images)

    # Obtain the IS value
    is_mean, is_std = inception_score.compute()
    results.append((is_mean.item(), is_std.item()))
    print(f'Inception Score for all classes together: {is_mean} ± {is_std}')

    # Reset the metric for the next class
    inception_score.reset()

    return results

# FID

In [None]:
def calculate_fid_for_classes(generated_dataloader, real_dataloader,num_classes=num_classes,classes=classes):
    
    # Inicializar o calculador de FID
    fid = FrechetInceptionDistance(feature=768)

    for i in range(num_classes):
        filtered_dataloader_gen = filter_dataloader_by_class(generated_dataloader, i)
        filtered_dataloader_real = filter_dataloader_by_class(real_dataloader, i)

        # Calcular FID para imagens reais
        for images, _ in filtered_dataloader_real:
            images = convert_to_uint8(images)
            fid.update(images, real=True)

        # Calcular FID para imagens geradas
        for images, _ in filtered_dataloader_gen:
            images = convert_to_uint8(images)
            fid.update(images, real=False)

        # Obter o valor do FID
        fid_value = fid.compute()
        print(f'Frechet Inception Distance for class {classes[i]}: {fid_value}')

# SSIM

In [None]:
def calculate_ssim_for_class(generated_dataloader, real_dataloader, num_classes=num_classes, classes=classes):

    ssim_scores = []

    for class_idx in range(num_classes):
        torch.cuda.empty_cache()

        filtered_dataloader_gen = filter_dataloader_by_class(generated_dataloader, class_idx)
        filtered_dataloader_real = filter_dataloader_by_class(real_dataloader, class_idx)

        for batch_real, _ in filtered_dataloader_real:
            for batch_gen, _ in filtered_dataloader_gen:
                for img_real, img_gen in zip(batch_real, batch_gen):
                    img_real = convert_to_uint8(img_real).numpy()
                    img_gen = convert_to_uint8(img_gen).numpy()
                    ssim_score = ssim(img_real, img_gen, data_range=(img_gen.max() - img_gen.min()), win_size=3)
                    ssim_scores.append(ssim_score)

        mean_ssim = np.mean(ssim_scores)
        std_ssim = np.std(ssim_scores)

        print(f' Structural Similarity for class {classes[class_idx]}: {mean_ssim} ± {std_ssim}')

# Evaluation

In [None]:
calculate_inception_scores_for_one_class(generated_dataloader)

In [None]:
calculate_inception_scores_for_classes(generated_dataloader)

In [None]:
calculate_fid_for_classes(generated_dataloader, real_dataloader)

In [None]:
calculate_ssim_for_class(generated_dataloader, real_dataloader,num_classes=num_classes,classes=classes)