# **PART 1: CONTRASTIVE LEARNING USING CLIP**
In this section, we use the CLIP model to do contrastive learning on our paired CT-MR images, hence obtaining embeddings with better semantic representation, which will be later used as a starting point when doing image-to-image translation between unpaired CT and MR bulk of images.

Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torchvision.models import *
from torchvision.datasets import ImageFolder
from torchvision import transforms, utils
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from glob import glob
from itertools import product
from fastai.vision import *
from fastai.vision.models import *
import glob
import argparse
import itertools
import sys
import os
import random
import time
import datetime

In [2]:
# torch.manual_seed(42)

BATCH_SIZE=32
NUMBER_EPOCHS=5
IMG_SIZE=256
device = 'cuda'

Defining the generator that will be later used in the Attention-based CycleGAN

In [3]:
class ResnetBlock(nn.Module):
    def __init__(self, dim, norm_layer, use_dropout):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, norm_layer, use_dropout)

    def build_conv_block(self, dim, norm_layer, use_dropout):
        conv_block = [nn.ReflectionPad2d(1),
                       nn.Conv2d(dim, dim, kernel_size=3),
                       norm_layer(dim),
                       nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        conv_block += [nn.ReflectionPad2d(1),
                       nn.Conv2d(dim, dim, kernel_size=3),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

In [4]:
class Encoder(nn.Module):
    def __init__(self, input_nc=3, output_nc=4, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6):
        assert(n_blocks >= 0)
        super(Encoder, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ReflectionPad2d(1),
                      nn.Conv2d(ngf * mult, int(ngf * mult / 2),
                                kernel_size=3, stride=1),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True),
                      nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2)*4,
                                kernel_size=1, stride=1),
                      nn.PixelShuffle(2),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True),
                     ]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        output = self.model(input)
        attention_mask = F.sigmoid(output[:, :1])
        content_mask = output[:, 1:]
        attention_mask = attention_mask.repeat(1, 3, 1, 1)
        result = content_mask * attention_mask + input * (1 - attention_mask)

        return result

Defining our customized CLIP network

In [5]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise_conv = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=in_channels,
        )
        self.pointwise_conv = nn.Conv2d(
            in_channels, out_channels, kernel_size=1, stride=1, padding=0
        )

    def forward(self, x):
        x = self.depthwise_conv(x)
        x = self.pointwise_conv(x)
        return x

In [6]:
class CLIP(nn.Module):
    def __init__(self, image_encoder_1, image_encoder_2):
        super(CLIP, self).__init__()
        self.image_encoder_1 = image_encoder_1
        self.image_encoder_2 = image_encoder_2

        self.conv1_1 = DepthwiseSeparableConv(3, 32, kernel_size=3, stride=2, padding=1)
        self.conv1_2 = DepthwiseSeparableConv(3, 32, kernel_size=3, stride=2, padding=1)

        self.conv2_1 = DepthwiseSeparableConv(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv2_2 = DepthwiseSeparableConv(32, 64, kernel_size=3, stride=2, padding=1)

        self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)


        self.fc_1 = nn.Linear(4096, 1024)
        self.fc_2 = nn.Linear(4096, 1024)



    def forward(self, images_1, images_2):
        image_encoder_1 = self.image_encoder_1(images_1)
        image_encoder_1 = self.pool_1(F.relu(self.conv1_1(image_encoder_1)))
        image_encoder_1 = self.pool_1(F.relu(self.conv2_1(image_encoder_1)))
        image_encoder_1 = image_encoder_1.view(-1, 4096)
        image_embeddings_1 = F.relu(self.fc_1(image_encoder_1))

        image_encoder_2 = self.image_encoder_1(images_1)
        image_encoder_2 = self.pool_2(F.relu(self.conv1_2(image_encoder_2)))
        image_encoder_2 = self.pool_2(F.relu(self.conv2_2(image_encoder_2)))
        image_encoder_2 = image_encoder_2.view(-1, 4096)
        image_embeddings_2 = F.relu(self.fc_2(image_encoder_2))

        return image_embeddings_1, image_embeddings_2

Helper functions

In [7]:
def imshow(img, text=None, should_save=False):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def show_plot(iteration,loss):
    plt.plot(iteration,loss)
    plt.show()

Constructing the data loader

