# Quantitative analysis

This notebook contains the code for quantitative analysis of the synthetic images. Many different images will be loaded, based on whether they are synthetic or real data. When comparing synthetic data to real data, the validation data for that specific model is always used to compare. The code for calculating the metrics FID, MS-SSIM and SSIM used the following tutorial from the MONAI-generative framework: https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/2d_ddpm/2d_ddpm_tutorial_ignite.ipynb


In [None]:
import os
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from monai.data import DataLoader, Dataset
from generative.metrics import FIDMetric, MultiScaleSSIMMetric, SSIMMetric
#print_config()

# Images to analyze

Comment out the ones not in use

### Real images

In [None]:
real_images_path ="Real_images/Real_validation_data_bs16_for_epoch250_lacie" #OBSSSS
#real_images_path ="Real_images/Real_validation_data_bs32_for_epoch250"
#real_images_path = "Real_images/Real_training_data_bs16_epoch"
#real_images_path = "Real_images/Real_validation_data_bs8_8nov"
#real_images_path = "Real_images/Real_training_data_bs8_8nov"
#real_images_path = "Real_images/Real_validation_data_500timesteps"

### Synthetic images

In [None]:
#gen_images_path = "Synthetic_images/bs16_150epochs_3nov"
#gen_images_path = "Synthetic_images/bs16_150epochs_22_nov_larger_dataset"
#gen_images_path = "Synthetic_images/bs16_125epochs_3nov"
#gen_images_path = "Synthetic_images/bs32_epoch249av250_pretrained_1nov"
#gen_images_path = "Synthetic_images/bs32_epoch249av250_pretrained_2nov"
#gen_images_path = "Synthetic_images/bs32_epoch149av250_pretrained_1nov"
#gen_images_path = "Synthetic_images/bs32_epoch99av250_pretrained_1nov"
#gen_images_path = "Synthetic_images/bs32_epoch124av250_pretrained_1nov"
#gen_images_path = "Synthetic_images/bs32_epoch174av250_pretrained_1nov"
#gen_images_path = "Real_images/Real_validation_data_bs32_for_epoch250"
gen_images_path = "Real_images/Real_training_data_bs16_epoch"
#gen_images_path = "Synthetic_images/bs16_150epochs_timestep500"
#gen_images_path = "Synthetic_images/bs8_150epochs_8nov"
#gen_images_path = "Synthetic_images/bs16_150epochs_timestep500"

### Load images from path

In [None]:
def load_images(images_folder):
    images = []
    
    for element in os.listdir(images_folder):
        #image = cv2.imread(images_folder + "/" + element)
        image = nib.load(images_folder + "/" + element).get_fdata()
        images.append(image)
        
    return images

In [None]:
def load_images_rekkefølge(images_folder):
    images = []
    
    for i in range (30):
        #image = cv2.imread(images_folder + "/" + element)
        image = nib.load(images_folder + "/" + "nifti_file_" + str(i)+".nii").get_fdata()
        images.append(image)
        
    return images
    

In [None]:
#real_images = load_images(real_images_path)#[0:30]
#N = 100
#synthetic_images = load_images(gen_images_path)
#np.random.shuffle(synthetic_images)
#synthetic_images = synthetic_images[0:N]
real_images = load_images(real_images_path)
#np.random.shuffle(real_images)
#real_images = real_images[0:len(synthetic_images)]
#print(len(real_images))
#print(len(synthetic_images))
#np.random.shuffle(synthetic_images)
#synthetic_images = synthetic_images[0:100]

In [None]:
#vizualise
print(np.array(synthetic_images).shape)
plt.imshow(synthetic_images[2], cmap = "gray")
plt.colorbar()
plt.show()

In [None]:
real_images = load_images_rekkefølge(real_images_path)[0:18]
for i in range(len(real_images)):
    plt.figure()
    plt.imshow(real_images[i], cmap = "bone")
    plt.title(str(i))
    plt.show()

# Mean / Variance / Max / Min

