In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import random
import math
import io
from PIL import Image
from copy import deepcopy
from IPython.display import HTML
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
# device='cpu'
from torch.utils.data import Dataset, DataLoader
import random
import time
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
import numpy as np
from matplotlib.colors import Normalize
from torch.utils.data import DataLoader, random_split, Subset
import cv2
import math
print(device)

In [None]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
path='/home/scai/mtech/aib222683/scratch/Task2/data'
data_path_Train = os.path.join(path,'train') #Enter the train folder directory
data_path_Test = os.path.join(path,'test') #Enter the test folder directory
batch_size = 16
num_workers = 2


from torchvision.transforms import Lambda

class MinMaxScale:
    def __init__(self, min_val=0, max_val=1):
        self.min = min_val
        self.max = max_val
        # print(self.min, self.max)

    def __call__(self, img):
         return (img-self.min)/(self.max-self.min)

class MinMaxScaleInverse:
    def __init__(self, min_val=0, max_val=1):
        self.min = min_val
        self.max = max_val
        # print(self.min, self.max)

    def __call__(self, img):
        
         return img*(self.max-self.min)+self.min
         
transform_train_x = transforms.Compose([Lambda(MinMaxScale(min_val=math.log(-3.1416+4), max_val=math.log(3677465+4)))])

transform_train_y = transforms.Compose([Lambda(MinMaxScale(min_val=math.log(-3.1416+4), max_val=math.log(3677465+4)))])

transform_test= transforms.Compose([Lambda(MinMaxScaleInverse(min_val=math.log(-3.1416+4), max_val=math.log(3677465+4)))])
# min tensor(-3.1416)
# max tensor(3677465.)


In [None]:
#function for fourier
def FCT(images):
    # Convert RGB images to grayscale
    # print(images.shape)
    grayscale_images = np.mean(images, axis=-1)

    # Compute the Fourier Transform for each image
    data = np.fft.fft2(grayscale_images)

    magnitude = np.abs(data)
    phases = np.angle(data)

    result_array = np.stack([magnitude, phases], axis=-1)

    return result_array

#function for inverse fourier
def inverseFCT(img):
    img=np.transpose(img,(1,2,0))
    
    magnitude=img[:,:,0]
    phase=img[:,:,1]
    # print(magnitude.shape)
    complex_image = magnitude* np.exp(1j * phase)

    # Perform the Inverse Fourier Transform
    reconstructed_image = np.fft.ifft2(complex_image)

    # Take the real part to get the original image
    original_image = np.real(reconstructed_image)

    return original_image


def split(img):
    return img[:,:,:,:256], img[:,:,:,256:]

In [None]:
# this will create a class for training data containing both mri sequences (real and target) as a single merge image of shape (256,512)

class MRIImage(Dataset):
    def __init__(self, data_dir, transform1=None,transform2=None):
        self.data_dir = data_dir
        self.transform1 = transform1
        self.transform2=transform2
        # self.transform3=transform3
        self.image_paths =[os.path.join(data_dir, filename) for filename in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, filename)) and not filename.startswith('.')]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        image =np.asarray(image)
        
        
        
       
        img1=image[:,:256,:]
        img2=image[:,256:,:]

        img1=FCT(img1)
        img2=FCT(img2)

        img1=np.transpose(img1,(2,0,1))
        img2=np.transpose(img2,(2,0,1))
        
        img1=torch.from_numpy(img1).float()
        img2=torch.from_numpy(img2).float()
        
        img1=torch.log(img1+4)
        img2=torch.log(img2+4)
        
        if self.transform1:
            img1= self.transform1(img1)
            
        if self.transform2:
            img2=self.transform2(img2)
            
        image=torch.cat((img1, img2), dim=2)
        
        
        return image

In [None]:
class MRIImageTest(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        self.image_paths = [os.path.join(data_dir, filename) for filename in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, filename)) and not filename.startswith('.')]


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        
        if self.transform:
            image = self.transform(image)

        return image




In [None]:
from tqdm import tqdm

full_dataset =MRIImage(data_dir=data_path_Train, transform1=transform_train_x,transform2=transform_train_y)

train_dataset=full_dataset
# Create a data loader for training
load_Train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)


