## Legacy

In [None]:


def EncLayer(self,Ich,Och,k_size=3,p=1,p_m='reflect',s=2,relu=True,drop=None,norm=True):
  '''
  @Description
    Creates dk layer using a convoutional layer, with ot without normalizaton.
    This is the downsampling part, in other words the layer of the encoder

  @Inputs
    Ich: Input channels, int
    Och: Output channles, int
    k_size. kernel size By default 3
    p. Padding, by default 1
    p_m. Padding mode, by default 'reflect'
    s. Stride, as is defined in [1] it is 2 as default
    relu. Boolean, True as defaul uses Relu. False use nn.LeakyReLU(0.2)
    drop.  by default is None. Float number indicates the percentage of a dropout layer
    norm. by default is True for instance normalization. False indicates no instance normalization layer.

  @Outputs
    Sequential model wich correspond a dk layer 
  '''
  m = nn.Sequential()
  m.add_module("conv1", nn.Conv2d(Ich, 
                            Och, 
                            kernel_size=k_size, 
                            padding=p, 
                            stride=s, 
                            padding_mode=p_m))
  
  if norm: m.add_module("Instancenorm", nn.InstanceNorm2d(Och)) 
  m.add_module("activation", nn.ReLU()) if relu else m.add_module("activation", nn.LeakyReLU(0.2))
  if drop: m.add_module("Dropout", nn.Dropout(drop)) 
  return m


def DecLayer(self,Ich,Och,k_size=3,p=1,s=2,relu=True,drop=None,norm=True,upsampling=None):
  '''
  @Description
    Creates uk layer using a deconvoutional layer, with or without normalization.
    This is the upsampling part, in other words the layer of the decoder

  @Inputs
    Ich: Input channels, int
    Och: Output channles, int
    k_size. kernel size By default 3
    p. Padding, by default 1
    p_m. Padding mode, by default 'reflect'
    s. Stride, as is defined in [1] it is 2 as default
    relu. Boolean, True as defaul uses Relu. False use nn.LeakyReLU(0.2)
    drop.  by default is None. Float number indicates the percentage of a dropout layer
    norm. by default is True for instance normalization. False indicates no instance normalization layer.
    upsampling. by default is None. int number indicates the scale factor for upsampling

  @Outputs
    Sequential model wich correspond a uk layer 
  '''
  m = nn.Sequential()
  if upsampling: m.add_module("upsampling",nn.Upsample(scale_factor=upsampling)) #scale_factor =2
  m.add_module("dconv1",  nn.ConvTranspose2d(Ich,
                                              Och, 
                                              kernel_size=k_size,
                                              padding=p,
                                              stride=s))
  
  if norm: m.add_module("Instancenorm", nn.InstanceNorm2d(Och)) 
  m.add_module("activation", nn.ReLU()) if relu else m.add_module("activation", nn.LeakyReLU(0.2))
  if drop: m.add_module("Dropout", nn.Dropout(drop)) 
  return m




  
class DecLayer(nn.Module):
  '''
    @Description
      Creates uk layer using a deconvoutional layer, with or without normalization.
      This is the upsampling part, in other words the layer of the decoder

    @Inputs
      Ich: Input channels, int
      Och: Output channles, int
      k_size. kernel size By default 3
      p. Padding, by default 1
      p_m. Padding mode, by default 'reflect'
      s. Stride, as is defined in [1] it is 2 as default
      relu. Boolean, True as defaul uses Relu. False use nn.LeakyReLU(0.2)
      drop.  by default is None. Float number indicates the percentage of a dropout layer
      norm. by default is True for instance normalization. False indicates no instance normalization layer.
      upsampling. by default is None. int number indicates the scale factor for upsampling

    @Outputs
      Sequential model wich correspond a uk layer 
    '''
  def __init__(self,Ich,Och,nor=True,k=3, s=2, p=1,op=1):
      super(DecLayer, self).__init__()
      self.conv1 = nn.ConvTranspose2d(Ich,Och, kernel_size=k, stride=s, padding=p, output_padding=op)
      if nor:
          self.instancenorm = nn.InstanceNorm2d(Och)
      self.nor = nor
      self.activation = nn.ReLU()

  def forward(self, x):
      x = self.conv1(x)
      if self.nor:
          x = self.instancenorm(x)
      x = self.activation(x)
      return x

class EncLayer(nn.Module):
  '''
  @Description
    Creates dk layer using a convoutional layer, with ot without normalizaton.
    This is the downsampling part, in other words the layer of the encoder

  @Inputs
    Ich: Input channels, int
    Och: Output channles, int
    k. kernel size By default 3
    p. Padding, by default 1
    s. Stride, as is defined in [1] it is 2 as default
    relu. Boolean, True as defaul uses Relu. False use nn.LeakyReLU(0.2)
    drop.  by default is None. Float number indicates the percentage of a dropout layer
    norm. by default is True for instance normalization. False indicates no instance normalization layer.

  @Outputs
    Sequential model wich correspond a dk layer 
  '''
  def __init__(self,Ich,Och,nor=True, k=3,p=1,s=2,act='relu'):
    super(EncLayer, self).__init__()
    self.conv1 = nn.Conv2d(Ich, Och, kernel_size=k, padding=p, stride=s, padding_mode='reflect')
    self.activation = nn.ReLU() if act == 'relu' else nn.LeakyReLU(0.2)
    if nor:
        self.instancenorm = nn.InstanceNorm2d(Och)
    self.nor = nor

  def forward(self, x):
    x = self.conv1(x)
    if self.nor:
        x = self.instancenorm(x)
    x = self.activation(x)
    return x



## Install and import packages