In [None]:
def find_metrics(images):
    means = []
    stds = []
    mins = []
    maxs = []
    medians = []
    
    for image in images:
        mu = image.mean()
        std = np.std(image)
        minimum = np.amin(image)
        maximum = np.amax(image)
        median = np.median(image)
        
        means.append(mu)
        stds.append(std)
        mins.append(minimum)
        maxs.append(maximum)
        medians.append(median)
    
    return means, stds, maxs, mins, medians

In [None]:
real_means, real_stds, real_maxs, real_mins, real_medians = find_metrics(real_images)
syn_means, syn_stds, syn_maxs, syn_mins, syn_medians = find_metrics(synthetic_images)

print("Mean: (Real / synthetic ) ", np.around(np.mean(real_means),3), " / ",np.around(np.mean(syn_means),3))
print("Mean (stddev): (Real / synthetic ) ", np.around(np.std(real_means),3), " / ",np.around(np.std(syn_means),3))

print("Median: (Real / synthetic ) ", np.around(np.mean(real_medians),3), " / ",np.around(np.mean(syn_medians),3))
print("Median (stddev): (Real / synthetic ) ", np.around(np.std(real_medians),3), " / ",np.around(np.std(syn_medians),3))

print("Standard deviation: (Real / synthetic) ", np.around(np.mean(real_stds),3), " / ",np.around(np.mean(syn_stds),3))
print("Variation in Standard deviation: (Real / synthetic) ", np.around(np.std(real_stds),3), " / ",np.around(np.std(syn_stds),3))

print("Max pixel value: (Real / synthetic) ", np.around(np.mean(real_maxs),3), " / ",np.around(np.mean(syn_maxs),3))
print("Min pixel value: (Real / synthetic) ", np.around(np.mean(real_mins),3), " / ",np.around(np.mean(syn_mins),3))

# Histogram analysis

In [None]:
def find_mean_histogram(images):
    histograms = []
    binss = []
    for image in images:
        counts,bins = np.histogram(image, bins = 256)
        #print("bins:", bins)
        histograms.append(counts)
        binss.append(bins)
        #plt.stairs(counts, bins)
    #plt.show()
    
    average_histogram = np.mean(histograms, axis=0)
    plt.stairs(average_histogram, bins, color = "blue", fill = True)
    plt.grid()
    #plt.ylim((0,550))
    plt.xlim((0,1))
    plt.title("Average histogram")
    plt.ylabel("Frequency")
    plt.xlabel("pixel intensity")
    #print(average_histogram.shape)
    #print(np.median(average_histogram))
    #print(average_histogram)
    #plt.axvline(average_histogram.mean(), color = "blue", linestyle = "dashed", linewidth = 1)
    plt.show()
    
    #sns.histplot(average_histogram, bins = 256)
    
    return average_histogram, binss

In [None]:
avg_hist_real, real_binss = find_mean_histogram(real_images)
avg_hist_syn= find_mean_histogram(synthetic_images)

# Comparison of different models

In [None]:
model_names = ["Synthetic_images/bs32_epoch249av250_pretrained_1nov","Synthetic_images/bs32_epoch149av250_pretrained_1nov","Synthetic_images/bs32_epoch99av250_pretrained_1nov","Synthetic_images/bs16_150epochs_3nov", "Synthetic_images/bs16_125epochs_3nov","Synthetic_images/bs8_150epochs_8nov","Synthetic_images/bs16_150epochs_timestep500"]

In [None]:
model1 = load_images(model_names[0])
model2 = load_images(model_names[1])
model3 = load_images(model_names[2])
model4 = load_images(model_names[3])
model5 = load_images(model_names[4])
model6 = load_images(model_names[5])
model7 = load_images(model_names[6])

means1, stds1, maxs1, mins1, medians1 = find_metrics(model1)
means2, stds2, maxs2, mins2, medians2 = find_metrics(model2)
means3, stds3, maxs3, mins3, medians3 = find_metrics(model3)
means4, stds4, maxs4, mins4, medians4 = find_metrics(model4)
means5, stds5, maxs5, mins5, medians5 = find_metrics(model5)
means6, stds6, maxs6, mins6, medians6 = find_metrics(model6)
means7, stds7, maxs7, mins7, medians7 = find_metrics(model7)