In [8]:
class PairedImageDataset(Dataset):
    def __init__(self, root_dir_1, root_dir_2, transform=None):
        self.root_dir_1 = root_dir_1
        self.root_dir_2 = root_dir_2
        self.transform = transform
        self.image_list_1 = sorted(os.listdir(root_dir_1))
        self.image_list_2 = sorted(os.listdir(root_dir_2))
        self.all_pairs = list(product(self.image_list_1, self.image_list_2))

    def __len__(self):
        return len(self.all_pairs)

    def __getitem__(self, idx):
        img_name_1, img_name_2 = self.all_pairs[idx]
        img_path_1 = os.path.join(self.root_dir_1, img_name_1)
        img_path_2 = os.path.join(self.root_dir_2, img_name_2)

        image_1 = Image.open(img_path_1).convert("RGB")
        image_2 = Image.open(img_path_2).convert("RGB")

        if self.transform:
            if transforms.RandomHorizontalFlip().p > 0.5:
                image1 = transforms.functional.hflip(image_1)
                image2 = transforms.functional.hflip(image_2)
            image_1 = self.transform(image_1)
            image_2 = self.transform(image_2)

        label = 1 if img_name_1.replace('CT','') == img_name_2.replace('MR','') else 0


        return image_1, image_2, torch.tensor(label)

In [9]:
transform = transforms.Compose([
    # transforms.CenterCrop(200),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

root_dir_1 = "/content/drive/MyDrive/image_CT"
root_dir_2 = "/content/drive/MyDrive/image_MR"

paired_dataset = PairedImageDataset(root_dir_1, root_dir_2, transform=transform)
data_loader = DataLoader(paired_dataset, batch_size=BATCH_SIZE, shuffle=True)

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/image_CT'

We use focal loss to address the class imbalance problem of pair labels

In [10]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')

        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return torch.mean(focal_loss)
        elif self.reduction == 'sum':
            return torch.sum(focal_loss)
        elif self.reduction == 'none':
            return focal_loss
        else:
            raise ValueError("Invalid reduction option. Use 'mean', 'sum', or 'none'.")

Model training (no validation needed)

In [45]:
image_encoder_1 = Encoder() #ct
image_encoder_2 = Encoder() #mr

clip_model = CLIP(image_encoder_1, image_encoder_2).to(device)
criterion = nn.CosineEmbeddingLoss() #nn.CrossEntropyLoss() #FocalLoss(alpha=1, gamma=2, reduction='mean')
optimizer = optim.Adam(clip_model.parameters(), lr=3e-5)

counter = []
loss_history = []
iteration_number= 0
NUMBER_EPOCHS = 5

In [8]:
for epoch in range(NUMBER_EPOCHS):
    print(f'Epoch {epoch + 1} starts.')
    for i, data in enumerate(data_loader, 0):
        img_ct, img_mr , labels = data
        img_ct, img_mr , labels = img_ct.to(device), img_mr.to(device) , labels.to(device)

        image_embedding, text_embedding = clip_model(img_ct, img_mr)

        loss = criterion(image_embedding, text_embedding, labels)

        optimizer.zero_grad()
        loss.backward()
        loss_history.append(loss.item())
        optimizer.step()

    print(f"Epoch {epoch + 1}/{NUMBER_EPOCHS}, Loss: {loss.item()}")

Epoch 1 starts.


NameError: name 'data_loader' is not defined

In [None]:
model_path = "/content/drive/MyDrive/dl_project/CLIP.pth"

torch.save({
    'model_state_dict': clip_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, model_path)

In [9]:
model_path = "CLIP.pth"
clip_model.load_state_dict(torch.load(model_path)['model_state_dict'])

<All keys matched successfully>

In [10]:
image_folder = '/content/drive/MyDrive/medvqa/juh_mr_ct/mr'
image_paths = [f"{image_folder}/{img}" for img in os.listdir(image_folder)]

images = [transform(Image.open(image_path).convert("RGB")).unsqueeze(0).to(device) for image_path in image_paths]
image_embeddings = [clip_model(image, image)[1] for image in images]
image_embeddings = torch.stack(image_embeddings)

query_image_path = "/content/drive/MyDrive/medvqa/juh_mr_ct/ct/10_CT_s1.png"
query_image = transform(Image.open(query_image_path).convert("RGB")).unsqueeze(0).to(device)
query_embedding, _ = clip_model(query_image, torch.zeros(1, dtype=torch.long).to(device))

similarity_scores = [nn.functional.cosine_similarity(query_embedding, image_embedding, dim=1) for image_embedding in image_embeddings]
similarity_scores = torch.stack(similarity_scores)

most_similar_index = torch.argmax(similarity_scores).item()
most_similar_image_path = image_paths[most_similar_index]

print(f"Most similar image: {most_similar_image_path}")

FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/medvqa/juh_mr_ct/mr'

# **PART 2: ATTENTIONGAN TO MAP UNPAIRED DOMAINS**
We use AttentionGAN-v1 to map unpaired images from separate domains i.e. CT -> MR and MR -> CT. We initialize the weights for the generators of both mappings with the ones learned from the CLIP network.


Defining utilities

In [11]:
def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)

In [12]:
def tensor2image(tensor):
    image = 127.5*(tensor[0].cpu().float().numpy() + 1.0)
    if image.shape[0] == 1:
        image = np.tile(image, (3,1,1))
    return image.astype(np.uint8)

In [13]:
class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

In [14]:
class LambdaLR():
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)

In [15]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant(m.bias.data, 0.0)

In [16]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = os.listdir(os.path.join(root, 'CT'))
        self.files_A = [os.path.join(root, 'CT', f) for f in self.files_A]
        self.files_B = os.listdir(os.path.join(root, 'MR'))
        self.files_B = [os.path.join(root, 'MR', f) for f in self.files_B]

    def __getitem__(self, index):
        item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))

        if self.unaligned:
            item_B = self.transform(Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)]))
        else:
            item_B = self.transform(Image.open(self.files_B[index % len(self.files_B)]))

        return {'A': item_A, 'B': item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [17]:
class Generator(nn.Module):
    def __init__(self, input_nc=3, output_nc=4, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6):
        assert(n_blocks >= 0)
        super(Generator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2**i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2**n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, use_dropout=use_dropout)]

        for i in range(n_downsampling):
            mult = 2**(n_downsampling - i)
            model += [nn.ReflectionPad2d(1),
                      nn.Conv2d(ngf * mult, int(ngf * mult / 2),
                                kernel_size=3, stride=1),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True),
                      nn.Conv2d(int(ngf * mult / 2), int(ngf * mult / 2)*4,
                                kernel_size=1, stride=1),
                      nn.PixelShuffle(2),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True),
                     ]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        output = self.model(input)
        attention_mask = F.sigmoid(output[:, :1])
        content_mask = output[:, 1:]
        attention_mask = attention_mask.repeat(1, 3, 1, 1)
        result = content_mask * attention_mask + input * (1 - attention_mask)

        return result, attention_mask, content_mask

