In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import random
import math
import io
from PIL import Image
from copy import deepcopy
from IPython.display import HTML
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
# device='cpu'
from torch.utils.data import Dataset, DataLoader
import random
import time
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
import numpy as np
from matplotlib.colors import Normalize
import random
from torch.utils.data import DataLoader, random_split, Subset
import cv2
print(device)

In [None]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [None]:
# for training and testing we have load image that is merge of both input and target (256,512)
path='/home/scai/mtech/aib222677/scratch/data'
path1='/home/scai/mtech/aib222683/scratch/Task2/data'
# data_path_Train = os.path.join(path,'train') #Enter the train folder directory
# data_path_Test = os.path.join(path,'test') #Enter the test folder directory
data_path_Train = os.path.join(path1,'train')
data_path_Test = os.path.join(path1,'test')

batch_size = 30
num_workers = 2 
transform_train_x = transforms.Compose([transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
                                                                   
transform_train_y = transforms.Compose([transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

transform= transforms.Compose([transforms.Resize((256,512)),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

In [None]:
class BrainImage(Dataset):
    def __init__(self, data_dir, transform1=None,transform2=None):
        self.data_dir = data_dir
        self.transform1 = transform1
        self.transform2 = transform2
        self.image_paths = [os.path.join(data_dir, filename) for filename in os.listdir(data_dir)]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)

        to_tensor=transforms.ToTensor()
        image=to_tensor(image)

        img1=image[:,:,:256]
        img2=image[:,:,256:]
        
        
        if self.transform1:
            img1 = self.transform1(img1)
            
        if self.transform2:
            img2 = self.transform2(img2)

        image=torch.cat((img1,img2),dim=2)
        return image

In [None]:
class BrainImageTest(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        self.image_paths = [os.path.join(data_dir, filename) for filename in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, filename)) and not filename.startswith('.')]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        
        if self.transform:
            image = self.transform(image)

        return image

In [None]:

from torch.utils.data import DataLoader, random_split, Subset

# Load the entire training dataset
full_dataset=BrainImage(data_dir=data_path_Train, transform1=transform_train_x,transform2=transform_train_y)
train_dataset=full_dataset
# Create a data loader for training
load_Train = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [None]:
# Load the entire test dataset
full_test_dataset = BrainImageTest(data_dir=data_path_Test, transform=transform)

test_dataset=full_test_dataset

# Create a data loader for testing
load_Test = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)


In [None]:
#function used to split input and target
def split(img):
    return img[:,:,:,:256], img[:,:,:,256:]

In [None]:
inst_norm = True if batch_size==1 else False  # instance normalization


def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
    padding=padding)


def conv_n(in_channels, out_channels, kernel_size, stride=1, padding=0, inst_norm=False):
    if inst_norm == True:
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size,
        stride=stride, padding=padding), nn.InstanceNorm2d(out_channels,
        momentum=0.1, eps=1e-5),)
    else:
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size,
        stride=stride, padding=padding), nn.BatchNorm2d(out_channels,
        momentum=0.1, eps=1e-5),)

def tconv(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0,):
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride,
    padding=padding, output_padding=output_padding)

def tconv_n(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, inst_norm=False):
    if inst_norm == True:
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
        stride=stride, padding=padding, output_padding=output_padding),
        nn.InstanceNorm2d(out_channels, momentum=0.1, eps=1e-5),)
    else:
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size,
        stride=stride, padding=padding, output_padding=output_padding),
        nn.BatchNorm2d(out_channels, momentum=0.1, eps=1e-5),)

In [None]:
dim_c = 3
dim_g = 64