In [None]:
mean_box = [list(real_means),list(means1), list(means2), list(means3),list(means4), list(means5), list(means6), list(means7)]# list(gen_maxs_bs64)]
sns.boxplot(data=mean_box, showfliers = False)
sns.stripplot(data=mean_box, color = "black")
plt.title("Mean pixel value for models 1 - 7")
plt.xlabel("Models")
plt.ylabel("Mean pixel value")
plt.show()

median_box = [list(medians1), list(medians2), list(medians3),list(medians4), list(medians5), list(medians6), list(medians7)]# list(gen_maxs_bs64)]
sns.boxplot(data=median_box, showfliers = False)
sns.stripplot(data=median_box, color = "black")
plt.title("Median pixel value for models 1 - 7")
plt.xlabel("Models")
plt.ylabel("Median pixel value")
plt.show()

maxs_box = [list(real_maxs),list(maxs1), list(maxs2), list(maxs3),list(maxs4), list(maxs5), list(maxs6), list(maxs7)]# list(gen_maxs_bs64)]
sns.boxplot(data=maxs_box, showfliers = False)
sns.stripplot(data=maxs_box, color = "black")
plt.title("Maximum pixel value for models 1 - 7")
plt.xlabel("Models")
plt.ylabel("Maximum pixel value")
plt.show()

mins_box = [list(real_mins),list(mins1), list(mins2), list(mins3),list(mins4), list(mins5), list(mins6), list(mins7)]# list(gen_maxs_bs64)]
sns.boxplot(data=mins_box, showfliers = False)
sns.stripplot(data=mins_box, color = "black")
plt.title("Minimum pixel value for models 1 - 7")
plt.xlabel("Models")
#plt.xlim((1,7))
plt.ylabel("Minimum pixel value")
plt.show()

mins_box = [list(real_stds), list(stds1), list(stds2), list(stds3),list(stds4), list(stds5), list(stds6), list(stds7)]# list(gen_maxs_bs64)]
sns.boxplot(data=mins_box, showfliers = False)
sns.stripplot(data=mins_box, color = "black")
plt.title("Variation in pixel values for models 1 - 7")
plt.xlabel("Models")
plt.ylabel("Standard deviation")
plt.show()



# Frequency spectrum analysis

In [None]:
real_fourier_images = []
for image in real_images:
    f_transform = np.fft.fft2(image)
    f_transform_shifted = np.fft.fftshift(f_transform)
    power_spectrum = np.abs(f_transform_shifted) ** 2
    real_fourier_images.append(power_spectrum)
    
generated_fourier_images = []
for image in synthetic_images:
    f_transform = np.fft.fft2(image)
    f_transform_shifted = np.fft.fftshift(f_transform)
    power_spectrum = np.abs(f_transform_shifted) ** 2
    generated_fourier_images.append(power_spectrum)

In [None]:
f_real_means, f_real_stds, f_real_maxs, f_real_mins, f_real_medians = find_metrics(real_fourier_images)
f_syn_means, f_syn_stds, f_syn_maxs, f_syn_mins, f_syn_medians = find_metrics(generated_fourier_images)


print("Mean: (Real / synthetic ) ", np.around(np.mean(f_real_means),3), " / ",np.around(np.mean(f_syn_means),3))
print("Mean (stddev): (Real / synthetic ) ", np.around(np.std(f_real_means),3), " / ",np.around(np.std(f_syn_means),3))

#print("Median: (Real / synthetic ) ", np.around(np.mean(f_real_medians),3), " / ",np.around(np.mean(f_syn_medians),3))
#print("Median (stddev): (Real / synthetic ) ", np.around(np.std(f_real_medians),3), " / ",np.around(np.std(f_syn_medians),3))