In [18]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv_tower = nn.Sequential(
            nn.Conv2d(3,   64,  4, 2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(),
            nn.Conv2d(64,  128,  4, 2),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(),
            nn.Conv2d(128,  256,  4, 2),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(),
            nn.Conv2d(256,  512, 4, 2),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(),
            nn.Conv2d(512, 512, 4),
            nn.LeakyReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(512, 1, 1),
        )

    def forward(self, img):
        output = self.conv_tower(img)
        return output

In [19]:
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio, UniversalImageQualityIndex, VisualInformationFidelity

In [20]:
import wandb

In [46]:
wandb.init(
    # set the wandb project where this run will be logged
    project='deep_learning_project',
    entity='thedlproject'
)

0,1
CT Identity Loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
CT Pixel-Wise Reconstruction Loss,▃▅▁▁▃▄▃▃▃▅▂▅▃▅▂▅▄▄▆▃▄█▄▄▄▇▃▆▃█▃▄▄▄▃▃▃▄▆▅
CT Reconstruction Loss,▆▇▃▃▄▄▆▆▅▄▆▅▃▅▄▅▇▄▄▁▃█▃▂▄█▂▂▃▅▃▁▆▂▄▄▄▂█▂
CT to MR Loss,▂▁▂▃▆▃▃▇▃▄▃▅▁▅█▃▃▃▅▄▄▇▅▆█▅▄▄▇▆▅▇▅▃▆▅▄▄▃▂
Discriminator Loss,▄▆▅▄▂▅▆▂▆▃▃▁▅▁▅▄▄▂▄▃▇▁▂█▄▂█▄▂▂▂▅▆▅▁▂▁▁▂▂
Generator Loss,▄▂▃▆▄▃▄▅▃▅▃▆▁▄▅▃▅▄▆▅▅█▄▆▅▇▃▅▇█▅▇▃▄▇▆▅▄▄▃
MR Identity Loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
MR Pixel-Wise Reconstruction Loss,▇▄▁▄▂▁▅▂▆▂▄▅▃▄▃▃▂▄▄▃▄▄▄█▅▆▄▅▅▅▇▅█▅▆█▆▅▅▇
MR Reconstruction Loss,▄█▂▆▁▃▅▇▇▄▇▅▃▅▄▄▃▅▆▃▂▄▅▄▇▆▄▆▇█▄▃▆▃▃▂▅▃▇▆
MR to CT Loss,▅▂▄▇▃▃▄▁▂▆▂▆▁▃▂▂▅▃▇▅▆█▄▅▂▇▁▅▆█▃▇▁▄█▇▅▂▃▄

0,1
CT Identity Loss,0.0
CT Pixel-Wise Reconstruction Loss,0.56545
CT Reconstruction Loss,0.55495
CT to MR Loss,0.47967
Discriminator Loss,0.02521
Generator Loss,0.5953
MR Identity Loss,0.0
MR Pixel-Wise Reconstruction Loss,0.44827
MR Reconstruction Loss,0.42146
MR to CT Loss,0.47124


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669731796719135, max=1.0…

In [47]:
parser = argparse.ArgumentParser()
epoch = 0 #starting epoch
n_epochs = 10 #number of epochs of training
batch_size = 1 #size of the batches
dataroot = 'train' # '/content/drive/MyDrive/DL Project/Data/unpaired line 7 /train' #root directory of the dataset
save_name = 'MedAttentionGAN_2_no_clip'
lr = 0.0001 #initial learning rate
decay_epoch = 8 #epoch to start linearly decaying the learning rate to 0
size = 256 #size of the data crop (squared assumed)
input_nc = 3 #number of channels of input data
output_nc = 3 #number of channels of output data
cuda = 'store_true' #use GPU computation
n_cpu = 8 #number of cpu threads to use during batch generation
lambda_cycle = 10
lambda_identity = 0
lambda_pixel = 1
lambda_reg = 1e-6

gan_curriculum = 10  #Strong GAN loss for certain period at the beginning
starting_rate = 0.01 #Set the lambda weight between GAN loss and Recon loss during curriculum period at the beginning. We used the 0.01 weight.
default_rate = 0.5 #Set the lambda weight between GAN loss and Recon loss after curriculum period. We used the 0.5 weight.


if torch.cuda.is_available() and not cuda:
    print("No CUDA!")

In [36]:
model_path = "CLIP.pth" #"/content/drive/MyDrive/DL Project/CLIP.pth"
clip_model.load_state_dict(torch.load(model_path)['model_state_dict'])

<All keys matched successfully>

In [48]:
###### Definition of variables ######
# Networks
netG_A2B = Generator()
netG_B2A = Generator()
netG_A2B.load_state_dict(clip_model.image_encoder_1.state_dict()) # load from CLIP, ct encoder
netG_B2A.load_state_dict(clip_model.image_encoder_2.state_dict()) # load from CLIP, mr encoder

netD_A = Discriminator()
netD_B = Discriminator()

print('---------- Networks initialized -------------')
print_network(netG_A2B)
print_network(netG_B2A)
print_network(netD_A)
print_network(netD_B)
print('-----------------------------------------------')

if cuda:
    netG_A2B.cuda()
    netG_B2A.cuda()
    netD_A.cuda()
    netD_B.cuda()

# netG_A2B.apply(weights_init_normal)
# netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)

---------- Networks initialized -------------
Generator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): ResnetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): ReflectionPad2d

  torch.nn.init.normal(m.weight.data, 0.0, 0.02)
  torch.nn.init.normal(m.weight.data, 1.0, 0.02)
  torch.nn.init.constant(m.bias.data, 0.0)