# Generator class designed for taking input of shape (128,128,3)
class Gen(nn.Module):
    def __init__(self, inst_norm=False):
        super(Gen,self).__init__()
        self.n1 = conv(dim_c, dim_g, 4, 2, 1)
        self.n2 = conv_n(dim_g, dim_g*2, 4, 2, 1, inst_norm=inst_norm)
        self.n3 = conv_n(dim_g*2, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        self.n4 = conv_n(dim_g*4, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n5 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n6 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.n7 = conv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        # self.n8 = conv(dim_g*8, dim_g*8, 4, 2, 1)

        self.m1 = tconv_n(dim_g*8, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m2 = tconv_n(dim_g*8*2, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m3 = tconv_n(dim_g*8*2, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m4 = tconv_n(dim_g*8*2, dim_g*8, 4, 2, 1, inst_norm=inst_norm)
        self.m5 = tconv_n(dim_g*8+256, dim_g*4, 4, 2, 1, inst_norm=inst_norm)
        self.m6 = tconv_n(dim_g*4+128, dim_g*2, 4, 2, 1, inst_norm=inst_norm)
        self.m7 = tconv_n(dim_g*2+64, dim_c, 4, 2, 1, inst_norm=inst_norm)
        # self.m8 = tconv(dim_g*1*2, dim_c, 4, 2, 1)
        self.tanh = nn.Tanh()

    def forward(self,x):
        n1 = self.n1(x)
        n2 = self.n2(F.leaky_relu(n1, 0.2))
        n3 = self.n3(F.leaky_relu(n2, 0.2))
        n4 = self.n4(F.leaky_relu(n3, 0.2))
        n5 = self.n5(F.leaky_relu(n4, 0.2))
        n6 = self.n6(F.leaky_relu(n5, 0.2))
        n7 = self.n7(F.leaky_relu(n6, 0.2))
       
        m1 = torch.cat([F.dropout(self.m1(F.relu(n7)), 0.5, training=True), n6], 1)
        
        m2 = torch.cat([F.dropout(self.m2(F.relu(m1)), 0.5, training=True), n5], 1)
        
        m3 = torch.cat([F.dropout(self.m3(F.relu(m2)), 0.5, training=True), n4], 1)
        
        m4 = torch.cat([self.m4(F.relu(m3)), n3], 1)
       
        m5 = torch.cat([self.m5(F.relu(m4)), n2], 1)
       
        m6 = torch.cat([self.m6(F.relu(m5)), n1], 1)
       
        m7 = self.m7(F.relu(m6))
       
        return self.tanh(m7)



In [None]:
dim_d = 64

#  discriminator class designed for taking input of shape (128,128,3)
class Disc(nn.Module):
    def __init__(self, inst_norm=False):
        super(Disc,self).__init__()
        self.c1 = conv(dim_c*2, dim_d, 4, 2, 1)
        self.c2 = conv_n(dim_d, dim_d*2, 4, 2, 1, inst_norm=inst_norm)
        self.c4 = conv_n(dim_d*2, dim_d*4, 4, 1, 1, inst_norm=inst_norm)
        self.c5 = conv(dim_d*4, 1, 4, 1, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        xy=torch.cat([x,y],dim=1)
       
        xy=F.leaky_relu(self.c1(xy), 0.2)
        
        xy=F.leaky_relu(self.c2(xy), 0.2)
        
        xy=F.leaky_relu(self.c4(xy), 0.2)
        
        xy=self.c5(xy)

        return self.sigmoid(xy)

def weights_init(z):
    cls_name =z.__class__.__name__
    if cls_name.find('Conv')!=-1 or cls_name.find('Linear')!=-1:
        nn.init.normal_(z.weight.data, 0.0, 0.02)
        nn.init.constant_(z.bias.data, 0)
    elif cls_name.find('BatchNorm')!=-1:
        nn.init.normal_(z.weight.data, 1.0, 0.02)
        nn.init.constant_(z.bias.data, 0)

In [None]:
BCE = nn.BCELoss() #binary cross-entropy
L1 = nn.L1Loss() # L1 loss
L2=nn.MSELoss() # L2 loss
#instance normalization
Gen_model = Gen(inst_norm).to(device)
Disc = Disc(inst_norm).to(device)
generator = Gen(inst_norm).to(device)
#optimizers
# Gen_optim = optim.Adam(Gen.parameters(), lr=2e-4, betas=(0.5, 0.999), weight_decay=0.35)
Gen_optim = optim.Adam(Gen_model.parameters(), lr=2e-4, betas=(0.5, 0.999))   #lr=1e-4
Disc_optim = optim.Adam(Disc.parameters(), lr=2e-4, betas=(0.5, 0.999))   # lr=5e-5

In [None]:
# training and displaying the results of training on few training samples


iter_per_plot = 10
epochs = 10
L1_lambda = 100.0


for ep in range(epochs):
    for i, data in enumerate(load_Train):
        
        start_time = time.time()
        size = data.shape[0]
        # creating patches from input image 
        x1,y1=[],[]
        x2,y2=[],[]
        x3,y3=[],[]
        x4,y4=[],[]
        x5,y5=[],[]
        x_img, y_img= split(data.to(device))

        for img in x_img:
          x1.append(img[:,:128,:128])
          x2.append(img[:,128:,:128])
          x3.append(img[:,:128,128:])
          x4.append(img[:,128:,128:])
          x5.append(img[:,64:192,64:192])

        for img in y_img:
          y1.append(img[:,:128,:128])
          y2.append(img[:,128:,:128])
          y3.append(img[:,:128,128:])
          y4.append(img[:,128:,128:])
          y5.append(img[:,64:192,64:192])

        
        input=[]

        input.append((torch.stack(x1),torch.stack(y1)))
        input.append((torch.stack(x2),torch.stack(y2)))
        input.append((torch.stack(x3),torch.stack(y3)))
        input.append((torch.stack(x4),torch.stack(y4)))
        input.append((torch.stack(x5),torch.stack(y5)))

       

        for x,y in input:
        

        
    
            r_masks = torch.ones(size,1,30,30).to(device)
            f_masks = torch.zeros(size,1,30,30).to(device)
    
            # disc
            Disc.zero_grad()
            r_patch=Disc(y,x)
           
            
            
            r_disc_loss=L2(r_patch,r_masks)
    
            fake=Gen_model(x)
          

            #fake_patch
            f_patch = Disc(fake.detach(),x)
            # print(f_patch.shape,f_masks.shape)
            
            f_disc_loss=L2(f_patch,f_masks)
            Disc_loss = r_disc_loss + f_disc_loss
            
            Disc_loss.backward()
            Disc_optim.step()
    
            # gen
            Gen_model.zero_grad()
            f_patch = Disc(fake,x)
            f_gan_loss=L2(f_patch,r_masks)

            L1_loss = L1(fake,y)
            Gen_loss = f_gan_loss + L1_lambda*L1_loss
            
            Gen_loss.backward()

            Gen_optim.step()
            end_time = time.time()  # End measuring time for each iteration
            elapsed_time = end_time - start_time  # Calculate elapsed time
           
    
            
        if (i+1)%iter_per_plot == 0 :

            print('Epoch [{}/{}], Step [{}/{}], disc_loss: {:.4f}, gen_loss: {:.4f},Disc(real): {:.2f}, Disc(fake):{:.2f}, gen_loss_gan:{:.4f}, gen_loss_L1:{:.4f}'.format(ep+1, epochs, i+1, len(load_Train), Disc_loss.item(), Gen_loss.item(), r_disc_loss.item(), f_disc_loss.item(), f_gan_loss.item(), L1_loss.item()))

    



    if (ep+1)%5==0:
    
        with torch.no_grad():
            Gen_model.eval()
            for data in load_Train:
                # print(data.shape)
                # x, y = split(data.to(device))

                x1,y1=[],[]
                x2,y2=[],[]
                x3,y3=[],[]
                x4,y4=[],[]
                x5,y5=[],[]
                
                x_img, y_img= split(data.to(device))
        
                for img in x_img:
                  x1.append(img[:,:128,:128])
                  x2.append(img[:,128:,:128])
                  x3.append(img[:,:128,128:])
                  x4.append(img[:,128:,128:])
                  x5.append(img[:,64:192,64:192])
        
                for img in y_img:
                  y1.append(img[:,:128,:128])
                  y2.append(img[:,128:,:128])
                  y3.append(img[:,:128,128:])
                  y4.append(img[:,128:,128:])
                  y5.append(img[:,64:192,64:192])
        
                
                input=[]
        
                input.append((torch.stack(x1),torch.stack(y1)))
                input.append((torch.stack(x2),torch.stack(y2)))
                input.append((torch.stack(x3),torch.stack(y3)))
                input.append((torch.stack(x4),torch.stack(y4)))
                input.append((torch.stack(x5),torch.stack(y5)))
        
                # print(input[0][0].shape)
                output=[]
                for x,y in input:
                    fake = Gen_model(x)
                    output.append(fake)



                fake_img=torch.zeros(30,3,256,256)

                fake_img[:,:,:128,:128]=output[0]

                fake_img[:,:,128:,:128]=output[1]

                fake_img[:,:,:128,128:]=output[2]

                fake_img[:,:,128:,128:]=output[3]

                fake_img[:,:,64:192,64:192]=output[4]
            
                    
                for j in range(fake_img.shape[0]):
                    if j%10==0:
                        figs=plt.figure(figsize=(10,10))
            
                        plt.subplot(1,3,1)
                        plt.axis("off")
                        plt.title("input image")
                        plt.imshow(np.transpose(vutils.make_grid(x_img[j], nrow=1, padding=5,
                        normalize=True).cpu(), (1,2,0)))
                    
                        plt.subplot(1,3,2)
                        plt.axis("off")
                        plt.title("generated image")
                        plt.imshow(np.transpose(vutils.make_grid(fake_img[j], nrow=1, padding=5,
                        normalize=True).cpu(), (1,2,0)))
                    
                        plt.subplot(1,3,3)
                        plt.axis("off")
                        plt.title("ground truth")
                        plt.imshow(np.transpose(vutils.make_grid(y_img[j], nrow=1, padding=5,
                        normalize=True).cpu(), (1,2,0)))
    
                break
            Gen_model.train()
            
    if ep+1==epochs:
        torch.save(Gen_model.state_dict(), 'patchmodel.pth')  # for saving the model after training
            


In [None]:
####visualsing the generated image and plotting the error map

batch_size = 1  # Set the batch size to 1 to get a single image in each iteration
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
c=0

generator.load_state_dict(torch.load('patchmodel.pth', map_location=device)) # calling the saved model

for i,data in enumerate (test_loader):
    if i%10==0:
        c+=1
        t_batch=data['image1'] # loading input and target 
        with torch.no_grad():
            generator.eval()
            x1,y1=[],[]
            x2,y2=[],[]
            x3,y3=[],[]
            x4,y4=[],[]
            x5,y5=[],[]
                
            x_img, y_img= split(t_batch.to(device))
        
            for img in x_img:
              x1.append(img[:,:128,:128])
              x2.append(img[:,128:,:128])
              x3.append(img[:,:128,128:])
              x4.append(img[:,128:,128:])
              x5.append(img[:,64:192,64:192])
    
            for img in y_img:
              y1.append(img[:,:128,:128])
              y2.append(img[:,128:,:128])
              y3.append(img[:,:128,128:])
              y4.append(img[:,128:,128:])
              y5.append(img[:,64:192,64:192])
    
            
            input=[]
    
            input.append((torch.stack(x1),torch.stack(y1)))
            input.append((torch.stack(x2),torch.stack(y2)))
            input.append((torch.stack(x3),torch.stack(y3)))
            input.append((torch.stack(x4),torch.stack(y4)))
            input.append((torch.stack(x5),torch.stack(y5)))
    
            # print(input[0][0].shape)
            output=[]
            
            for x,y in input:
                # print(x.shape)
                fake =generator(x)
                output.append(fake)

            fake_img=torch.zeros(1,3,256,256)

            fake_img[:,:,:128,:128]=output[0]

            fake_img[:,:,128:,:128]=output[1]

            fake_img[:,:,:128,128:]=output[2]

            fake_img[:,:,128:,128:]=output[3]

            fake_img[:,:,64:192,64:192]=output[4]
    
            
            
            figs=plt.figure(figsize=(10,10))

            plt.subplot(1,4,1)
            plt.axis("off")
            plt.title("input image")
            
            plt.imshow(np.rot90(np.transpose(vutils.make_grid(x_img[0], nrow=1, padding=5,
            normalize=True).cpu(), (1,2,0)),k=-1))
            
            
            plt.subplot(1,4,2)
            plt.axis("off")
            plt.title("generated image")
            plt.imshow(np.rot90(np.transpose(vutils.make_grid(fake_img[0], nrow=1, padding=5,
            normalize=True).cpu(), (1,2,0)),k=-1))
            # plt.colorbar()
            
            plt.subplot(1,4,3)
            plt.axis("off")
            plt.title("ground truth")
            plt.imshow(np.rot90(np.transpose(vutils.make_grid(y_img[0], nrow=1, padding=5,
            normalize=True).cpu(), (1,2,0)),k=-1))
            # plt.colorbar()
            # print(t_y[j].is_cuda)
            # print(fk[j].is_cuda)
            plt.subplot(1,4,4)
            plt.axis("off")
            plt.title("error map")
    
            fake_img=fake_img.to(device)
            data_eg=np.transpose(np.abs((y_img[0]-fake_img[0]).cpu().numpy()),(1,2,0))
            normalized_data=Normalize()(data_eg)
            normalized_data=np.mean(normalized_data, axis=2, keepdims=True)
            # print(normalized_data.shape)
            plt.imshow(np.rot90((normalized_data),k=-1),cmap='jet')

            if c==20:
                break

In [None]:

def calculate_psnr(true_image, fake_image, data_range):
    mse = np.mean((true_image - fake_image) ** 2)
    if mse == 0:
        return float('inf')
    psnr = 20 * np.log10(data_range / np.sqrt(mse))
    return psnr

error_list = []
structural_similarity_list = []
psnr_list = []

for i, data in enumerate(test_loader):
    t_batch = data['image1']
    with torch.no_grad():
        generator.eval()
        x1, y1 = [], []
        x2, y2 = [], []
        x3, y3 = [], []
        x4, y4 = [], []
        x5, y5 = [], []
            
        x_img, y_img = split(t_batch.to(device))
    
        for img in x_img:
            x1.append(img[:, :128, :128])
            x2.append(img[:, 128:, :128])
            x3.append(img[:, :128, 128:])
            x4.append(img[:, 128:, 128:])
            x5.append(img[:, 64:192, 64:192])

        for img in y_img:
            y1.append(img[:, :128, :128])
            y2.append(img[:, 128:, :128])
            y3.append(img[:, :128, 128:])
            y4.append(img[:, 128:, 128:])
            y5.append(img[:, 64:192, 64:192])

        inputs = [
            (torch.stack(x1), torch.stack(y1)),
            (torch.stack(x2), torch.stack(y2)),
            (torch.stack(x3), torch.stack(y3)),
            (torch.stack(x4), torch.stack(y4)),
            (torch.stack(x5), torch.stack(y5))
        ]

        outputs = []
        
        for x, y in inputs:
            fake = generator(x)
            outputs.append(fake)

        fake_img = torch.zeros_like(t_batch)

        fake_img[:, :, :128, :128] = outputs[0]
        fake_img[:, :, 128:, :128] = outputs[1]
        fake_img[:, :, :128, 128:] = outputs[2]
        fake_img[:, :, 128:, 128:] = outputs[3]
        fake_img[:, :, 64:192, 64:192] = outputs[4]
        
    t_y_numpy = y_img.cpu().detach().numpy()
    fk_batch_numpy = fake_img.cpu().detach().numpy()
    
    # Calculate RMSE
    rmse = np.sqrt(np.mean((t_y_numpy - fk_batch_numpy) ** 2))
    error_list.append(rmse)
    
    # Calculate SSIM
    t_y_numpy = np.transpose(t_y_numpy.squeeze(), (1, 2, 0))
    fk_batch_numpy = np.transpose(fk_batch_numpy.squeeze(), (1, 2, 0))
    ssi = ssim(t_y_numpy, fk_batch_numpy, multichannel=True, win_size=3, data_range=2.0)
    structural_similarity_list.append(ssi)
    
    # Calculate PSNR
    psnr_value = calculate_psnr(t_y_numpy, fk_batch_numpy, data_range=255.0)
    psnr_list.append(psnr_value)

# Convert lists to numpy arrays for statistics calculation
error_list = np.array(error_list)
structural_similarity_np = np.array(structural_similarity_list)
psnr_np = np.array(psnr_list)

# Print statistics
print('Mean RMSE error:', np.mean(error_list))
print('Max RMSE error:', np.max(error_list))
print('Min RMSE error:', np.min(error_list))
print('Mean SSIM:', np.mean(structural_similarity_np))
print('Min SSIM:', np.min(structural_similarity_np))
print('Max SSIM:', np.max(structural_similarity_np))
print('Mean PSNR:', np.mean(psnr_np))
print('Min PSNR:', np.min(psnr_np))
print('Max PSNR:', np.max(psnr_np))