#print("Standard deviation: (Real / synthetic) ", np.around(np.mean(f_real_stds),3), " / ",np.around(np.mean(f_syn_stds),3))
#print("Variation in Standard deviation: (Real / synthetic) ", np.around(np.std(f_real_stds),3), " / ",np.around(np.std(f_syn_stds),3))

print("Max pixel value: (Real / synthetic) ", np.around(np.mean(f_real_maxs),3), " / ",np.around(np.mean(f_syn_maxs),3))
print("Min pixel value: (Real / synthetic) ", np.around(np.mean(f_real_mins),3), " / ",np.around(np.mean(f_syn_mins),3))

In [None]:
f_transform = np.fft.fft2(real_images[0])
print(f_transform.shape)
f_transform_shifted = np.fft.fftshift(f_transform)
power_spectrum = np.abs(f_transform_shifted) ** 2
plt.imshow(np.log1p(power_spectrum), cmap='gray')
plt.colorbar()
plt.title('Fourier transform of real image')
plt.show()

f_transform_gen = np.fft.fft2(synthetic_images[0])
print(f_transform_gen.shape)
f_transform_shifted_gen = np.fft.fftshift(f_transform_gen)
power_spectrum_gen = np.abs(f_transform_shifted_gen) ** 2
plt.imshow(np.log1p(power_spectrum_gen), cmap='gray')
plt.colorbar()
plt.title('Fourier transform of synthetic image')
plt.show()

print(np.mean(power_spectrum))
print(np.amax(power_spectrum))

print(np.mean(power_spectrum_gen))
print(np.amax(power_spectrum_gen))

In [None]:
fig, axs = plt.subplots(1, 3)
counts,bins = np.histogram(f_transform, bins = 1000)
axs[0].stairs(counts, bins, color = "blue", fill = True)
axs[1].stairs(counts, bins, color = "green", fill = True)
axs[2].stairs(counts, bins, color = "black", fill = True)
axs[1].set_xlim((-250, 250))
axs[2].set_xlim((-250,250))
axs[2].set_ylim((0, 200))
axs[0].set_yscale("log")
axs[1].set_yscale("log")
#axs[0].set_yscale("log")
#plt.xlim((-250, 250))
plt.show()


fig, axs = plt.subplots(1, 3)
counts,bins = np.histogram(f_transform_gen, bins = 1000)
axs[0].stairs(counts, bins, color = "blue", fill = True)
axs[1].stairs(counts, bins, color = "green", fill = True)
axs[2].stairs(counts, bins, color = "black", fill = True)
axs[1].set_xlim((-250, 250))
axs[2].set_xlim((-250,250))
axs[2].set_ylim((0, 200))
axs[0].set_yscale("log")
axs[1].set_yscale("log")
#plt.xlim((-250, 250))
plt.show()


# FID 

In [None]:
device = torch.device("cpu")#"cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

In [None]:
def subtract_mean(x: torch.Tensor) -> torch.Tensor:
    mean = [0.406, 0.456, 0.485]
    x[:, 0, :, :] -= mean[0]
    x[:, 1, :, :] -= mean[1]
    x[:, 2, :, :] -= mean[2]
    return x


def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor:
    return x.mean([2, 3], keepdim=keepdim)


def get_features(image):
    # If input has just 1 channel, repeat channel to have 3 channels
    if image.shape[1]:
        image = image.repeat(1, 3, 1, 1)

    # Change order from 'RGB' to 'BGR'
    image = image[:, [2, 1, 0], ...]

    # Subtract mean used during training
    image = subtract_mean(image)

    # Get model outputs
    with torch.no_grad():
        feature_image = radnet.forward(image)
        # flattens the image spatially
        feature_image = spatial_average(feature_image, keepdim=False)

    return feature_image

In [None]:
def preprocess_for_fid(images):
    images = np.array(images).astype('float32')
    images = torch.unsqueeze(torch.tensor(images), 1)
    print(images.shape)
    return images

