In [1]:
# Producing structure preserving MRI colorizations with biologically plausible color distributions using a CycleGAN
# Oscar Moonen
# Version: 05-2022

### !!This Model is made to run with Google Colab!! ###

# Current GPU-Type
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-73330d7f-d4f7-8c9c-f406-1557b523ada1)


In [2]:
#### Settings ####
numEpochs = 100        # Epochs to Run for
batch = 6              # Batch Size    
num_blocks = 10        # ResNet blocks
learningRate = 0.0002  # Learning Rate
UpdateAfterStep = 1    # Update learning rate after epoch
UpdateBy = 0.95        # Update by: LR*UpdateBy
b1 = 0.5               # Momentum parameters
b2 = 0.999

lambdaAdversarial = 1  # Weighting Generator Losses
lambdaCycle = 1
lambdaSSIM = 1        

transpose2D = False    # Upsample technique, default= upscale -> convolution
augmentPolicy = "All"  # Differentiable Augmentations to apply (other: color, translation, cutout)
saveAfter = 500        # Plot intermediate results / Save model after # Batches


### Folders ###
prefixTest = "/content/drive/MyDrive/"                     # Intermediate plotting source MRI scans
prefixInter = "/content/drive/MyDrive/"                    # Intermediate plotting output folder
prefixModel = "/content/drive/MyDrive/"                    # Save location Generator MRI->Color
cryoDataFrame = "/content/drive/MyDrive/Data/Cryodf.pkl"   # Dataset locations
MRIDataFrame = "/content/drive/MyDrive/Data/MRIdf.pkl"

prefixSave = "MyTestRun"                                   # Prefix to add to output Files

In [3]:
# Imports Statements #
import numpy as np
import pandas as pd
import random 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as TorchTransforms 
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import random
import IPython
from torch.autograd import Variable
import os , itertools
from torch.optim.lr_scheduler import  StepLR
import imgaug as ia
import imgaug.augmenters as iaa
from math import log10, sqrt, exp
import warnings
warnings.filterwarnings('ignore')

## Colab compatability installs ##
!pip install piqa
from piqa import SSIM
!pip3 install pickle5
import pickle5 as pickle

## Mount Google Drive ##
from google.colab import drive
drive.mount('/content/drive')
import numpy as np

## Load Datasets ##
with open(cryoDataFrame, "rb") as fh:
    Cryodf= pickle.load(fh) 
with open(MRIDataFrame, "rb") as fh:
    MRIdf= pickle.load(fh) 

print()     
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
dataPointsCryo = Cryodf.shape[0]
dataPointsMRI= MRIdf.shape[0]
print("dataPoints Cryo:", dataPointsCryo)
print("dataPoints MRI:", dataPointsMRI)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Device: cuda
dataPoints Cryo: 2698
dataPoints MRI: 5881


In [4]:
# Create Tensors of MRI scans used in the intermediate plots
# Used in the intermediate plots
# name the desired images: 0,1,2,3,4.png (should be 256x256)

def makeTestTensors():  
    prefix = prefixTest   # File location
    suffix = ".png"       # Format
    tensorList = []       # Store tensors
    for i in range(5):
        path = prefix + str(i) + suffix # File Path
        testImage = np.asarray(Image.open(path)) # Load with PIL image
        testImageScaled = (testImage/127.5) -1   # Normalize from [0,256] to [-1,1]  
        testTensor = torch.tensor(testImageScaled.astype(float).T, requires_grad=False).float().unsqueeze(0) # Create tensor with correct shape (B, C, H, W)
        tensorList.append(testTensor)
    return tensorList

In [5]:
# Saves the output of the intermediate MRI images to Drive folder

def saveTestImages(genList, reconList, epoch, i):
    fig, ax = plt.subplots(5, 3) # Create plot of 15 images
    fig.set_size_inches(15, 15)
    prefix = prefixTest # File location
    add = prefixInter   # Save location
    suffix = ".png"
    for i, (gen, recon) in enumerate(zip(genList, reconList)): # Go over received list of colorizations and reconstructions
        path = prefix + str(i) + suffix # Retrieve the original input MRI
        gen = np.clip(gen, 0, 255)      # Colorization
        recon = np.clip(recon, 0, 255)  # Reconstruction
        input = np.asarray(Image.open(path))
        ax[i][0].imshow(input)
        ax[i][0].axis("off")
        ax[i][0].title.set_text('Input')
        ax[i][1].imshow(gen)
        ax[i][1].axis("off")
        ax[i][1].title.set_text('Generated')
        ax[i][2].imshow(recon)
        ax[i][2].axis("off")
        ax[i][2].title.set_text('Reconstructed')
    fig.set_size_inches(15, 15)
    plt.savefig(add + prefixSave + "-" + str(epoch) + suffix)
    plt.close()