Discriminator(
  (conv_tower): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.01)
    (9): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2))
    (10): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): LeakyReLU(negative_slope=0.01)
    (12): Conv2d(512, 512, kernel_size=(4, 4), stride=(1, 1))
    (13): LeakyReLU(negative_slope=0.01)
    (14): AdaptiveAvgPool2d(output_size=(1, 1))
    (15): Conv2d(512, 1, kernel_size=(1, 1), 

In [49]:
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

# Metrics
ssi =  StructuralSimilarityIndexMeasure()
psnr = PeakSignalNoiseRatio()
uqi = UniversalImageQualityIndex()
vif = VisualInformationFidelity()
# c_ssi = 0
# c_psnr = 0
# c_uqi = 0
# c_vif = 0

In [50]:
# Optimizers & LR schedulers
optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()),
                                lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(itertools.chain(netD_A.parameters(), netD_B.parameters()),
                                lr=lr, betas=(0.5, 0.999))

lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimizer_D, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step)

In [51]:
# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
input_A = Tensor(batch_size, input_nc, size, size)
input_B = Tensor(batch_size, output_nc, size, size)
target_real = Variable(Tensor(batch_size).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(batch_size).fill_(0.0), requires_grad=False)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

In [52]:
from PIL import Image

class GrayscaleToRGB:
    """Convert a grayscale image to RGB by replicating the single channel."""
    def __call__(self, img):
        """
        Args:
            img (PIL Image): Grayscale image.

        Returns:
            PIL Image: RGB image.
        """
        return img.convert('RGB')

In [53]:
# Dataset loader
transforms_ = [ GrayscaleToRGB(),
    transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
                 ]