In [3]:
#!pip install --q light-the-torch && ltt install torch
!pip install lightning -q
!pip install split-folders -q
!pip install fastkde -q
!pip install torchmetrics -q
!pip install pytorch-fid -q
!pip install tensorboard -q


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m68.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.8/57.8 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.7/129.7 kB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.4/66.4 kB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.9/66.9 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 kB[0m [31m46.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.1/557.1 kB[0m [31m51.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
#Packages to manage data
import os
import random
from pathlib import Path
import gzip
import glob
import os.path
import nibabel as nib
import numpy as np
import pandas as pd
#from glob import glob
import gzip
import shutil
import splitfolders
import re

#Packages to manage Image and compute histograms
import fnmatch
import matplotlib.pyplot as plt
from scipy import stats
from fastkde.fastKDE import pdf
from PIL import Image
import cv2
import nibabel as nib
plt.rcParams["figure.figsize"] = (10, 10)


#Package for Lightning
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping,ModelCheckpoint
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import TensorBoardLogger


#Package for torch
import torch
from torch import nn
import torchvision.models as models
from pytorch_fid import fid_score
import torchvision.utils as vutils
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
from torchmetrics import PeakSignalNoiseRatio as PSNR
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn.functional as F

#Other packages
from skimage import color
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from time import time
import multiprocessing as mp
from collections import OrderedDict




#Important Variables and constants
seed=478

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Pre-Cleanning  - local 


Dataset1: https://openneuro.org/datasets/ds002330/versions/1.1.0   
Dataset2: https://openneuro.org/datasets/ds002382/versions/1.0.1   
Dataset3: https://www.kaggle.com/datasets/awsaf49/brats20-dataset-training-validation

In [None]:
local = False

if local:
  ######################### Move Directory
  def up_one_dir(path):
      try:
          # from Python 3.6
          parent_dir = Path(path).parents[1]
          # for Python 3.4/3.5, use str to convert the path to string
          # parent_dir = str(Path(path).parents[1])
          shutil.move(path, parent_dir)
      except IndexError:
          # no upper directory
          pass

      
  subdir=[path for path in os.scandir("D:\Dataset1") if path.is_dir()]
  for i in range(1,len(subdir)+1):
      up_one_dir(r"D:\Dataset1\sub-%02d\anat\sub-%02d_T1w.nii.gz" % (i, i))
      up_one_dir(r"D:\Dataset1\sub-%02d\anat\sub-%02d_T2w.nii.gz" % (i, i))
      
  subdir=[path for path in os.scandir("D:\Dataset2") if path.is_dir()]
  for i in range(1,len(subdir)+1):
      up_one_dir(r"D:\Dataset2\sub-%02d\anat\sub-%02d_T1w.nii.gz" % (i, i))
      up_one_dir(r"D:\Dataset2\sub-%02d\anat\sub-%02d_T2w.nii.gz" % (i, i))

  subdir=[path for path in os.scandir("D:\Dataset2") if path.is_dir()]
  for i in range(1,len(subdir)+1):
      up_one_dir(r"D:\Dataset2\sub-%02d\anat\sub-%02d_T1w.json" % (i, i))
      up_one_dir(r"D:\Dataset2\sub-%02d\anat\sub-%02d_T2w.json" % (i, i))



  ######################### unzip
  def Unzip(source_dir,dest_dir):
      for src_name in glob.glob(os.path.join(source_dir, '*.gz')):
          base = os.path.basename(src_name)
          dest_name = os.path.join(dest_dir, base[:-3])
          with gzip.open(src_name, 'rb') as infile:
              with open(dest_name, 'wb') as outfile:
                  for line in infile:
                      outfile.write(line)
                      
  subdir=[path for path in os.scandir("D:\Dataset1") if path.is_dir()]                    
  for i in range(1,len(subdir)+1):
    Unzip("D:\Dataset1\sub-%02d" % (i),"D:\Dataset1\sub-%02d" % (i))

  subdir=[path for path in os.scandir("D:\Dataset2") if path.is_dir()]                    
  for i in range(1,len(subdir)+1):
    Unzip("D:\Dataset2\sub-%02d" % (i),"D:\Dataset2\sub-%02d" % (i))


  subdir=[path for path in os.scandir("D:\Dataset3") if path.is_dir()]                    
  for i in range(1,len(subdir)+1): 
    os.rename("D:\Dataset3\BraTS20_Training_%03d"%(i), "D:\Dataset3\sub-%03d"%(i))

  subdir=[path for path in os.scandir("D:\Dataset3") if path.is_dir()]                    
  for i in range(1,len(subdir)+1): 
    os.rename("D:\Dataset3\sub-%03d\BraTS20_Training_%03d_t1.nii"%(i,i), "D:\Dataset3\sub-%03d\sub-%03d_T1w.nii"%(i,i))
    os.rename("D:\Dataset3\sub-%03d\BraTS20_Training_%03d_t2.nii"%(i,i), "D:\Dataset3\sub-%03d\sub-%03d_T2w.nii"%(i,i)) 

  dataset1=pd.read_csv("D:\Dataset1\participants.tsv",sep='\t',usecols=["participant_id","age","sex"])
  dataset1["Dataset"]="D1"
  d = {"female": "F", "male": "M","Other":"NaN"}
  dataset1.replace({"sex": d},inplace=True)
  dataset1["age"]=dataset1["age"].astype("int")

  dataset2=pd.read_csv("D:\Dataset2\participants.tsv",sep='\t',usecols=["participant_id","age","sex"])
  dataset2["Dataset"]="D2"
  dataset2["age"]=dataset2["age"].astype("int")


  dataset3=pd.read_csv("D:\Dataset3\survival_info.csv",usecols=["Brats20ID","Age"])
  dataset3.rename(columns={"Brats20ID": "participant_id", "Age": "age"},inplace=True)
  dataset3["Dataset"]="D3"
  dataset3["sex"]="NaN"
  dataset3["participant_id"]=["sub-"+i.split("_")[2] for i in dataset3.participant_id]
  dataset3["age"]=dataset3["age"].astype("int")


  dataset=pd.concat([dataset1,dataset3,dataset3])
  dataset.to_csv("D:\Dataset.csv",index=False)

## Get data

In [None]:
getdata=False
if getdata:
  T1=["/content/drive/MyDrive/TFM/Dataset%d/**/*_T1w.nii"%(i) for i in range(3,4)]
  T2=["/content/drive/MyDrive/TFM/Dataset%d/**/*_T2w.nii"%(i) for i in range(3,4)]
  bound=[(135,309),(110,260),(30,140)]

  Root_T1="/content/drive/MyDrive/TFM/T1"
  Root_T2="/content/drive/MyDrive/TFM/T2"

  for t1,t2,b in zip(T1,T2,bound):
      T1_dir=glob.glob(t1,recursive=True)
      T2_dir=glob.glob(t2,recursive=True)
      
      if  "Dataset3" in t1:
          for t1_dir,t2_dir in zip(T1_dir,T2_dir):
              T1_load=nib.load(t1_dir)
              T2_load=nib.load(t2_dir)
              s=int(re.search('\d+',Path(t1_dir).parent.name).group(0))
              for i in range(b[0],b[1]):
                  np.save(os.path.join(Root_T1,"D3_%03d_%03d_T1"%(s,i)), T1_load.dataobj[:,:,i].astype(float))
                  np.save(os.path.join(Root_T2,"D3_%03d_%03d_T2"%(s,i)), T2_load.dataobj[:,:,i].astype(float))
              print("Subject {}".format(s))

      elif "Dataset1" in t1:
          for t1_dir,t2_dir in zip(T1_dir,T2_dir):
              T1_load=nib.load(t1_dir)
              T2_load=nib.load(t2_dir)
              s=int(re.search('\d+',Path(t1_dir).parent.name).group(0))
              for i in range(b[0],b[1]):
                  np.save(os.path.join(Root_T1,"D1_%03d_%03d_T1"%(s,i)), T1_load.dataobj[:,i,:].astype(float))
                  np.save(os.path.join(Root_T2,"D1_%03d_%03d_T2"%(s,i)), T2_load.dataobj[:,i,:].astype(float))  
              
      else:
          for t1_dir,t2_dir in zip(T1_dir,T2_dir):
              T1_load=nib.load(t1_dir)
              T2_load=nib.load(t2_dir)
              s=int(re.search('\d+',Path(t1_dir).parent.name).group(0))
              for i in range(b[0],b[1]):
                  np.save(os.path.join(Root_T1,"D2_%03d_%03d_T1"%(s,i)), T1_load.dataobj[:,i,:].astype(float))
                  np.save(os.path.join(Root_T2,"D2_%03d_%03d_T2"%(s,i)), T2_load.dataobj[:,i,:].astype(float))

## Functions

In [6]:
def HistFunc(axis,filePath,color=None,alpha=0.05):
    
    data=nib.load(filePath).get_fdata()
    IdxSample=np.random.randint(data.shape[-1], size=int(data.shape[-1]*0.6))
    dataS=data[:,:,IdxSample]
    dataSR=dataS.reshape(dataS.shape[::-1])
    
    func=lambda x: cv2.resize(x, dsize=(100, 100), interpolation=cv2.INTER_NEAREST)
    values=np.ravel(np.array(list(map(func, dataSR))))
    
    kernel = stats.gaussian_kde(values)
    positions = np.linspace(values.min(), values.max(), num=100)
    histogram = kernel(positions)
    
    kwargs = dict(linewidth=1, color='black' if color is None else color, alpha=alpha)
    axis.plot(positions,histogram, **kwargs)


def PlotHist(directoty,axis,xix=True):
    color=None
    if xix==True:
        if 'HH' in directoty: color = 'red'
        elif 'Guys' in directoty: color = 'green'
        elif 'IOP' in directoty: color = 'blue'

    HistFunc(axis, directoty, color=color)
    
def Plotpdf(paths,axis,xix=True):
    IdxSample=np.random.randint(len(paths), size=int(len(paths)*0.6))
    Down_paths=list(map(lambda i: paths[i], IdxSample))
    for path in tqdm(Down_paths):
        PlotHist(path,axis,xix=True)
        
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def make_square(im):
  mn=min(im.shape[0],im.shape[1])
  new_im=Image.fromarray(im)
  new_im=new_im.resize((mn,mn))
  return np.array(new_im, dtype=float)


def preparedata(Paths,balance=True):
    #######Get images directory
    DataObject=getDir(Paths)
    df=DataObject.df

    if balance:
      drop_indices = np.random.choice(df[df["Group"]=="D3"].index, DataObject.D3-(DataObject.D1+DataObject.D2), replace=False)
      df_subset = df.drop(drop_indices).reset_index(drop=True)
      print("subset is {}".format(df_subset.shape))
    else:
      df_subset=df.copy()

    TrainDir,tmp=train_test_split(df_subset, test_size=0.20,shuffle=True)
    TestDir,ValtDir=train_test_split(tmp, test_size=0.50,shuffle=True)
    
    TrainDir=TrainDir.reset_index(drop=True)
    ValtDir=ValtDir.reset_index(drop=True)
    TestDir=TestDir.reset_index(drop=True)
    print("Train is {}".format(TrainDir.shape))
    return TrainDir,ValtDir,TestDir

seed_everything(seed)


## Define Classes

In [7]:
class getDir():
  def __init__(self, Paths):
    #################### Obtener directorios ####################
    self.T1Root=Paths["T1"]
    self.T2Root=Paths["T2"]

    self.T1files=glob.glob(Paths["T1"]+"/*",recursive=True)
    self.T2files=glob.glob(Paths["T2"]+"/*",recursive=True)
    self.df=self.getFile()


  def getFile(self):
    self.T1files.sort()
    self.T2files.sort()

    list2df=[(os.path.basename(t1).split("_")[0],t1,t2) for t1,t2 in zip(self.T1files,self.T2files)]
    df=pd.DataFrame(list2df,columns=["Group","T1","T2"])
    self.size=df.shape[0]
    self.D1=df[df.Group == 'D1'].shape[0]
    self.D2=df[df.Group == 'D2'].shape[0]
    self.D3=df[df.Group == 'D3'].shape[0]

    return df

################################################################################
    
    
class ImageTransform:
  def __init__(self, img_size=256):
    self.transform = {
        'train':  transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((256, 256)),
            transforms.Normalize(mean=[0.5], std=[0.5])]),
        'test': transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((256, 256)),
            transforms.Normalize(mean=[0.5], std=[0.5])])}

  def __call__(self, img, phase='train'):
    img = self.transform[phase](img)
    return img
    
    