In [6]:
# Saves the loss plots intermediately
# Smooths the errors by averaging over 10 batches

def saveErrorPlots(Generators_Loss_List, Discriminator_M_Loss_List, Discriminator_C_Loss_List, adversarial_loss_list , cycle_loss_list, ssim_loss_list, Discriminator_C_FakeDecision_List, SSIMtrack, epoch, i):
    add = prefixInter
    losses = [Generators_Loss_List, adversarial_loss_list , cycle_loss_list , ssim_loss_list, Discriminator_C_FakeDecision_List, Discriminator_M_Loss_List, Discriminator_C_Loss_List, SSIMtrack] # Losses to plot
    lossesNames = ["G_Total", "Adversarial_loss" , "Cycle_loss" , "SSIM_loss", "C_Fake_Conf", "D_M", "D_C", "SSIM_score"] # Relevant names
    
    for k, loss in enumerate(losses): # Smooth error plots by averaging
      lossList = [sum(i) for i in zip(*([iter(loss)]*10))]
      lossList = [(i/10) for i in lossList]
      losses[k] = lossList

    fig, ax = plt.subplots(4,1)
    ax[0].set_title('G Losses', size = 24)
    for name, lossList in zip(lossesNames[:4],losses[:4]):
          ax[0].plot(lossList, label=name)
    ax[0].legend(loc="upper right")
    ax[0].set_ylim([0, 3])

    ax[1].set_title('D_M Loss', size = 24)
    ax[1].plot(Discriminator_M_Loss_List, label='D_M Loss')
    ax[1].set_ylim([0, 1])
    
    ax[0].set_title('D_C Loss', size = 24)
    for name, lossList in zip(lossesNames[5:7],losses[5:7]):
          ax[2].plot(lossList, label=name)
    ax[2].legend(loc="upper right")
    ax[2].set_ylim([0, 1])

    ax[3].set_title('SSIM score', size = 24)
    ax[3].plot(SSIMtrack, label="SSIM")
    ax[3].set_ylim([0, 1])
    ax[3].legend(loc="upper right")


    # Discriminator_C_FakeDecision_List = [sum(i) for i in zip(*([iter(Discriminator_C_FakeDecision_List)]*10))]
    # Discriminator_C_FakeDecision_List = [(i/10) for i in Discriminator_C_FakeDecision_List]
    # Discriminator_M_Loss_List = [sum(i) for i in zip(*([iter(Discriminator_M_Loss_List)]*10))]
    # Discriminator_M_Loss_List = [(i/10) for i in Discriminator_M_Loss_List]
    # Discriminator_C_Loss_List = [sum(i) for i in zip(*([iter(Discriminator_C_Loss_List)]*10))]
    # Discriminator_C_Loss_List = [(i/10) for i in Discriminator_C_Loss_List]


    # ax[1].set_title('D MRI Loss', size = 24)
    # ax[1].plot(Discriminator_M_Loss_List[start:], label="Discriminator M")
    # ax[1].set_ylim([0, 0.75])
    # ax[2].set_title('D Cryo', size = 24)
    # ax[2].plot(Discriminator_C_FakeDecision_List[start:], label="Fake Des")
    # ax[2].plot(Discriminator_C_Loss_List[start:], label="Discriminator C")
    # ax[2].legend(loc="upper right")
    # ax[2].set_ylim([0, 0.75])
    # SSIMtrack = [sum(i) for i in zip(*([iter(SSIMtrack)]*10))]
    # SSIMtrack = [(i/10) for i in SSIMtrack]
    # ax[3].set_title('SSIM', size = 24)
    # ax[3].plot(SSIMtrack[start:], label="SSIM")
    # ax[3].set_ylim([0, 1])
    # ax[3].legend(loc="upper right")
    fig.set_size_inches(30, 20)
    plt.savefig(add + prefixSave +"-" + str(epoch) +'-Errors.png')
    plt.close()

In [7]:
# Save the Generator that creates colorizations every intermediate period
 