In [None]:
# Load the entire test dataset
full_test_dataset = MRIImage(data_dir=data_path_Test, transform1=transform_train_x,transform2=transform_train_y)

test_dataset=full_test_dataset
# Create a data loader for testing
load_Test = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)


In [None]:
inst_norm = True if batch_size==1 else False  # instance normalization


def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
    padding=padding)


def conv_n(in_channels, out_channels, kernel_size, stride=1, padding=0, inst_norm=False):
    if inst_norm == True:
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size,
        stride=stride, padding=padding), nn.InstanceNorm2d(out_channels,
        momentum=0.1, eps=1e-5),)
    else:
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size,
        stride=stride, padding=padding), nn.BatchNorm2d(out_channels,
        momentum=0.1, eps=1e-5),)

def tconv(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0,):
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride,
    padding=padding, output_padding=output_padding)

def tconv_n(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, inst_norm=False):
    if inst_norm == True:
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
        stride=stride, padding=padding, output_padding=output_padding),
        nn.InstanceNorm2d(out_channels, momentum=0.1, eps=1e-5),)
    else:
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
        stride=stride, padding=padding, output_padding=output_padding),
        nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5),)

In [None]:
dim_c = 2
dim_g = 64

# Generator
class Gen(nn.Module):
    def __init__(self, inst_norm=False):
        super(Gen,self).__init__()
        self.n1 = conv(dim_c, dim_g, 4, 2, 1)
        self.n2 = conv_n(dim_g, dim_g*2, 4, 2, 1, inst_norm=inst_norm)
        self.n3 = conv_n(dim_g*2, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        self.n4 = conv_n(dim_g*4, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n5 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n6 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n7 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n8 = conv(dim_g*8, dim_g*8, 4, 2, 1)

        self.m1 = tconv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m2 = tconv_n(dim_g*8*2, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m3 = tconv_n(dim_g*8*2, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m4 = tconv_n(dim_g*8*2, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m5 = tconv_n(dim_g*8*2, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        self.m6 = tconv_n(dim_g*4*2, dim_g*2, 4, 2, 1, inst_norm=inst_norm)
        self.m7 = tconv_n(dim_g*2*2, dim_g*1, 4, 2, 1, inst_norm=inst_norm)
        self.m8 = tconv(dim_g*1*2, dim_c, 4, 2, 1)
        self.tanh = nn.Tanh()

    def forward(self,x):
        n1 = self.n1(x)
        n2 = self.n2(F.leaky_relu(n1, 0.2))
        n3 = self.n3(F.leaky_relu(n2, 0.2))
        n4 = self.n4(F.leaky_relu(n3, 0.2))
        n5 = self.n5(F.leaky_relu(n4, 0.2))
        n6 = self.n6(F.leaky_relu(n5, 0.2))
        n7 = self.n7(F.leaky_relu(n6, 0.2))
        n8 = self.n8(F.leaky_relu(n7, 0.2))
        m1 = torch.cat([F.dropout(self.m1(F.relu(n8)), 0.5, training=True), n7], 1)
        m2 = torch.cat([F.dropout(self.m2(F.relu(m1)), 0.5, training=True), n6], 1)
        m3 = torch.cat([F.dropout(self.m3(F.relu(m2)), 0.5, training=True), n5], 1)
        m4 = torch.cat([self.m4(F.relu(m3)), n4], 1)
        m5 = torch.cat([self.m5(F.relu(m4)), n3], 1)
        m6 = torch.cat([self.m6(F.relu(m5)), n2], 1)
        m7 = torch.cat([self.m7(F.relu(m6)), n1], 1)
        m8 = self.m8(F.relu(m7))

        return self.tanh(m8)

In [None]:
dim_d = 64

# Discriminator
class Disc(nn.Module):
    def __init__(self, inst_norm=False):
        super(Disc,self).__init__()
        self.c1 = conv(dim_c*2, dim_d, 4, 2, 1)
        self.c2 = conv_n(dim_d, dim_d*2, 4, 2, 1, inst_norm=inst_norm)
        self.c3 = conv_n(dim_d*2, dim_d*4, 4, 2, 1, inst_norm=inst_norm)
        self.c4 = conv_n(dim_d*4, dim_d*8, 4, 1, 1, inst_norm=inst_norm)
        self.c5 = conv(dim_d*8, 1, 4, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        xy=torch.cat([x,y],dim=1)
        xy=F.leaky_relu(self.c1(xy), 0.2)
        xy=F.leaky_relu(self.c2(xy), 0.2)
        xy=F.leaky_relu(self.c3(xy), 0.2)
        xy=F.leaky_relu(self.c4(xy), 0.2)
        xy=self.c5(xy)

        return self.sigmoid(xy)

def weights_init(z):
    cls_name =z.__class__.__name__
    if cls_name.find('Conv')!=-1 or cls_name.find('Linear')!=-1:
        nn.init.normal_(z.weight.data, 0.0, 0.02)
        nn.init.constant_(z.bias.data, 0)
    elif cls_name.find('BatchNorm')!=-1:
        nn.init.normal_(z.weight.data, 1.0, 0.02)
        nn.init.constant_(z.bias.data, 0)

In [None]:
BCE = nn.BCELoss() #binary cross-entropy
L1 = nn.L1Loss()
L2=nn.MSELoss()
#instance normalization
Gen_model = Gen(inst_norm).to(device)
Disc = Disc(inst_norm).to(device)
generator = Gen(inst_norm).to(device)
#optimizers
# Gen_optim = optim.Adam(Gen.parameters(), lr=2e-4, betas=(0.5, 0.999), weight_decay=0.35)
Gen_optim = optim.Adam(Gen_model.parameters(), lr=1e-4, betas=(0.5, 0.999))   #lr=2e-3
Disc_optim = optim.Adam(Disc.parameters(), lr=5e-5, betas=(0.5, 0.999))   # lr=2e-6

In [None]:
# # img_list = []
# Disc_losses = Gen_losses = Gen_GAN_losses = Gen_L1_losses = []


min_value=math.log(-3.1416+4)
max_value=math.log(3677465+4)
# min_value=(-3.1416)
# max_value=(4239579)

iter_per_plot = 20
epochs = 30
L1_lambda = 50.0


for ep in range(epochs):
    for i, data in enumerate(load_Train):
        size = data.shape[0]

        x, y = split(data.to(device))

        r_masks = torch.ones(size,1,30,30).to(device)
        f_masks = torch.zeros(size,1,30,30).to(device)

        fake=Gen_model(x)
        # disc
        if (ep+1)%2!=0:
            Disc.zero_grad()
            r_patch=Disc(y,x)
    
            # print(r_patch.mean())
            
            # print(r_patch.mean())
            r_disc_loss=L2(r_patch,r_masks)
            # print(r_gan_loss)
            
    
            
            #fake_patch
            f_patch = Disc(fake.detach(),x)
            f_disc_loss=L2(f_patch,f_masks)
            Disc_loss = r_disc_loss + f_disc_loss
           
            Disc_loss.backward()
            Disc_optim.step()

        # gen
        # fake=Gen_model(x)
        Gen_model.zero_grad()
        f_patch = Disc(fake,x)
        f_gan_loss=L2(f_patch,r_masks)

        L1_loss = L1(fake,y)
        Gen_loss = f_gan_loss + L1_lambda*L1_loss
        # print(Gen_loss.item())
        Gen_loss.backward()

        Gen_optim.step()
        
        
        if (i+1)%iter_per_plot == 0 :
            if (ep+1)%2!=0:
                print('Epoch [{}/{}], Step [{}/{}], disc_loss: {:.4f}, gen_loss: {:.4f},Disc(real): {:.2f}, Disc(fake):{:.2f}, gen_loss_gan:{:.4f}, gen_loss_L1:{:.4f}'.format(ep+1, epochs, i+1, len(load_Train), Disc_loss.item(), Gen_loss.item(), r_disc_loss.item(), f_disc_loss.item(), f_gan_loss.item(), L1_loss.item()))
            else:
                print('Epoch [{}/{}], Step [{}/{}], disc_loss: {:.4f}, gen_loss: {:.4f},Disc(real): {:.2f}, Disc(fake):{:.2f}, gen_loss_gan:{:.4f}, gen_loss_L1:{:.4f}'.format(ep+1, epochs, i+1, len(load_Train), 0, Gen_loss.item(), 0, 0, f_gan_loss.item(), L1_loss.item()))

           


    if (ep+1)==epochs:
    
        with torch.no_grad():
            Gen_model.eval()
            for data in load_Train:
                # print(data.shape)
                x, y = split(data.to(device))
                fake = Gen_model(x)
                for j in range(fake.shape[0]):
                    if j%10==0:
                        t_y=y[j].cpu().detach().numpy()
                        fk_batch=fake[j].cpu().detach().numpy()
                        t_x=x[j].cpu().detach().numpy()
                        
                        t_x=t_x*(max_value-min_value)+min_value
                        t_y=t_y*(max_value-min_value)+min_value
                        fk_batch=fk_batch*(max_value-min_value)+min_value
                        
                        t_x=np.exp(t_x)-4
                        t_y=np.exp(t_y)-4
                        fk_batch=np.exp(fk_batch)-4
                        
                        t_x=inverseFCT(t_x)
                        t_y=inverseFCT(t_y)
                        fk_batch=inverseFCT(fk_batch)
                        
                        plt.figure(figsize=(10, 4))
                        
                        # Plotting the first array
                        plt.subplot(1, 3, 1)
                        plt.imshow(t_x)
                        plt.title('input')
                        
                        # Plotting the second array
                        plt.subplot(1, 3, 2)
                        plt.imshow(fk_batch)
                        plt.title('generated image')
                        
                        # Plotting the third array
                        plt.subplot(1, 3, 3)
                        plt.imshow(t_y)
                        plt.title('ground truth')










                        
                        
                break
            Gen_model.train()
            
    if ep+1==epochs:
        torch.save(Gen_model.state_dict(), 'MRI2CT_fourier_30epoch.pth')  
            


In [None]:
# t_batch =  next(iter(load_Test))
# # print(t_batch.shape)
# t_x, t_y = split(t_batch)
error_list=[]
generator.load_state_dict(torch.load("MRI2CT_fourier_30epoch.pth", map_location=device))

batch_size = 1  # Set the batch size to 1 to get a single image in each iteration
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for data in tqdm(test_loader):
    t_batch=data
    t_x, t_y = split(t_batch)
    
    with torch.no_grad():
        generator.eval()
        fk_batch=generator(t_x.to(device))
    t_y=t_y.numpy()
    fk_batch=fk_batch.cpu().detach().numpy()
    t_y=t_y*(max_value-min_value)+min_value
    fk_batch=fk_batch*(max_value-min_value)+min_value

    t_y=np.exp(t_y.squeeze())-4
    fk_batch=np.exp(fk_batch.squeeze())-4

    # t_x=(t_x.squeeze())
    # t_y=(t_y.squeeze())
    # fk_batch=(fk_batch.squeeze())
    
    
    t_y=inverseFCT(t_y)
    fk_batch=inverseFCT(fk_batch)
    
    # print(t_y.shape,fk_batch.shape)
    error_list.append(np.sqrt(np.sum((t_y-fk_batch)**2)))
    

error_list=np.array(error_list)


print('mean rmse error :',np.mean(error_list))
print('max rmse error:',np.max(error_list))
print('min rmse error:',np.min(error_list))




In [None]:
def fft_visual(img):
    
    img=np.transpose(img,(1,2,0))
    
    magnitude=img[:,:,0]
    phase=img[:,:,1]

    magnitude_log = np.log1p(magnitude)

    normalized_magnitude=Normalize()(magnitude_log)
    # print(np.max(normalized_magnitude),np.min(normalized_magnitud?

    return normalized_magnitude

In [None]:
min_value=math.log(-3.1416+4)
max_value=math.log(3677465+4)

# min_value=(-3.1416)
# max_value=(4239579)
c=0
for data in (test_loader):
    t_batch=data
    t_x, t_y = split(t_batch)
    
    with torch.no_grad():
        generator.eval()
        fk_batch=generator(t_x.to(device))
    t_y=t_y.numpy()
    fk_batch=fk_batch.cpu().detach().numpy()
    t_x=t_x.cpu().detach().numpy()
    
    t_x=t_x*(max_value-min_value)+min_value
    t_y=t_y*(max_value-min_value)+min_value
    fk_batch=fk_batch*(max_value-min_value)+min_value
    
    t_x=np.exp(t_x.squeeze())-4.5
    t_y=np.exp(t_y.squeeze())-4.5
    fk_batch=np.exp(fk_batch.squeeze())-4.5

    # fft_visual(fk_batch)
   
    mag_t_x=fft_visual(t_x)
    mag_t_y=fft_visual(t_y)
    mag_fk_batch=fft_visual(fk_batch)
    
   
    plt.figure(figsize=(10, 4))

    # Plotting the first array
    plt.subplot(1, 3, 1)
    plt.imshow(mag_t_x,cmap='gray')
    plt.title('log magnitude(input)')
    plt.axis('off')
    
    # Plotting the second array
    plt.subplot(1, 3, 2)
    plt.imshow(mag_fk_batch,cmap='gray')
    plt.title('log magnitude(generated image)')
    plt.axis('off')
    
    # Plotting the third array
    plt.subplot(1, 3, 3)
    plt.imshow(mag_t_y,cmap='gray')
    plt.title('log magnitude(ground truth)')
    plt.axis('off')

    c+=1
# # Show the plots
    plt.show()

    
    if c==10:
        break

In [None]:
# min_value=math.log(-3.1416+4)
# max_value=math.log(3677465+4)

c=0
for data in tqdm(test_loader):
    t_batch=data
    t_x, t_y = split(t_batch)
    
    with torch.no_grad():
        generator.eval()
        fk_batch=generator(t_x.to(device))
    t_y=t_y.numpy()
    fk_batch=fk_batch.cpu().detach().numpy()
    t_x=t_x.cpu().detach().numpy()
    
    t_x=t_x*(max_value-min_value)+min_value
    t_y=t_y*(max_value-min_value)+min_value
    fk_batch=fk_batch*(max_value-min_value)+min_value
    
    t_x=np.exp(t_x.squeeze())-4
    t_y=np.exp(t_y.squeeze())-4
    fk_batch=np.exp(fk_batch.squeeze())-4

    # t_x=(t_x.squeeze())
    # t_y=(t_y.squeeze())
    # fk_batch=(fk_batch.squeeze())
    
    t_x=inverseFCT(t_x)
    t_y=inverseFCT(t_y)
    fk_batch=inverseFCT(fk_batch)



    normalized_t_x=Normalize()(t_x)
    normalized_fk_batch=Normalize()(fk_batch)
    normalized_t_y=Normalize()(t_y)
    # print(np.max(normalized_fk_batch),np.min(normalized_fk_batch))
    # print(np.max(normalized_t_y),np.min(normalized_t_y))

    # print(np.max(fk_batch),np.min(fk_batch))
    # print(np.max(t_y),np.min(t_y))
    # print('#############################')

    plt.figure(figsize=(10, 4))

    # Plotting the first array
    plt.subplot(1, 3, 1)
    plt.imshow( normalized_t_x,cmap='gray')
    plt.title('input')
    plt.axis('off')
    # Plotting the second array
    plt.subplot(1, 3, 2)
    plt.imshow( normalized_fk_batch,cmap='gray')
    plt.title('generated image')
    plt.axis('off')
    # Plotting the third array
    plt.subplot(1, 3, 3)
    plt.imshow( normalized_t_y,cmap='gray')
    plt.title('ground truth')
    plt.axis('off')
    c+=1

# Show the plots
    # plt.show()

    # print(np.max(t_x),np.min(t_x))
    # print(np.max(t_y),np.min(t_y))
    # print(np.max(fk_batch),np.min(fk_batch))
    if c==10:
        break

In [None]:
# t_batch =  next(iter(load_Test))
# # print(t_batch.shape)
# t_x, t_y = split(t_batch)

# # generator.load_state_dict(torch.load('t1_to_t2_model_gaussian.pth', map_location=device))

# with torch.no_grad():
#     generator.eval()
#     fk=generator(t_x.to(device))
#     t_y=t_y.to(device)
#     # print(fk.shape)
# # compare_batches(t_x, fk_batch, t_y,"input images", "predicted images",  "ground truth")
# for j in range(fk.shape[0]):
#     if j%3==0:
#         figs=plt.figure(figsize=(12,8))
#         plt.subplot(1,2,1)
#         fake_np_rgb=np.transpose((fk[j].cpu().numpy()),(1,2,0))
#         fake_np_rgb=Normalize()(fake_np_rgb)
#         fake_np_gray=np.mean(fake_np_rgb, axis=2)
#         # fake_np_gray= (fake_np_gray*255).astype(np.uint8)
#         # print(fake_np_gray.shape)
#         # pixel_values = fake_np_gray.flatten()
#         # print(pixel_values.shape
# # Plot the histogram
#         plt.hist(fake_np_gray.flatten(), bins=256, range=[0, 1], density=True, color='red', alpha=0.5)
#         plt.yscale('log')
#         plt.xlabel('Intensity Value')
#         plt.ylabel('Frequency')
#         plt.title('Histogram of Grayscale Generated Image (Intensity [0, 1])')
#         # plt.show()
#         # figs=plt.figure(figsize=(10,10))
#         plt.subplot(1,2,2)
#         true_np_rgb=np.transpose((t_y[j].cpu().numpy()),(1,2,0))
#         true_np_rgb=Normalize()(true_np_rgb)
#         true_np_gray=np.mean(true_np_rgb, axis=2)
# # Plot the histogram
#         plt.hist(true_np_gray.flatten(), bins=256, range=[0, 1], density=True, color='gray', alpha=0.5)
#         plt.yscale('log')
#         plt.xlabel('Intensity Value')
#         plt.ylabel('Frequency')
#         plt.title('Histogram of Grayscale True Image (Intensity [0, 1])')
#         plt.show()
        # plt.axis("off")
        # plt.title("generated image histogram")

In [None]:
# # Assuming you have a dataset called 'test_dataset'
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# # Now you can iterate over it with random samples
# for _ in range(20):
#     t_batch, _ = next(iter(test_loader))
#     t_x, t_y = split(t_batch)

#     with torch.no_grad():
#         Gen.eval()
#         fk_batch = Gen(t_x.to(device))

#     compare_batches(t_x, fk_batch, t_y, "input images", "predicted images", "ground truth")
#     # compare_batches(t_x, fk_batch, "input images", "predicted images", t_y, "ground truth")



In [None]:
# # Assuming you have a dataset called 'test_dataset'
# batch_size = 1  # Set the batch size to 1 to get a single image in each iteration
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# error_list=[]
# structural_similarity=[]
# # Now you can iterate over it with random samples
# for i,data in enumerate (test_loader):
#     t_batch=data
#     t_x, t_y = split(t_batch)
#     with torch.no_grad():
#         generator.eval()
#         fk_batch = generator(t_x.to(device))
#     # if i%10==0:
#     #     compare_batches(t_x, fk_batch, t_y, "input images", "predicted images","ground truth")
#     t_y=t_y.numpy()
#     fk_batch=fk_batch.cpu().detach().numpy()
#     # print(t_y.shape)
#     # print(fk_batch.shape)
#     error_list.append(np.sqrt(np.sum((t_y-fk_batch)**2)))
#     t_y=np.transpose(t_y.squeeze(), (1, 2, 0))
#     fk_batch=np.transpose(fk_batch.squeeze(), (1, 2, 0))
#     # print(np.min(t_y))
#     # print(np.max(t_y))
#     # # print(fk_batch.shape)
#     ssi=ssim(t_y,fk_batch,multichannel=True,win_size=3,data_range=2.0)
#     structural_similarity.append(ssi)
#     # print(error_list)
# error_list=np.array(error_list)
# structural_similarity_np=np.array(structural_similarity)

# print('mean rmse error :',np.mean(error_list))
# print('max rmse error:',np.max(error_list))
# print('min rmse error:',np.min(error_list))
# print('mean ssim :',np.mean(structural_similarity_np))
# print('min ssim :',np.min(structural_similarity_np))
# print('max ssim :',np.max(structural_similarity_np))


In [None]:
# for i,data in enumerate (test_loader):
#     t_batch=data
#     t_x, t_y = split(t_batch)
#     with torch.no_grad():
#         generator.eval()
#         fk_batch = generator(t_x.to(device))
#     if structural_similarity[i]>np.mean(structural_similarity_np):
#         # print(i)
#         compare_batches(t_x, fk_batch, t_y, "input images", "predicted images","ground truth")
        