################################################################################

class MRIDataset(Dataset):
  def __init__(self, dirc, transform, phase='train',factor=1):
    self.dirc = dirc
    self.transform = transform
    self.phase = phase 
    self.factor=factor       

  def __len__(self):
    return  int(self.dirc.shape[0]*self.factor)

  def __getitem__(self, idx):
      
    T1_img=np.load(self.dirc.T1[idx])
    T2_img=np.load(self.dirc.T2[idx])

    if T1_img.shape[0]!=T1_img.shape[1]:
      T1_imgN=make_square(T1_img)
      T2_imgN=make_square(T2_img)
    else:
      T1_imgN=T1_img
      T2_imgN=T2_img
    
    T1_img_T = self.transform(T1_imgN, self.phase)
    T2_img_T = self.transform(T2_imgN, self.phase)

    return T1_img_T.type(torch.float32),T2_img_T.type(torch.float32)
    
################################################################################
    
class MRIDatamodule(pl.LightningDataModule):
  def __init__(self,Paths,im_size=256,batch_size=75,factor=1):
    super(MRIDatamodule, self).__init__()
    #self.save_hyperparameters()
    #Define required parameters here
    self.batch_size=batch_size
    self.transform=ImageTransform(img_size=img_size)
    self.factor=factor
    self.TrainDir,self.ValtDir,self.TestDir=self.preparedata(Paths)
    #self._log_hyperparams =None
    self.prepare_data_per_node=False
    #self.allow_zero_length_dataloader_with_multiple_devices=True


  def preparedata(self,Paths,balance=True):
    #######Get images directory
    DataObject=getDir(Paths)
    df=DataObject.df

    if balance:
      drop_indices = np.random.choice(df[df["Group"]=="D3"].index, DataObject.D3-(DataObject.D1+DataObject.D2), replace=False)
      df_subset = df.drop(drop_indices).reset_index(drop=True)
      print("subset is {}".format(df_subset.shape))
    else:
      df_subset=df.copy()

    TrainDir,tmp=train_test_split(df_subset, test_size=0.20,shuffle=True)
    TestDir,ValtDir=train_test_split(tmp, test_size=0.50,shuffle=True)
    
    TrainDir=TrainDir.reset_index(drop=True)
    ValtDir=ValtDir.reset_index(drop=True)
    TestDir=TestDir.reset_index(drop=True)
    print("Train is {}".format(TrainDir.shape))
    return TrainDir,ValtDir,TestDir 

  def prepare_data(self):
    """
    Empty prepare_data method left in intentionally. 
    https://pytorch-lightning.readthedocs.io/en/latest/data/datamodule.html#prepare-data
    """
    pass             


  def setup(self, stage=None):
    if stage == "fit" or stage is None:
      self.Train_dataset = MRIDataset(self.TrainDir,self.transform,factor=self.factor)
      self.Val_dataset = MRIDataset(self.ValtDir, self.transform,factor=self.factor)
    
    if stage=="test":
      self.Test_dataset = MRIDataset(self.TestDir, self.transform, phase="test")
  
  def train_dataloader(self):
    #print(self.Train_dataset)
    return DataLoader(self.Train_dataset,shuffle=True,batch_size=self.batch_size)
    
  def val_dataloader(self):
    return DataLoader(self.Val_dataset,batch_size=self.batch_size)
    
  def test_dataloader(self):
    return DataLoader(self.Test_dataset,batch_size=self.batch_size)
    