def saveModel(epoch, Generator_MC):
        PATH = prefixModel + prefixSave + str(epoch) + "model.pt"
        torch.save({'GMC_state_dict': Generator_MC.state_dict(),}, PATH)

In [8]:
# Soft data augmentation pipeline

# Tutorial: https://imgaug.readthedocs.io/en/latest/source/examples_basics.html#a-standard-use-case

seq = iaa.Sequential([                                                        # Soft Augmentation Pipeline:
    iaa.Fliplr(0.5),                                                          # Flip
    iaa.Crop(percent=(0, 0.1)),                                               # Crop
    iaa.Sometimes(0.5,iaa.GaussianBlur(sigma=(0, 0.5))),                      # Blur
    iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255), per_channel=0.5), # Noise
    iaa.Affine(                                                               # Translation / Rotation
        scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
        translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
        rotate=(-25, 25),
        shear=(-8, 8))], random_order=True) 

In [9]:
# Data Loaders

# Tutorial: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html 

class CryoDataset(Dataset):
    def __init__(self):
        self.file_names = Cryodf['fileName'].tolist()  # List of datafile names
    def __getitem__(self, index):
        Cryodf.sample(frac=1)                          # Shuffle 
        file = Cryodf.loc[Cryodf['fileName'] == self.file_names[index]]['file'].item() # Retrieve from df
        file = seq(images=file[None,:,:,:])[0]         # Use soft augmentation pipeline    
        file = (file/127.5) -1                         # Normalize to [-1:1]
        return torch.tensor(file.astype(float).T, requires_grad=True).float() # Create tensor with (B,C,H,W) shape. 
    def __len__(self):
        return len(self.file_names)

class MRIDataset(Dataset):
    def __init__(self):
        self.file_names = MRIdf['fileName'].tolist()
    def __getitem__(self, index):
        MRIdf.sample(frac=1)
        file = MRIdf.loc[MRIdf['fileName'] == self.file_names[index]]['file'].item()
        file = seq(images=file[None,:,:,:])[0]
        file = (file/127.5) -1
        return torch.tensor(file.astype(float).T, requires_grad=True).float()
    def __len__(self):
        return len(self.file_names)

CryoLoader = torch.utils.data.DataLoader(CryoDataset(), batch_size = batch, shuffle = True)
MRILoader = torch.utils.data.DataLoader(MRIDataset(), batch_size = batch, shuffle = True)

In [10]:
# Creates an image pool of recently generated images

# Citation:
# Yun A. (2019). CycleGAN-pytorch. Kaggle. https://www.kaggle.com/code/pipiking/cyclegan-pytorch/script

class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []
    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images

In [11]:
# Differentiable Augmentations

# Citation: 
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738
# (Github: https://github.com/mit-han-lab/data-efficient-gans)

def DiffAugment(x, policy='', channels_first=True):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x

def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x

def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x

def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x

def rand_translation(x, ratio=0.125):
    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(x.size(2), dtype=torch.long, device=x.device),
        torch.arange(x.size(3), dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
    return x

def rand_cutout(x, ratio=0.5):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x

AUGMENT_FNS = { # Augmentation policies
    'All' : [rand_brightness, rand_saturation, rand_contrast, rand_cutout, rand_translation], 
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],}

In [12]:
# Generator Architecture

# Written with guidance from:
# Ashwath B. (2020). CycleGAN Translating Apples->Oranges [PyTorch]. Kaggle. https://www.kaggle.com/code/balraj98/cyclegan-translating-apples-oranges-pytorch/notebook
# Yun A. (2019). CycleGAN-pytorch. Kaggle. https://www.kaggle.com/code/pipiking/cyclegan-pytorch/script

# Original Paper:
# Zhu J., Park T., Isola P., and Efros A.. "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks", in IEEE International Conference on Computer Vision (ICCV), 2017.