dataloader = DataLoader(ImageDataset(dataroot, transforms_=transforms_, unaligned=True),
                        batch_size=batch_size, shuffle=True, num_workers=n_cpu)



In [54]:
###### Training ######
for epoch in range(epoch, n_epochs):
    for i, batch in enumerate(dataloader):
        # Set model input
        real_A = Variable(input_A.copy_(batch['A']))
        real_B = Variable(input_B.copy_(batch['B']))

        ###### Generators A2B and B2A ######
        optimizer_G.zero_grad()

        # Identity loss
        # G_A2B(B) should equal B if real B is fed
        same_B, _, _ = netG_A2B(real_B)
        loss_identity_B = criterion_identity(same_B, real_B)*lambda_identity
        # G_B2A(A) should equal A if real A is fed
        same_A, _, _ = netG_B2A(real_A)
        loss_identity_A = criterion_identity(same_A, real_A)*lambda_identity

        # GAN loss
        fake_B, mask_B, temp_B = netG_A2B(real_A)
        recovered_A, _, _ = netG_B2A(fake_B)
        pred_fake_B = netD_B(fake_B)

        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)
        loss_GAN_A2B = criterion_GAN(pred_fake_B, target_real)
        loss_pix_A = criterion_identity(fake_B, real_A)

        fake_A, mask_A, temp_A = netG_B2A(real_B)
        recovered_B, _, _  = netG_A2B(fake_A)
        pred_fake_A = netD_A(fake_A)

        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)
        loss_GAN_B2A = criterion_GAN(pred_fake_A, target_real)
        loss_pix_B = criterion_identity(fake_A, real_B)

        loss_reg_A = lambda_reg * (
                torch.sum(torch.abs(mask_A[:, :, :, :-1] - mask_A[:, :, :, 1:])) +
                torch.sum(torch.abs(mask_A[:, :, :-1, :] - mask_A[:, :, 1:, :])))

        loss_reg_B = lambda_reg * (
                torch.sum(torch.abs(mask_B[:, :, :, :-1] - mask_B[:, :, :, 1:])) +
                torch.sum(torch.abs(mask_B[:, :, :-1, :] - mask_B[:, :, 1:, :])))

        # Total loss
        if epoch < gan_curriculum:
            rate = starting_rate
            # print('using curriculum gan')
        else:
            rate = default_rate
            # print('using normal gan')

        loss_G = ((loss_GAN_A2B + loss_GAN_B2A)*0.5 + (loss_reg_A + loss_reg_B))* (1.-rate) + ((loss_cycle_ABA + loss_cycle_BAB)*lambda_cycle+(loss_pix_B+loss_pix_A)*lambda_pixel)* rate

        loss_G.backward()
        optimizer_G.step()
        ###################################

        optimizer_D.zero_grad()

        # Real loss
        pred_real_A = netD_A.forward(real_A)
        loss_D_real_A = criterion_GAN(pred_real_A, target_real)

        # Fake loss
        fake_A = fake_A_buffer.push_and_pop(fake_A)
        pred_fake_A = netD_A.forward(fake_A.detach())
        loss_D_fake_A = criterion_GAN(pred_fake_A, target_fake)

        # Real loss
        pred_real_B = netD_B.forward(real_B)
        loss_D_real_B = criterion_GAN(pred_real_B, target_real)

        # Fake loss
        fake_B = fake_B_buffer.push_and_pop(fake_B)
        pred_fake_B = netD_B.forward(fake_B.detach())
        loss_D_fake_B = criterion_GAN(pred_fake_B, target_fake)

        # Total loss
        loss_D = (loss_D_real_B + loss_D_fake_B + loss_D_real_A + loss_D_fake_A)*0.25

        loss_D.backward()
        optimizer_D.step()
        
        if i % 100 == 0:
            print('Epoch [%d/%d], Batch [%d/%d], loss_D: %.4f, loss_G: %.4f' % (epoch+1, n_epochs,i+1, len(dataloader), loss_D.item(), loss_G.item()))
            print('loss_GAN_A2B: %.4f, loss_GAN_B2A: %.4f, loss_cycle_ABA: %.4f, loss_cycle_BAB: %.4f, loss_identity_A: %.4f, loss_identity_B: %.4f, loss_pix_A: %.4f, loss_pix_B: %.4f' % (loss_GAN_A2B.item(),
                loss_GAN_B2A.item(), loss_cycle_ABA.item(), loss_cycle_BAB.item(), loss_identity_A.item(), loss_identity_B.item(), loss_pix_A.item(), loss_pix_B.item()))

        wandb.log(
            {
                'Discriminator Loss': loss_D.item(),
                'Generator Loss': loss_G.item(),
                'CT to MR Loss': loss_GAN_A2B.item(),
                'MR to CT Loss': loss_GAN_B2A.item(),
                'CT Reconstruction Loss': loss_cycle_ABA.item(),
                'MR Reconstruction Loss': loss_cycle_BAB.item(),
                'CT Identity Loss': loss_identity_A.item(),
                'MR Identity Loss': loss_identity_B.item(),
                'CT Pixel-Wise Reconstruction Loss': loss_pix_A.item(),
                'MR Pixel-Wise Reconstruction Loss': loss_pix_B.item(),
            }
        )

        wandb.log(
            {
                'Structural Similarity Index Measure (SSIM) - CT to MR': ssi(real_A.data.cpu(), fake_A.data.cpu()),
                'Peak Signal-to-Noise Ratio (PSNR) - CT to MR': psnr(real_A.data.cpu(), fake_A.data.cpu()),
                'Universal Quality Index (UQI) - CT to MR': uqi(real_A.data.cpu(), fake_A.data.cpu()),
                'Visual Information Fidelity (VIF) - CT to MR': vif(real_A.data.cpu(), fake_A.data.cpu()),
                'Structural Similarity Index Measure (SSIM) - MR to CT': ssi(real_B.data.cpu(), fake_B.data.cpu()),
                'Peak Signal-to-Noise Ratio (PSNR) - MR to CT': psnr(real_B.data.cpu(), fake_B.data.cpu()),
                'Universal Quality Index (UQI) - MR to CT': uqi(real_B.data.cpu(), fake_B.data.cpu()),
                'Visual Information Fidelity (VIF) - MR to CT': vif(real_B.data.cpu(), fake_B.data.cpu())
            }
        )

        save_path='%s/%s' % (save_name, 'training')
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        
        if i % 100 == 0:
            image_array = torch.cat([
                real_A.data.cpu()[0]* 0.5 + 0.5,
                mask_B.data.cpu()[0],
                fake_B.data.cpu()[0]* 0.5+0.5,
                temp_B.data.cpu()[0]* 0.5+0.5,
                real_B.data.cpu()[0]* 0.5 + 0.5,
                mask_A.data.cpu()[0],
                fake_A.data.cpu()[0]* 0.5+0.5,
                temp_A.data.cpu()[0]* 0.5+0.5],
                2)
            wandb.log({
                "examples": wandb.Image(image_array, caption=f"Epoch {epoch} (Batch {i}) Real CT, Masked MR, Fake MR, Temp MR, Real MR, Masked CT, Fake CT, Temp CT")
                })

        # image_array = torch.cat([
        #     real_B.data.cpu()[0], # * 0.5 + 0.5,
        #     mask_A.data.cpu()[0],
        #     fake_A.data.cpu()[0], #*0.5+0.5,
        #     temp_A.data.cpu()[0]], #*0.5+0.5],
        #     2)

        # wandb.log({
        #     "examples": wandb.Image(image_array, caption="Real MR, Masked CT, Fake CT, Temp CT")
        #     })

        # save_image(torch.cat([
        #     real_A.data.cpu()[0] * 0.5 + 0.5,
        #     mask_B.data.cpu()[0],
        #     fake_B.data.cpu()[0]*0.5+0.5, temp_B.data.cpu()[0]*0.5+0.5], 2),
        #     '%s/%04d_%04d_progress_B.png' % (save_path,epoch+1,i+1))

        # save_image(torch.cat([
        #     real_B.data.cpu()[0] * 0.5 + 0.5,
        #     mask_A.data.cpu()[0],
        #     fake_A.data.cpu()[0]*0.5+0.5, temp_A.data.cpu()[0]*0.5+0.5], 2),
        #     '%s/%04d_%04d_progress_A.png' % (save_path,epoch+1,i+1))


    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D.step()

    torch.save(netG_A2B.state_dict(), '%s/%s' % (save_name, 'netG_A2B.pth'))
    torch.save(netG_B2A.state_dict(), '%s/%s' % (save_name, 'netG_B2A.pth'))
    torch.save(netD_A.state_dict(), '%s/%s' % (save_name, 'netD_A.pth'))
    torch.save(netD_B.state_dict(), '%s/%s' % (save_name, 'netD_B.pth'))

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/10], Batch [1/1749], loss_D: 0.3500, loss_G: 0.8336
loss_GAN_A2B: 0.3765, loss_GAN_B2A: 0.9333, loss_cycle_ABA: 0.4317, loss_cycle_BAB: 0.6396, loss_identity_A: 0.0000, loss_identity_B: 0.0000, loss_pix_A: 0.3040, loss_pix_B: 0.4387
Epoch [1/10], Batch [101/1749], loss_D: 0.1091, loss_G: 0.5013
loss_GAN_A2B: 0.5851, loss_GAN_B2A: 0.2504, loss_cycle_ABA: 0.2159, loss_cycle_BAB: 0.3016, loss_identity_A: 0.0000, loss_identity_B: 0.0000, loss_pix_A: 0.1890, loss_pix_B: 0.3437
Epoch [1/10], Batch [201/1749], loss_D: 0.0557, loss_G: 0.7424
loss_GAN_A2B: 0.7384, loss_GAN_B2A: 0.5390, loss_cycle_ABA: 0.3617, loss_cycle_BAB: 0.4324, loss_identity_A: 0.0000, loss_identity_B: 0.0000, loss_pix_A: 0.3631, loss_pix_B: 0.4227
Epoch [1/10], Batch [301/1749], loss_D: 0.0898, loss_G: 0.5898
loss_GAN_A2B: 0.6392, loss_GAN_B2A: 0.3644, loss_cycle_ABA: 0.4206, loss_cycle_BAB: 0.2378, loss_identity_A: 0.0000, loss_identity_B: 0.0000, loss_pix_A: 0.4987, loss_pix_B: 0.4382
Epoch [1/10], Batch [401/1

Testing

In [100]:
batch_size = 1 #size of the batches')
save_name = 'MedAttentionGAN_2_no_clip'
dataroot = 'juh_mr_ct' #root directory of the dataset
input_nc = 3 #number of channels of input data'
output_nc = 3 #number of channels of output data'
size = 256 #size of the data (squared assumed)'
cuda = 'store_true', #use GPU computation'
n_cpu = 8 #number of cpu threads to use during batch generation'

if torch.cuda.is_available() and not cuda:
    print("WARNING: You have a CUDA device, so you should probably run with --cuda")

In [101]:
###### Definition of variables ######
# Networks
netG_A2B = Generator()
netG_B2A = Generator()

if cuda:
    netG_A2B.cuda()
    netG_B2A.cuda()

# Load state dicts
netG_A2B.load_state_dict(torch.load('%s/%s' % (save_name, 'netG_A2B.pth')))
netG_B2A.load_state_dict(torch.load('%s/%s' % (save_name, 'netG_B2A.pth')))

# Set model's test mode
netG_A2B.eval()
netG_B2A.eval()

Generator(
  (model): Sequential(
    (0): ReflectionPad2d((3, 3, 3, 3))
    (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1))
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): ResnetBlock(
      (conv_block): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
        (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace=True)
        (4): ReflectionPad2d((1, 1, 1, 1))
        (5): Conv2d(256, 256, k

In [102]:
# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
input_A = Tensor(batch_size, input_nc, size, size)
input_B = Tensor(batch_size, output_nc, size, size)

# Dataset loader
transforms_ = [GrayscaleToRGB(),
               transforms.Resize((256,256)),
    transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
              transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])]
dataloader = DataLoader(ImageDataset(dataroot, transforms_=transforms_, mode='test'),
                        batch_size=batch_size, shuffle=False, num_workers=n_cpu)

In [103]:
wandb.init(
    # set the wandb project where this run will be logged
    project='deep_learning_project',
    entity='thedlproject'
)

VBox(children=(Label(value='20.101 MB of 20.101 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
Peak Signal-to-Noise Ratio (PSNR) - CT to MR,▆▅▇▇▆▆▇▄▅▆▅▃▅▅▄▆▇▃▅▇▅▅▄▆▇▅▆▇▇▅█▆▅▃▃▁█▅▅▄
Peak Signal-to-Noise Ratio (PSNR) - MR to CT,█▂▂▃▁▄▂▆▄▁▂▁▂▅▂▂▃▄▃▅▄▆▅▃▂▂▂▃▂▄▃▂▆▂▃▂▂▅▃▂
Structural Similarity Index Measure (SSIM) - CT to MR,▄▄▃▂▃▃▃▅▂▃▂▃▃▄▆▄▄▂▃▅▅▅▄▃▅▂▁▃▃▅█▃▁▄▅▁█▃▂▄
Structural Similarity Index Measure (SSIM) - MR to CT,▇▃▂▂▁▃▃█▂▃▁▅▂▅▆▃▄▅▄▅▆▆▃▄▄▄▂▃▂▆▇▂▃▃▇▄▇▄▄▄
Universal Quality Index (UQI) - CT to MR,▇▃▂▃▂▆▃▅▃▄▃▇▁▃▃▄▃▅▃▂▁▂▃▁▃█▁▁▃▄▆▂▅█▄▄▄▁▄▃
Universal Quality Index (UQI) - MR to CT,▅▁▅▂▄▇▃▅▁▅▅█▄▃▄▃▅▅▅▇▃▃▃▃▄▃▅▂▅▄▆▄▄▃▄▅▆▃▁▃
Visual Information Fidelity (VIF) - CT to MR,▃▂▂▂▃▇▂▅▂▂▁▅▁▂▄▂▂▅▂▃▅▂▂▁▃▄▁▂▁▂█▃▂▄▂▄▄▃▂▂
Visual Information Fidelity (VIF) - MR to CT,▄▂▂▂▄▄▃▂▂▆▃▂▃▂▃▂▂▂▂█▃▁▂▂▂▃▄▂▂▂▄▃▂▂▁▂▃▂▃▄

0,1
Peak Signal-to-Noise Ratio (PSNR) - CT to MR,14.47996
Peak Signal-to-Noise Ratio (PSNR) - MR to CT,15.17131
Structural Similarity Index Measure (SSIM) - CT to MR,0.38029
Structural Similarity Index Measure (SSIM) - MR to CT,0.40261
Universal Quality Index (UQI) - CT to MR,0.00157
Universal Quality Index (UQI) - MR to CT,0.00097
Visual Information Fidelity (VIF) - CT to MR,0.18996
Visual Information Fidelity (VIF) - MR to CT,0.02047


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669802639322977, max=1.0…

In [104]:
for i, batch in enumerate(dataloader):
    # Set model input
    real_A = Variable(input_A.copy_(batch['A']))
    real_B = Variable(input_B.copy_(batch['B']))

    # Generate output
    fake_B, mask_B, temp_B = netG_A2B(real_A)
    # fake_B_1 = 0.5*fake_B.data[0] + 0.5
    # fake_B_2 = 0.5*temp_B.data[0] + 0.5
    fake_A, mask_A, temp_A = netG_B2A(real_B)
#     fake_A_1 = 0.5*fake_A.data[0] + 0.5
#     fake_A_2 = 0.5*temp_A.data[0] + 0.5
    
    
    image_array = torch.cat([
                real_A.data.cpu()[0]* 0.5 + 0.5,
                mask_B.data.cpu()[0],
                fake_B.data.cpu()[0]* 0.5+0.5,
                temp_B.data.cpu()[0]* 0.5+0.5,
                real_B.data.cpu()[0]* 0.5 + 0.5,
                mask_A.data.cpu()[0],
                fake_A.data.cpu()[0]* 0.5+0.5,
                temp_A.data.cpu()[0]* 0.5+0.5],
                2)
    wandb.log({
        "examples": wandb.Image(image_array, caption=f"TEST (paired): Real CT, Masked MR, Fake MR, Temp MR, Real MR, Masked CT, Fake CT, Temp CT")
        })
    wandb.log(
            {
                'Structural Similarity Index Measure (SSIM) - CT to MR': ssi(real_A.data.cpu(), fake_A.data.cpu()),
                'Peak Signal-to-Noise Ratio (PSNR) - CT to MR': psnr(real_A.data.cpu(), fake_A.data.cpu()),
                'Universal Quality Index (UQI) - CT to MR': uqi(real_A.data.cpu(), fake_A.data.cpu()),
                'Visual Information Fidelity (VIF) - CT to MR': vif(real_A.data.cpu(), fake_A.data.cpu()),
                'Structural Similarity Index Measure (SSIM) - MR to CT': ssi(real_B.data.cpu(), fake_B.data.cpu()),
                'Peak Signal-to-Noise Ratio (PSNR) - MR to CT': psnr(real_B.data.cpu(), fake_B.data.cpu()),
                'Universal Quality Index (UQI) - MR to CT': uqi(real_B.data.cpu(), fake_B.data.cpu()),
                'Visual Information Fidelity (VIF) - MR to CT': vif(real_B.data.cpu(), fake_B.data.cpu())
            }
        )
    # # Show images
    # imshow(real_A.data.cpu()[0]*0.5+0.5, text='REAL CT', should_save=False)
    # imshow(real_B.data.cpu()[0]*0.5+0.5, text='REAL MR', should_save=False)
    # imshow(fake_A_1, text='FAKE CT (WITH MASKING)', should_save=False)
    # imshow(fake_B_1, text='FAKE MR (WITH MASKING)', should_save=False)
    # imshow(fake_A_2, text='FAKE CT (WITHOUT MASKING)', should_save=False)
    # imshow(fake_B_2, text='FAKE MR (WITHOUT MASKING)', should_save=False)
    # imshow(mask_A.data.cpu()[0], text='CT MASK)', should_save=False)
    # imshow(mask_B.data.cpu()[0], text='MR MASK)', should_save=False)