## Modelo


In [8]:
class ResBlock(nn.Module):
  '''
  ResBlock Class:
  @Based on the paper: 
  - [1] Jun-Yan Zhu*, Taesung Park*, Phillip Isola, and Alexei A. Efros.
    "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks",
    in IEEE International Conference on Computer Vision (ICCV), 2017
  - [2] Dar, S. U., Yurt, M., Karacan, L., Erdem, A., Erdem, E., & Cukur, T. (2019). 
    Image synthesis in multi-contrast MRI with conditional generative adversarial networks.
    IEEE transactions on medical imaging, 38(10), 2375-2388.
  - [3] https://towardsdatascience.com/residual-network-implementing-resnet-a7da63c7b278
    

  @Description
    This class implement a Residual block, to use later in class Generator. 
    As reference one describes, residual blocks contains 2 convolutional layer
    with a intance normalization interspersed. Finally, to achive the residual
    efect, the output is added to the input.

  @Inputs
    Ich: Input channel
    k_size. Kernel size Default 3, as it is defined in [1]
    p. Padding mode as default 1,  as it is defined in [1]
    p_m. Padding mode as 'reflect' by default,
    dropOut=None


  @Outputs
    Returns the output of the block Original input+residual   
  '''
  
  def __init__(self,Ich,k_size=3,p=1,p_m='reflect',dropOut=None):
    super(ResBlock, self).__init__()

    ######################  Define the block ######################
    self.Resblock=nn.Sequential()
    self.Resblock.add_module("conv1",nn.Conv2d(Ich,Ich,kernel_size=k_size,padding=p,padding_mode=p_m))
    self.Resblock.add_module("Inst_1",nn.InstanceNorm2d(Ich))
    self.Resblock.add_module("Relu_1",nn.ReLU())
    if dropOut: self.Resblock.add_module("Drop",nn.Dropout(dropOut))
    self.Resblock.add_module("conv2",nn.Conv2d(Ich,Ich,kernel_size=k_size,padding=p,padding_mode=p_m))
    self.Resblock.add_module("Inst_2",nn.InstanceNorm2d(Ich))

  def forward(self, x):
    '''
        x: image tensor of shape (batch size, channels, height, width)
    '''
    original_x = x.clone()
    x = self.Resblock(x)
    return original_x + x