class ResidualBlock(nn.Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.convRes = nn.Conv2d(256,256,3)
        self.reflect1 = nn.ReflectionPad2d(1)
        self.norm256 = nn.InstanceNorm2d(256)
        self.ReLu = nn.ReLU(inplace=True)
    def block(self, x):
        x = self.reflect1(x)
        x = self.ReLu(self.norm256(self.convRes(x)))
        x = self.reflect1(x)
        x = self.norm256(self.convRes(x))
        return x
    def forward(self, x):
        return x + self.block(x)
  
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, 7, stride = 1)
        self.conv2 = nn.Conv2d(64, 128, 3, stride = 2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, stride = 2, padding=1)
        self.block = ResidualBlock()
        self.conv4 = nn.Conv2d(256, 128, 3, stride = 1, padding=1)
        self.conv5 = nn.Conv2d(128, 64, 3, stride = 1, padding=1)
        self.conv6 = nn.Conv2d(64, 3, 7, stride = 1, padding=0)

        self.reflect1 = nn.ReflectionPad2d(1)
        self.reflect3 = nn.ReflectionPad2d(3)
        self.ReLu = nn.ReLU(inplace=True)
        self.tanH = nn.Tanh()
        if transpose2D:
          self.upsample1 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1)
          self.upsample2 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1)
        else:
          self.upsample = nn.Upsample(scale_factor=2)
        self.norm256 = nn.InstanceNorm2d(256)
        self.norm128 = nn.InstanceNorm2d(128)
        self.norm64 = nn.InstanceNorm2d(64)

    def forward(self, x):
        x = self.reflect3(x)
        x = self.ReLu(self.norm64 (self.conv1(x))) #1
        x = self.ReLu(self.norm128(self.conv2(x))) #2
        x = self.ReLu(self.norm256(self.conv3(x))) #3

        for i in range(num_blocks): 
            x = self.block(x)
        if transpose2D:
            x = self.ReLu(self.norm128(self.upsample1(x))) #4
            x = self.ReLu(self.norm64(self.upsample2(x))) #5
        else:
            x = self.ReLu(self.norm128(self.conv4(self.upsample(x)))) #4
            x = self.ReLu(self.norm64(self.conv5(self.upsample(x))))  #5   
        x = self.reflect3(x)
        x = self.tanH(self.conv6(x))
        return x

In [13]:
# Discriminator Architecture

# Written with guidance from:
# Ashwath B. (2020). CycleGAN Translating Apples->Oranges [PyTorch]. Kaggle. https://www.kaggle.com/code/balraj98/cyclegan-translating-apples-oranges-pytorch/notebook
# Yun A. (2019). CycleGAN-pytorch. Kaggle. https://www.kaggle.com/code/pipiking/cyclegan-pytorch/script

# Original Paper:
# Zhu J., Park T., Isola P., and Efros A.. "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks", in IEEE International Conference on Computer Vision (ICCV), 2017.

class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, 4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, 4, stride=1, padding=1)
        self.conv5 = nn.Conv2d(512, 1, 4, stride=1, padding=1)
    
        self.norm512 = nn.InstanceNorm2d(512)
        self.norm128 = nn.InstanceNorm2d(128)
        self.norm256 = nn.InstanceNorm2d(256)
        self.zeroPad = nn.ZeroPad2d((1, 0, 1, 0))
        self.ReLu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x = self.ReLu(self.conv1(x))               #1
        x = self.ReLu(self.norm128(self.conv2(x))) #2
        x = self.ReLu(self.norm256(self.conv3(x))) #3
        x = self.ReLu(self.norm512(self.conv4(x))) #4
        x = self.conv5(self.zeroPad(x))            #5 
        return x


In [14]:

### Initialize Losses ###
adverLoss = torch.nn.MSELoss().to(device)
cycleLoss = torch.nn.L1Loss().to(device)
ssim = SSIM().cuda()

### Initialize Models ###
Generator_MC = Generator().to(device)        #MC = MRI->Color
Generator_CM = Generator().to(device)        #CM = Color->MRI
Discriminator_M = Discriminator().to(device) # M = MRI
Discriminator_C = Discriminator().to(device) # C = Color

### Optimization ###
lr=learningRate 
optimizer_Generator = torch.optim.Adam(itertools.chain(Generator_MC.parameters(), Generator_CM.parameters()), lr=lr, betas=(b1,b2))
optimizer_Discriminator_M = torch.optim.Adam(Discriminator_M.parameters(), lr=lr, betas=(b1,b2))
optimizer_Discriminator_C = torch.optim.Adam(Discriminator_C.parameters(), lr=lr, betas=(b1,b2))
scheduler = StepLR(optimizer_Generator, step_size= UpdateAfterStep, gamma=UpdateBy)

### Tracking Losses ###
Generators_Loss_List = []
Discriminator_M_Loss_List = []
Discriminator_C_Loss_List = []
adversarial_loss_list , cycle_loss_list , ssim_loss_list, SSIMtrack, Discriminator_C_FakeDecision_List = [], [], [], [], []

