### KL divergence to measure the similarity between data distribution

In [1]:
import random
import numpy as np
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import torch
import learn2learn as l2l
from tqdm import tqdm
import matplotlib.pyplot as plt
# The corase reconstruction is the rss of the zerofilled multi-coil kspaces
# after inverse FT.
from functions.data.transforms import UnetDataTransform_norm, normalize
# Import a torch.utils.data.Dataset class that takes a list of data examples, a path to those examples
# a data transform and outputs a torch dataset.
from functions.data.mri_dataset import SliceDataset
# Unet architecture as nn.Module
from functions.models.unet import Unet
# Function that returns a MaskFunc object either for generatig random or equispaced masks
from functions.data.subsample import create_mask_for_mask_type
# Implementation of SSIMLoss
from functions.training.losses import SSIMLoss
from functions.helper import evaluate_loss_dataloader
from functions.fftc import fft2c_new as fft2c

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Set seed
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [2]:
# path_test1 = '/cheng/metaMRI/metaMRI/data_dict/E-part1/P/knee_train_PD_Biograph_15-22.yaml'
# path_test2 = '/cheng/metaMRI/metaMRI/data_dict/E10.2/P/brain_train_T1POST_TrioTim_5-8.yaml'

path_test1 = '/cheng/metaMRI/metaMRI/data_dict/E-part1/P/knee_train_PD_Biograph_15-22.yaml'
path_test2 = '/cheng/metaMRI/metaMRI/data_dict/E-part1/P/knee_train_PD_Skyra_15-22.yaml'

# mask function and data transform
mask_function = create_mask_for_mask_type(mask_type_str = 'random', self_sup = False, 
                    center_fraction = 0.08, acceleration = 4.0, acceleration_total = 3.0)
data_transform_test = UnetDataTransform_norm('multicoil', mask_func = mask_function, use_seed=True, mode='adapt')
# dataset: num_sample_subset x 3
testset1 = SliceDataset(dataset = path_test1, path_to_dataset='', path_to_sensmaps=None, provide_senmaps=False, 
                      challenge="multicoil", transform=data_transform_test, use_dataset_cache=True)
test_dataloader1 = torch.utils.data.DataLoader(dataset = testset1, batch_size = 1, shuffle = False, 
                                generator = torch.Generator().manual_seed(1), pin_memory = True)
testset2 = SliceDataset(dataset = path_test2, path_to_dataset='', path_to_sensmaps=None, provide_senmaps=False, 
                      challenge="multicoil", transform=data_transform_test, use_dataset_cache=True)
test_dataloader2 = torch.utils.data.DataLoader(dataset = testset2, batch_size = 1, shuffle = False, 
                                generator = torch.Generator().manual_seed(1), pin_memory = True)

batch1 = next(iter(test_dataloader1))
_, target_image1, _, _, fname, slice_num = batch1
batch2 = next(iter(test_dataloader2))
_, target_image2, _, _, fname, slice_num = batch2

In [3]:
def compute_magnitude_spectrum(image):
    # Perform the Fourier transform on the image
    fft_image = torch.fft.fft2(image)
    # Calculate the magnitude spectrum
    magnitude_spectrum = torch.abs(fft_image)
    return magnitude_spectrum

def compute_probability_distribution(magnitude_spectrum):
    # Normalize the magnitude spectrum to get the probability distribution
    prob_dist = magnitude_spectrum / torch.sum(magnitude_spectrum)
    return prob_dist

def compute_kl_divergence(prob_dist_p, prob_dist_q):
    # Compute KL divergence
    kl_divergence = torch.sum(prob_dist_p * torch.log(prob_dist_p / prob_dist_q))
    return kl_divergence

# Assuming you have two image tensors named 'image_1' and 'image_2' with shape [1, 1, 320, 320]

# Compute the magnitude spectrum for each image
magnitude_spectrum_1 = compute_magnitude_spectrum(target_image1)
magnitude_spectrum_2 = compute_magnitude_spectrum(target_image2)

# Compute the probability distribution for each image in the frequency domain
prob_dist_1 = compute_probability_distribution(magnitude_spectrum_1)
prob_dist_2 = compute_probability_distribution(magnitude_spectrum_2)

# Compute the KL divergence between the probability distributions
kl_divergence = compute_kl_divergence(prob_dist_1, prob_dist_2)
print("KL Divergence between the two images:", kl_divergence)


KL Divergence between the two images: tensor(0.3282)


0 means similar