class Generator(nn.Module):
  '''
  ResBlock Class:
  @Based on the paper: 
  - [1] Jun-Yan Zhu*, Taesung Park*, Phillip Isola, and Alexei A. Efros.
    "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks",
    in IEEE International Conference on Computer Vision (ICCV), 2017
  - [2] Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., ... & Bengio, Y. (2020).
    Generative adversarial networks. Communications of the ACM, 63(11), 139-144.
  - [3] Dar, S. U., Yurt, M., Karacan, L., Erdem, A., Erdem, E., & Cukur, T. (2019). 
    Image synthesis in multi-contrast MRI with conditional generative adversarial networks.
    IEEE transactions on medical imaging, 38(10), 2375-2388.
  - [4] https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html
    

  @Description
    Following references [1,2] a generator with 9 residual blocks consists is created
    to use the cyclegan on images of 256x26. However this is the standard model. The 
    sript is open to add more residual blocks or change input parameters such as kernel,
    padding.
    
    The default layers are:
    c7s1-64,d128,d256,R256,R256,R256,R256,R256,R256,R256,R256,R256,u128,u64,c7s1-3

    Notation example:
    - c7s1-k denote a7×7 Convolution-Instance Norm-ReLU layer with k filters and stride 1. 
    - dk denotes a3×3 Convolution-InstanceNorm-ReLU layer  with k filters  and stride 2.
    - Rk denotes a residual block that contains two 3×3 convolutional layers with the same number of filters on both layer.
    - uk denotes a 3×3 fractional-strided-Convolution-InstanceNorm-ReLU layer with k filters and stride 1/2.

  @Inputs
    in_f: input (image)
    out_f. by default 64 (int), number of filters in first layer. Take into account the input size and the encoder leverls
    lvl_aut: Autoencoder levels, by default 2. Since encoder and deconder has same levels, this is a int with the levels of only one of them
    lvl_resnt: Resinual neck levels,  by default 9.



  @Outputs
    Returns the output of the block Original input+residual   
  '''
  
  def __init__(self,in_f,out_f=64,lvl_aut=2,lvl_resnt=9,upsam=None):
    super(Generator, self).__init__()

    ######################  Define constants and variables ######################
    f_deep=[2**i for i in range(1,lvl_aut+1)] #level scale
    rInput=out_f*(2**lvl_aut) # Input size Rest blocks
    self.gen=nn.Sequential()

    ######################  Define the encoder ######################
    self.gen.add_module(f"c7s1_{out_f}", self.EncLayer(in_f,out_f,nor=False,k=7,p=3,s=1,act='relu'))#c7s1-64

    #d128,d256
    Ich=out_f
    for f in f_deep:
      self.gen.add_module(f"d{out_f*f}", self.EncLayer(Ich,out_f*f,nor=True, k=3,p=1,s=2,act='relu'))
      Ich=out_f*f


    ######################  Define the residual neck ######################
    for l in range(lvl_resnt):
     self.gen.add_module(f"R{rInput}_{l}", ResBlock(rInput,k_size=3,p=1,p_m='reflect',dropOut=None))  #R256 x 9


    ######################  Define the Decoder ######################
    #u128,u64
    Ich=rInput
    for f in f_deep:
     self.gen.add_module(f"u{rInput//f}", self.DecLayer(Ich,rInput//f,nor=True,k=3,s=2,p=1,op=1,act="relu"))
     Ich=rInput//f

    ######################  Last layer ######################
    #self.c7s1_3 = nn.Sequential()
    self.gen.add_module("c7s1_3", nn.Conv2d(out_f,
                              in_f, 
                              kernel_size=7, 
                              padding=3, 
                              stride=1, 
                              padding_mode='reflect'))
    
    self.gen.add_module("Tanh", nn.Tanh())


  #--------------------------------------- Methods ---------------------------------------#
  def EncLayer(self,Ich,Och,nor=True, k=3,p=1,s=2,act='relu'):
    '''
    @Description
      Creates dk layer using a convoutional layer, with ot without normalizaton.
      This is the downsampling part, in other words the layer of the encoder

    @Inputs
      Ich: Input channels, int
      Och: Output channles, int
      k. kernel size By default 3
      p. Padding, by default 1
      s. Stride, as is defined in [1] it is 2 as default
      relu. Boolean, True as defaul uses Relu. False use nn.LeakyReLU(0.2)
      drop.  by default is None. Float number indicates the percentage of a dropout layer
      norm. by default is True for instance normalization. False indicates no instance normalization layer.

    @Outputs
      Sequential model wich correspond a dk layer 
    '''
    m = nn.Sequential()
    m.add_module("conv1", nn.Conv2d(Ich, Och, kernel_size=k, padding=p, stride=s, padding_mode='reflect'))
    if nor: m.add_module("Instancenorm", nn.InstanceNorm2d(Och)) 
    m.add_module("activation", nn.ReLU()) if act else m.add_module("activation", nn.LeakyReLU(0.2))
    return m

  def DecLayer(self,Ich,Och,nor=True,k=3,s=2,p=1,op=1,act="relu"):
    '''
    @Description
      Creates uk layer using a deconvoutional layer, with or without normalization.
      This is the upsampling part, in other words the layer of the decoder

    @Inputs
      Ich: Input channels, int
      Och: Output channles, int
      k_size. kernel size By default 3
      p. Padding, by default 1
      p_m. Padding mode, by default 'reflect'
      s. Stride, as is defined in [1] it is 2 as default
      relu. Boolean, True as defaul uses Relu. False use nn.LeakyReLU(0.2)
      drop.  by default is None. Float number indicates the percentage of a dropout layer
      norm. by default is True for instance normalization. False indicates no instance normalization layer.
      upsampling. by default is None. int number indicates the scale factor for upsampling

    @Outputs
      Sequential model wich correspond a uk layer 
    '''
    m = nn.Sequential()
    m.add_module("dconv1",  nn.ConvTranspose2d(Ich,Och, kernel_size=k, stride=s, padding=p, output_padding=op))
    if nor: m.add_module("Instancenorm", nn.InstanceNorm2d(Och)) 
    m.add_module("activation", nn.ReLU()) if act else m.add_module("activation", nn.LeakyReLU(0.2))
    return m


  #---------------------------------------  Call funtion ---------------------------------------#
  def forward(self, x):
    x=self.gen(x)
    return x




class Discriminator(nn.Module):
  '''
  Discriminator Class: 
  @Based on the paper: 
   - [1] Jun-Yan Zhu*, Taesung Park*, Phillip Isola, and Alexei A. Efros.
    "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks",
    in IEEE International Conference on Computer Vision (ICCV), 2017
   - [2] Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., ... & Bengio, Y. (2020).
    Generative adversarial networks. Communications of the ACM, 63(11), 139-144.
   - [3] https://arxiv.org/abs/1703.10593
   - [4] Dar, S. U., Yurt, M., Karacan, L., Erdem, A., Erdem, E., & Cukur, T. (2019). 
    Image synthesis in multi-contrast MRI with conditional generative adversarial networks.
    IEEE transactions on medical imaging, 38(10), 2375-2388. 
    

  @Description
    The discriminator yields  a matrix of values classifying corresponding 
    portions of the image as real or fake.

    Following references [1,2,3] for   discriminator   net-works, we use 70×70 PatchGAN.  
    After the last layer, we apply a convolution to produce a 1-dimensional output. leaky ReLUs
    with a slope of 0.2 is used to deal with vanishing problems. After the last layer, we apply
    a convolution to produce a 1-dimensional output
     
    The default layers are:
    C64-C128-C256-C512-C1

    Notation example:
    - Ck denote a 4×4 Convolution - InstanceNorm - LeakyReLU (0.2) layer with k filters and stride 2


  @Inputs
    ICh: the number of image input channels
    HCh: Hidden layers. the initial number of discriminator convolutional filters
    n. Number of layer to implement


  @Outputs
    Returns patchGAN Discriminator 
  '''

  def __init__(self,ICh,HCh=64,n=3):
    super(Discriminator, self).__init__()

    ######################  Define constants and variables ######################
    f_deep=[2**i for i in range(1,n+1)] #level scale
    self.disc=nn.Sequential()
    self.disc.add_module(f"C{HCh}",self.lcreation(ICh,HCh,k_size=4,p=1,s=2,drop=None,relu=False,norm=False))

    ######################  Define layers ######################
    input=HCh
    for f in f_deep:
      self.disc.add_module(f"C{HCh*f}",self.lcreation(input,HCh*f,k_size=4,p=1,s=2,drop=None,relu=False,norm=True))
      input=HCh*f

    ######################  Define layers ######################
    self.disc.add_module(f"C1",nn.Conv2d(HCh*(2**n), 1, kernel_size=4, padding=1))

  #---------------------------------------  Methods ---------------------------------------#
  def lcreation(self,Ich,Och,k_size=4,p=1,s=2,drop=None,relu=True,norm=True):
    '''
    @Description
    Creates layers according reference [1]. Using Convolution - InstanceNorm - LeakyReLU (0.2) 
    if it is need it.

    @Inputs
     ICh: the number of image input channels
     Och: Number of filters in the conv layer
     k_size. Kernal size, by default 4
     p. Padding by default 1
     s stride by default 2
     drop dropout factor, by default None
     relu if apply relu o leaky. By default Relu using True
     norm. Apply instance normalization. By defaul True

    @Outputs
      Returns sequential model.
    '''

    m = nn.Sequential()
    m.add_module("conv1", nn.Conv2d(Ich, 
                              Och, 
                              kernel_size=k_size, 
                              padding=p, 
                              stride=s))
    
    if norm: m.add_module("Instancenorm", nn.InstanceNorm2d(Och)) 
    m.add_module("activation", nn.ReLU()) if relu else m.add_module("activation", nn.LeakyReLU(0.2))
    if drop: m.add_module("Dropout", nn.Dropout(drop)) 
    return m
  #---------------------------------------  Call funtion ---------------------------------------#
  def forward(self, x):
    x=self.disc(x)
    return x

In [9]:
class CycleGAN(pl.LightningModule):
  '''
  CycleGAN Class: 
  @Based on the paper: 
   - [1] Jun-Yan Zhu*, Taesung Park*, Phillip Isola, and Alexei A. Efros.
    "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks",
    in IEEE International Conference on Computer Vision (ICCV), 2017
   - [2] Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., ... & Bengio, Y. (2020).
    Generative adversarial networks. Communications of the ACM, 63(11), 139-144.
   - [3] https://arxiv.org/abs/1703.10593
   - [4] Dar, S. U., Yurt, M., Karacan, L., Erdem, A., Erdem, E., & Cukur, T. (2019). 
    Image synthesis in multi-contrast MRI with conditional generative adversarial networks.
    IEEE transactions on medical imaging, 38(10), 2375-2388.
   - [5] https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html
   - [6] https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/
    

  @Description
  @Inputs
  @Outputs
  '''
  def __init__(self,
               input,
               params,
               features=64):
    super(CycleGAN,self).__init__()

    ############# Customize class #############
    self.save_hyperparameters(params)
    self.automatic_optimization = False
    self.target_shape=target_shape
    #self.device="cuda"
    self.lr = params["lr"]   
    self.b1 = params["b1"]
    self.b2 = params["b2"]
    self.lbc_T1 = params["lbc_T1"]   
    self.lbc_T2 = params["lbc_T2"]
    self.btch_size = params["batch_size"]
    self.target_shape = params["target_shape"]
    self.lbi=params["lbi"]

    ############# Define components #############
    self.G_T1_T2=Generator(input,out_f=features,lvl_aut=2,lvl_resnt=9)
    self.D_T1=Discriminator(input,HCh=features,n=3)

    self.G_T2_T1=Generator(input,out_f=features,lvl_aut=2,lvl_resnt=9)
    self.D_T2=Discriminator(input,HCh=features,n=3)


    self.G_T1_T2=self.G_T1_T2.apply(self.weights_init)
    self.D_T1=self.D_T1.apply(self.weights_init)
    self.G_T2_T1=self.G_T2_T1.apply(self.weights_init)
    self.D_T2=self.D_T2.apply(self.weights_init)



   ############# Define loss #############
    self.identity_loss = torch.nn.L1Loss()
    self.adv_loss = torch.nn.MSELoss() #adversarial loss function to keep track of how well the GAN is fooling the discriminator and how well the discriminator is catching the GAN
    self.cycle_loss = torch.nn.L1Loss()
    

  def forward(self, x):
    x=self.G_T1_T2(x)
    return x

  def training_step(self, batch):
    '''
      @Description:
      - [1] https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
      Compute loss accorgin original implementation of CycleGAN.
    

      @Inputs

      @Outputs
    '''
    ############# Initialization #############
    real_T1, real_T2 = batch
    Gopt,Dopt_T1,Dopt_T2=self.configure_optimizers()

    ############# update discriminator #############
    #### Discriminator T1
    self.toggle_optimizer(Dopt_T1)
   
    f_T1 = self.G_T2_T1(real_T2)
    T1Loss_Dics = self.DiscLoss(real_T1,f_T1,disc="T1")
    
    self.manual_backward(T1Loss_Dics,retain_graph=True)
    Dopt_T1.step()
    Dopt_T1.zero_grad() # Zero out the gradient before backpropagation
    self.untoggle_optimizer(Dopt_T1)

    #### Discriminator T2
    self.toggle_optimizer(Dopt_T2)
    
    f_T2 = self.G_T1_T2(real_T1)
    T2Loss_Dics = self.DiscLoss(real_T2,f_T2,disc="T2")
    
    self.manual_backward(T2Loss_Dics,retain_graph=True)
    Dopt_T2.step()
    Dopt_T2.zero_grad() # Zero out the gradient before backpropagation
    self.untoggle_optimizer(Dopt_T2)

    ############# update Generator #############
    self.toggle_optimizer(Gopt)
    gen_loss, f_T1, f_T2,Iden_term,Cycle_term,Adv_term = self.GenLoss(real_T1, real_T2)
    
    self.manual_backward(gen_loss) # Update gradients
    Gopt.step() # Update optimizer
    Gopt.zero_grad()
    self.untoggle_optimizer(Gopt)
    
    ############# Compute Training metrics #############
    #G_psnr_T2,G_ssim_T2,G_psnr_T1,G_ssim_T1=ComputeMetrics(f_T2, real_T2,f_T1, real_T1)


    ########### Loggers ###########
    self.log("D_loss_T1", T1Loss_Dics, prog_bar=True)
    self.log("D_loss_T2", T2Loss_Dics, prog_bar=True)
    self.log("G_loss", gen_loss, prog_bar=True)
    #self.log("G_psnr_T2", G_psnr_T2)
    #self.log("G_ssim_T2", G_ssim_T2)
    #self.log("G_psnr_T1", G_psnr_T1)
    #self.log("G_ssim_T1", G_ssim_T1)


    #Loss
    loss={'G_loss': gen_loss, 
          'D_loss_T2': T2Loss_Dics, 
          'D_loss_T1': T1Loss_Dics, 
          'identity': Iden_term,
          'Cycle_term': Cycle_term, 
          "Adver_term":Adv_term}#,
          #"G_psnr_T2": G_psnr_T2,
          #"G_ssim_T2": G_ssim_T2,
          #"G_psnr_T1": G_psnr_T1,
          #"G_ssim_T1": G_ssim_T1}
        
    return loss

  def validation_step(self, batch, batch_idx):

      ############# Initialization #############
      real_T1, real_T2 = batch
      Gopt,Dopt_T1,Dopt_T2=self.configure_optimizers()
      
      ############# update discriminator #############
      #### Discriminator T1
      f_T1 = self.G_T2_T1(real_T2)
      T1Loss_Dics = self.DiscLoss(real_T1,f_T1,disc="T1")
   
      #### Discriminator T2
      f_T2 = self.G_T1_T2(real_T1)
      T2Loss_Dics = self.DiscLoss(real_T2,f_T2,disc="T2")

      ############# update Generator #############
      gen_loss, f_T1, f_T2,Iden_term,Cycle_term,Adv_term = self.GenLoss(real_T1, real_T2)

      ############# Compute Training metrics #############
      G_psnr_T2,G_ssim_T2,G_psnr_T1,G_ssim_T1,Vfid_T1,Vfid_T2=self.ComputeMetrics(f_T2, real_T2,f_T1, real_T1)

      ########### Loggers ###########
      self.log("Dval_loss_T1", T1Loss_Dics, prog_bar=True)
      self.log("Dval_loss_T2", T2Loss_Dics, prog_bar=True)
      self.log("Gval_loss", gen_loss, prog_bar=True)
      self.log("Gval_psnr_T2", G_psnr_T2)
      self.log("Gval_ssim_T2", G_ssim_T2)
      self.log("Gval_psnr_T1", G_psnr_T1)
      self.log("Gval_ssim_T1", G_ssim_T1)
      self.log("Vfid_T1", Vfid_T1)
      self.log("Vfid_T2", Vfid_T2)

      loss= {'Gval_loss': gen_loss,
             'Dval_loss_T2': T2Loss_Dics,
             'Dval_loss_T1': T1Loss_Dics,
             'Val_identity': Iden_term,
             'Val_Cycle_term': Cycle_term,
             "Val_Adver_term":Adv_term,
             "Gval_psnr_T2": G_psnr_T2,
             "Gval_ssim_T2": G_ssim_T2,
             "Gval_psnr_T1": G_psnr_T1,
             "Gval_ssim_T1": G_ssim_T1,
             "Vfid_T1": Vfid_T1,
             "Vfid_T2": Vfid_T2}


      return loss
  
  def configure_optimizers(self):
    '''
    @Description:
    - [1] https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

    Inicialize the optimizers. As it describe in original CycleGAN implementation [1]
    the optimizers are ADAMS. 
  

    @Inputs

    @Outputs
    '''

    lr = self.lr
    b1 = self.b1
    b2 = self.b2

    #Gopt_T1_T2 = torch.optim.Adam(self.G_T1_T2.parameters(), lr=lr, betas=(b1, b2))
    Dopt_T1= torch.optim.Adam(self.D_T1.parameters(), lr=lr, betas=(b1, b2))

    #Gopt_T2_T1 = torch.optim.Adam(self.G_T2_T1.parameters(), lr=lr, betas=(b1, b2))
    Dopt_T2 = torch.optim.Adam(self.D_T2.parameters(), lr=lr, betas=(b1, b2))
    Gopt= torch.optim.Adam(list(self.G_T1_T2.parameters()) + list(self.G_T2_T1.parameters()), lr=lr, betas=(0.5, 0.999))

    return Gopt,Dopt_T1,Dopt_T2
  
  def DiscLoss(self,real,fake,disc="T1"):
    '''
    @Description
    This function computes the discriminator loss using the adversarial loss funtion
    MSE. Taking the target label and the discriminator predictions returns the adversarial loss.
    With adverarial loss from real and from fake image we compute the discriminator loss such as:

    discriminator loss= (adv_fake+adv_real)/2

    @Inputs
    real. Tensor, real image.
    fake. Tensor, fake image.
    
    @Outputs
    Discriminator loss

    '''

    if disc == "T1":
      disc_fake_hat = self.D_T1(fake.detach())      
      disc_real_hat = self.D_T1(real)
    else:
      disc_fake_hat = self.D_T2(fake.detach())
      disc_real_hat = self.D_T2(real)

    fake_loss = self.adv_loss(disc_fake_hat, torch.zeros_like(disc_fake_hat))
    real_loss = self.adv_loss(disc_real_hat, torch.ones_like(disc_real_hat))

    r=(fake_loss + real_loss) / 2
    return r

  def GenLoss(self, real_T1, real_T2):
    '''
    @Description
    @Inputs
    @Outputs
    '''

    #compute fakes
    f_T1 = self.G_T2_T1(real_T2)
    f_T2 = self.G_T1_T2(real_T1)

    #Compute Discriminators output
    dic_f_T1_hat = self.D_T1(f_T1)
    dic_f_T2_hat = self.D_T2(f_T2)

    # Compute adversarial loss AdvLoss_T2_T1 +  AdvLoss_T1_T2
    Adv_term=self.adv_loss(dic_f_T1_hat, torch.ones_like(dic_f_T1_hat)) + self.adv_loss(dic_f_T2_hat, torch.ones_like(dic_f_T2_hat))

    # Compute Cycles
    C_T1 = self.G_T2_T1(f_T2)
    C_T2 = self.G_T1_T2(f_T1)

    # Compute Cycle consistancy. 
    Cycle_term=self.lbc_T1*self.cycle_loss(C_T1,real_T1)+self.lbc_T2*self.cycle_loss(C_T2,real_T2)
        
    #Compute Identities
    identity_T1 = self.G_T2_T1(real_T1)
    identity_T2 = self.G_T1_T2(real_T2)

    # Compute Identity term
    Iden_term =  self.identity_loss (identity_T1, real_T1) + self.identity_loss (identity_T2, real_T2)

    # Compute Total loss
    gen_loss = self.lbi * Iden_term +  Cycle_term + Adv_term


    return gen_loss, f_T1, f_T2,Iden_term,Cycle_term,Adv_term

  def weights_init(self,m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

  def ComputeMetrics(self,f_T2, real_T2,f_T1, real_T1):
    ############# Metrics Generator #############
    psnr_metric = PSNR().to(self.device)
    ssim_metric = SSIM().to(self.device)

    G_psnr_T2 = psnr_metric(f_T2, real_T2)
    G_ssim_T2 = ssim_metric(f_T2, real_T2)

    G_psnr_T1 = psnr_metric(f_T1, real_T1)
    G_ssim_T1 = ssim_metric(f_T1, real_T1)

    #Compute FID with the InceptionV3 model
    #block_idx = models.inceptionV3.BLOCK_INDEX_BY_DIM[2048]
    #model1 = models.inceptionV3([block_idx])
    #model1.eval()

    # Compute the FID score
    fid_T1 = 0.0 #fid_score(real_T1, f_T1, model1, device=self.device)
    fid_T2 = 0.0 #fid_score(real_T2, f_T2, model1, device=self.device)

    return G_psnr_T2,G_ssim_T2,G_psnr_T1,G_ssim_T1,fid_T1,fid_T2







## Callbacks

In [19]:
#Entre mas alto el SSIM mejor

chk_pth="/content/drive/MyDrive/TFM/Checkpoints"
n_train_steps=1000
patient=800

early_stop_callback = EarlyStopping(
   monitor='Gval_ssim_T2',
   patience=patient,
   verbose=False,
   mode='max'
)

check = ModelCheckpoint(
    save_top_k=1,
    monitor="Gval_ssim_T2",
    mode="max",
    dirpath=chk_pth,
    every_n_train_steps=n_train_steps,
    filename="Model-{step:06d}{epoch:03d}-{Gval_loss:.2f}-{Val_identity:.2f}"
)


class VisualizeFakeImages(pl.Callback):
  def __init__(self, dataloader, every_n_steps=100):
    super().__init__()
    self.dataloader = dataloader
    self.every_n_steps = every_n_steps
      
  def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
    # Check if this step is the right one to visualize the fake images
    if (trainer.global_step + 1) % self.every_n_steps != 0:
      return


    # Generate fake images
    T1, T2 = batch
    with torch.no_grad():
      f_T2 = pl_module.G_T2_T1(T1)
    
    # Create a grid of real and fake images
    grid = vutils.make_grid(
        torch.cat([T2, f_T2], dim=0),
        normalize=True,
        scale_each=True,
        nrow=1
    )
    
    logger = trainer.logger
    if logger is not None:
      logger.experiment.add_image(
          f"Image_Epoch{trainer.current_epoch}_Batch{batch_idx}",
          grid.permute(1, 2, 0).cpu().numpy(),
          global_step=trainer.global_step
      )

    # Display the grid of images
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.show()



## Cycle

In [13]:
#Parameters
batch_size = 1  #When batch_size = 1, it's instance_normalization. 
#When batch_size > 1, it's batch_normalization. instance_normalization
#is better than batch_normalization for image2image transfer.
img_size = 256
Paths = {"T1":"/content/drive/MyDrive/TFM/T1","T2":"/content/drive/MyDrive/TFM/T2"}
lr=0.0002
lbc_T1 = 10
lbc_T2 = 10
lbi = 0.1
b1 = 0.5
b2 = 0.999
target_shape = 256
epochs = 4
input=1

#Get Image Dataloader
data_mri=MRIDatamodule(Paths,im_size=img_size,batch_size=batch_size,factor=0.1)


params = {'lr': lr,
          'lbc_T1': lbc_T1,
          'lbc_T2': lbc_T2,
          'lbi': lbi,
          'lr': 1e-05,
          'batch_size': (0.9, 0.999),
          'b1': b1,
          'b2': b2,
          'target_shape': target_shape,
          'batch_size': batch_size}

subset is (41268, 3)
Train is (33014, 3)


In [None]:
#Instance Model
model= CycleGAN(input,params)

#Instance Callbaks
tb_logger = TensorBoardLogger("/content/drive/MyDrive/TFM/loggers/", name="cycleGAN")
vis= VisualizeFakeImages(data_mri, every_n_steps=n_train_steps)


trainer = pl.Trainer(
    max_epochs=epochs,
    accelerator="gpu",
    val_check_interval=0.5,
    callbacks=[check,vis],
    logger=tb_logger)

#trainer = Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
trainer.fit(model, data_mri)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name          | Type          | Params
------------------------------------------------
0 | G_T1_T2       | Generator     | 11.4 M
1 | D_T1          | Discriminator | 2.8 M 
2 | G_T2_T1       | Generator     | 11.4 M
3 | D_T2          | Discriminator | 2.8 M 
4 | identity_loss | L1Loss        | 0     
5 | adv_loss      | MSELoss       | 0     
6 | cycle_lo

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

DataLoader:   
[1]https://stackoverflow.com/questions/73191999/when-to-use-prepare-data-vs-setup-in-pytorch-lightning

[2] https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/213
[3] https://scikit-image.org/docs/dev/auto_examples/transform/plot_ssim.html