### Initialize Image Pool ###
num_pool = 50
fake_MRI_pool = ImagePool(num_pool)
fake_Cryo_pool = ImagePool(num_pool)

### Display Progress ###
trackSave = 0 
batches = ((dataPointsCryo-1)// batch) 
updateProgress = display(IPython.display.Pretty('Starting'), display_id=True)
updateLR = display(IPython.display.Pretty('Starting'), display_id=True)
updateD = display(IPython.display.Pretty('Starting'), display_id=True)
updateSSIM = display(IPython.display.Pretty('Starting'), display_id=True)

### Train Model ###
for epoch in range(numEpochs):
    if epoch != 0:
        scheduler.step() # Decrease learning rate by step size

    for i, (cryoImage, MRIImage) in enumerate(zip(CryoLoader, MRILoader)):
        # Update Progress
        trackSave += 1  
        updateProgress.update(IPython.display.Pretty(str(epoch)+ " | " + str(i) + "/" + str(batches)))
        updateLR.update(IPython.display.Pretty("Generator LR: " + str(optimizer_Generator.param_groups[0]['lr'])))
        updateD.update(IPython.display.Pretty("D_fake cryo: " +  str(round(sum(Discriminator_C_FakeDecision_List[-2*batch:])/(2*batch), 2))))
        updateSSIM.update(IPython.display.Pretty( "SSIM:" +  str(round(sum(SSIMtrack[-2*batch:])/(2*batch), 2))  ))
        
        # Data to device
        Real_C = cryoImage.to(device)
        Real_M = MRIImage.to(device)

        if i%2 == 0: # Update every two batches

              # G_MC
              Gen_C = Generator_MC(Real_M)                         # Gen color image
              Gen_C_Aug = DiffAugment(Gen_C, policy=augmentPolicy) # Diff augment
              Gen_C_Aug = torch.clip(Gen_C_Aug, min=-1, max=1)     
              Discriminator_C_FakeDecision = Discriminator_C(Gen_C_Aug) # Confidence of D_C
              Discriminator_C_FakeDecision_List.append(  np.mean(Discriminator_C_FakeDecision.detach().cpu().numpy()) ) # Average loss
              Generator_MC_Loss = adverLoss(Discriminator_C_FakeDecision, Variable(torch.ones(Discriminator_C_FakeDecision.size()), requires_grad= False).to(device) ) # Adversarial losss
              Reconstructed_M = Generator_CM(Gen_C)                # Reconstruct image
              Forward_Cycle_Loss = cycleLoss(Reconstructed_M, Real_M) # Reconstruction loss
              Gen_C_RGB = ((Gen_C+1)/2)                            # Change range to [0,1] for calculation of SSIM with package
              Real_M_RGB = ((Real_M+1)/2)
              ssim_temp = ssim(Real_M_RGB, Gen_C_RGB)
              ssim_out_C = 1 - ssim(Real_M_RGB, Gen_C_RGB)          # SSIM Loss
              SSIMtrack.append(ssim_temp.detach().item())           # Track SSIM between M and C

              # G_CM
              Gen_M = Generator_CM(Real_C)
              Gen_M_Aug = DiffAugment(Gen_M, policy=augmentPolicy)
              Gen_M_Aug = torch.clip(Gen_M_Aug, min=-1, max=1) 
              Discriminator_M_FakeDecision = Discriminator_M(Gen_M_Aug)   
              Generator_CM_Loss = adverLoss(Discriminator_M_FakeDecision, Variable(torch.ones(Discriminator_M_FakeDecision.size()), requires_grad= False).to(device) )
              Reconstructed_C = Generator_MC(Gen_M)
              Backward_Cycle_Loss = cycleLoss(Reconstructed_C, Real_C)
              Gen_M_RGB = ((Gen_M+1)/2)
              Real_C_RGB = ((Real_C+1)/2)
              ssim_out_M = 1 - ssim(Real_C_RGB, Gen_M_RGB)
              
              # G optimization
              ssim_out = (ssim_out_M + ssim_out_C) * 0.5
              adversarial_loss_list.append(lambdaAdversarial*(Generator_MC_Loss.detach() + Generator_CM_Loss.detach()).item())
              cycle_loss_list.append(   lambdaCycle*(Forward_Cycle_Loss.detach() + Backward_Cycle_Loss.detach()).item()  )
              ssim_loss_list.append(   lambdaSSIM*(ssim_out.detach()).item()  )
              Generators_Loss = lambdaSSIM*ssim_out + lambdaAdversarial*(Generator_MC_Loss + Generator_CM_Loss) + lambdaCycle*(Forward_Cycle_Loss + Backward_Cycle_Loss)
              Generators_Loss_List.append(Generators_Loss.item())
              optimizer_Generator.zero_grad()
              Generators_Loss.backward()
              optimizer_Generator.step()

        # D_M
        Real_M_Aug = DiffAugment(Real_M.detach().requires_grad_(True), policy=augmentPolicy)    # Diff augment
        Discriminator_M_RealDecision = Discriminator_M(Real_M_Aug)                              # Confidence on real image
        Discriminator_M_RealLoss = adverLoss(Discriminator_M_RealDecision, Variable(torch.ones(Discriminator_M_RealDecision.size()), requires_grad= False).to(device) )
        Gen_M_Pooled = fake_MRI_pool.query(Gen_M.detach())                                      # Pool generated image
        Gen_M_Pooled_Aug = DiffAugment(Gen_M_Pooled.requires_grad_(True), policy=augmentPolicy) # Diff augment
        Discriminator_M_FakeDecision = Discriminator_M(Gen_M_Pooled_Aug)                        # Confidence on fake image
        Discriminator_M_FakeLoss = adverLoss(Discriminator_M_FakeDecision, Variable(torch.zeros(Discriminator_M_FakeDecision.size()), requires_grad= False).to(device) )
        # D_M Optimization
        Discriminator_M_Loss = (Discriminator_M_RealLoss + Discriminator_M_FakeLoss) * 0.5
        Discriminator_M_Loss_List.append(Discriminator_M_Loss.item())
        optimizer_Discriminator_M.zero_grad()
        Discriminator_M_Loss.backward()
        optimizer_Discriminator_M.step()

        # D_C
        Real_C_Aug = DiffAugment(Real_C.detach().requires_grad_(True), policy=augmentPolicy)
        Discriminator_C_RealDecision = Discriminator_C(Real_C_Aug)
        Discriminator_C_RealLoss = adverLoss(Discriminator_C_RealDecision, Variable(torch.ones(Discriminator_C_RealDecision.size()), requires_grad= False).to(device) )
        Gen_C_Pooled = fake_Cryo_pool.query(Gen_C.detach())
        Gen_C_Pooled_Aug = DiffAugment(Gen_C_Pooled.requires_grad_(True), policy=augmentPolicy)
        Discriminator_C_FakeDecision = Discriminator_C(Gen_C_Pooled_Aug)
        Discriminator_C_FakeLoss = adverLoss(Discriminator_C_FakeDecision, Variable(torch.zeros(Discriminator_C_FakeDecision.size()), requires_grad= False).to(device) )
        # D_C Optimization
        Discriminator_C_Loss = (Discriminator_C_RealLoss + Discriminator_C_FakeLoss) * 0.5
        Discriminator_C_Loss_List.append(Discriminator_C_Loss.item())
        optimizer_Discriminator_C.zero_grad()
        Discriminator_C_Loss.backward()
        optimizer_Discriminator_C.step()

        # Intermediate tracking/plotting
        if trackSave >= saveAfter:
                trackSave = 0
                tensorList = makeTestTensors()
                genList = []
                reconList = []
                for tensor in tensorList: # Put intermediate tensors through the G_MC and G_CM to retrieve results
                        outGenerated = Generator_MC(tensor.to(device))
                        outRecon = Generator_CM(outGenerated)[0].detach().cpu().numpy().T 
                        outGenerated = outGenerated[0].detach()
                        outGenerated = torch.clip(outGenerated, -1, 1)
                        outGenerated = ((outGenerated.cpu().numpy().T+1)*127.5).astype(int)
                        outRecon = ((outRecon+1)*127.5).astype(int)
                        genList.append(outGenerated)
                        reconList.append(outRecon)
                # Save intermediate results
                saveTestImages(genList, reconList, epoch, i)
                saveErrorPlots(Generators_Loss_List, Discriminator_M_Loss_List, Discriminator_C_Loss_List, adversarial_loss_list , cycle_loss_list , ssim_loss_list, Discriminator_C_FakeDecision_List, SSIMtrack, epoch ,i) 
                saveModel(epoch, Generator_MC)      

0 | 189/449

Generator LR: 0.0002

D_fake cryo: 0.49

SSIM:0.7

KeyboardInterrupt: ignored