In [None]:
#real_images = load_images(real_images_path)#[0:30]
N = 100
synthetic_images = load_images(gen_images_path)
np.random.shuffle(synthetic_images)
synthetic_images = synthetic_images[0:N]
real_images = load_images(real_images_path)
np.random.shuffle(real_images)
real_images = real_images[0:len(synthetic_images)]
print(len(real_images))
print(len(synthetic_images))
#np.random.shuffle(synthetic_images)
#synthetic_images = synthetic_images[0:100]

real_images_for_fid = preprocess_for_fid(real_images)
synthetic_images_for_fid = preprocess_for_fid(synthetic_images)

In [None]:
'''Load RadImageNet'''
radnet = torch.hub.load("Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True)
radnet.to(device)
radnet.eval()

In [None]:
def calculate_fid(dataloader1, dataloader2):
    images1_features = []
    images2_features = []
    
    for step, x in enumerate(dataloader1):
        images1 = x.to(device)

        # Get the features for the first set of images
        eval_feats1 = get_features(images1)
        images1_features.append(eval_feats1)
    
    for step, y in enumerate(dataloader2):
        images2 = y.to(device)

        # Get the features for the second set of images
        eval_feats2 = get_features(images2)
        images2_features.append(eval_feats2)
        
    eval_features1 = torch.vstack(images1_features)
    eval_features2 = torch.vstack(images2_features)
    
    fid = FIDMetric()
    fid_res = fid(eval_features1, eval_features2)
    return fid_res.item()

In [None]:
fid = calculate_fid(real_images_for_fid, synthetic_images_for_fid)
print(fid)

# PSNR

In [None]:
from monai.metrics import PSNRMetric
pnsr = PSNRMetric(max_val = 1.0)
PNSR = pnsr(real_images_for_fid[50], synthetic_images_for_fid[0])
print(PNSR.item())

# MS-SSIM & SSIM

In [None]:
def calculate_ssim_msssim(images1, images2):
    ms_ssim_scores = []
    ssim_scores = []

    ms_ssim = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_size=4)
    ssim = SSIMMetric(spatial_dims=2, data_range=1.0, kernel_size=4)
    
    for i in range(len(images1)):
        for j in range(len(images2)):
            ms_ssim_scores.append(ms_ssim(torch.unsqueeze(images1[i],0),torch.unsqueeze(images2[j],0)))
            ssim_scores.append(ssim(torch.unsqueeze(images1[i],0), torch.unsqueeze(images2[j],0)))
    
    return ms_ssim_scores, ssim_scores

def calculate_ssim_msssim_self(images1):
    ms_ssim_scores = []
    ssim_scores = []

    ms_ssim = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_size=4)
    ssim = SSIMMetric(spatial_dims=2, data_range=1.0, kernel_size=4)
    
    for i in range(len(images1)):
        for j in range(i,len(images1)):
            if (i!= j):
                ms_ssim_scores.append(ms_ssim(torch.unsqueeze(images1[i],0),torch.unsqueeze(images1[j],0)))
                ssim_scores.append(ssim(torch.unsqueeze(images1[i],0), torch.unsqueeze(images1[j],0)))
    
    return ssim_scores, ms_ssim_scores

In [None]:
#ssim, msssim = calculate_ssim_msssim(real_images_for_fid, synthetic_images_for_fid)
ssim_syn, msssim_syn = calculate_ssim_msssim_self(real_images_for_fid)

In [None]:
ms_ssim_scores = torch.cat(msssim_syn, dim=0)
ssim_scores = torch.cat(ssim_syn, dim=0)
print(f"MS-SSIM Metric: {ms_ssim_scores.mean():.4f} +- {ms_ssim_scores.std():.4f}")
print(f"SSIM Metric: {ssim_scores.mean():.4f} +- {ssim_scores.std():.4f}")

In [None]:
ms_ssim = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_size=4)
ssim = SSIMMetric(spatial_dims=2, data_range=1.0, kernel_size=4)

ms_ssim_test = ms_ssim(real_images_for_fid, synthetic_images_for_fid)
ssim_test = ssim(real_images_for_fid, synthetic_images_for_fid)
print(f"MS-SSIM score:", ms_ssim_test.mean().item())
print(f"MS-SSIM score:", ssim_test.mean().item())