In [16]:
import torch 
import torch.nn as nn 
import matplotlib.pyplot as plt 
import torch.nn.functional as F 
import torch.optim as optim 
torch.backends.cudnn.benchmark = True
import albumentations as A
from torch.utils.data import Dataset, DataLoader  
import os 
from PIL import Image 
from albumentations.pytorch import ToTensorV2
from torchvision.utils import save_image
import tqdm
import cv2 
import numpy as np 

from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio

In [2]:
t1_actual = '../input/ixi-t1/image slice-T1'
t1_generated = '../input/t1-generated-22-must/t1_generated_22_MUST'

# Discriminator

In [3]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 1, features = [32,64, 128, 256, 512]):
        super(Discriminator, self).__init__()
        
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, 
            features[0],
            kernel_size = 4, 
            stride = 2, 
            padding = 1, 
            padding_mode = 'reflect')
        )
        layers = []
        
        in_channel = features[0]
        for out_channel in features[1:]:
            layer = self.block(in_channel, out_channel,stride = 1 if out_channel == features[-1] else 2 )
            in_channel = out_channel 
            layers.append(layer)
            
        layers.append(
            nn.Conv2d(
                in_channel, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

            
        self.model = nn.Sequential(*layers)
            
            
        
    def block(self, in_channel, out_channel, stride = 1):
        block_layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 4, stride,1, bias = False, padding_mode = 'reflect'),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2)
        )
        return block_layer 
    def forward(self, x,y):
        x = torch.cat([x, y] , dim = 1)
        x = self.initial(x)
        x = self.model(x)
        return x 

# Many to one & one to one 

In [4]:
# Residual block
# 3x3 convolution
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=1, bias=False)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out



class Streams(nn.Module):
    def __init__(self, input_channel = 50):
        super(Streams, self).__init__()
        e1 = nn.Conv2d(input_channel, 32, kernel_size=3, stride=2, padding=4, bias=False)
        e2 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=0, bias=False)
        e3 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=0, bias=False)
        
        self.encoder = nn.Sequential(
            *[e1, e2, e3]         
        )
        
        layers = []
        for i in range(5):
            layer = ResidualBlock(32,32,1)
            layers.append(layer)
            

        self.residual_network_first_half = nn.Sequential(*layers)
        
        layers = []
        for i in range(4):
            layer = ResidualBlock(32,32,1)
            layers.append(layer)
            
            
        self.residual_network_second_half = nn.Sequential(*layers)
        
        
        d1 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding= 1, bias=False)
        d2 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding = 0, bias=False)
        d3 = nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding = 0, bias=False)
        
        self.decoder = nn.Sequential(*[d1, d2, d3])
        
            
    def forward(self, x, fuse = False):
        
        x = self.encoder(x)
        encoder_out = x 
        #print(x.shape)
        x = self.residual_network_first_half(x)
        res_out_first = x
        x = self.residual_network_second_half(x)
        res_out_second = x
        #print(x.shape)
        x = self.decoder(x)
        if not fuse:
            return x 
        else: 
            return x , encoder_out, res_out_first, res_out_second

# MustGAN 

In [5]:
class MustGAN(nn.Module):
    def __init__(self, num_slice = 50):
        super(MustGAN, self).__init__()
        
        self.num_slice = num_slice 
        
        self.one_one = Streams(input_channel = 1)
        self.many_one = Streams(input_channel = 50)
        
        layers = []
        for i in range(5):
            if i ==0: 
                layer =  nn.Conv2d(1632, 512, kernel_size=3, stride=1, padding=1, bias=False)
            else: 
                layer = ResidualBlock(512,512,1)
            layers.append(layer)
            
        self.encoder = nn.Sequential(*layers)
        
        
        layers = []
        for i in range(4):
            if i == 0: 
                layer =  nn.Conv2d(2144, 512, kernel_size=3, stride=1, padding=1, bias=False)
            else: 
                layer = ResidualBlock(512,512,1)
            layers.append(layer)
            
        self.residual = nn.Sequential(*layers)
        
        d1 = nn.ConvTranspose2d(2144, 1632, kernel_size=3, stride=2, padding= 1, bias=False)
        d2 = nn.ConvTranspose2d(1632, 256, kernel_size=3, stride=2, padding = 0, bias=False)
        d3 = nn.ConvTranspose2d(256, 50, kernel_size=4, stride=2, padding = 0, bias=False)
        
        self.decoder = nn.Sequential(*[d1, d2, d3])
        
        
    def forward(self,x_bulk):
        fuse1 = []
        fuse2 = []
        fuse3 = []
        
        for i in range(self.num_slice):
            _ , encoder_out, res_out_first, res_out_second = self.one_one(x_bulk[:,i:i+1,:,:], fuse = True)
            fuse1.append(encoder_out)
            fuse2.append(res_out_first)
            fuse3.append(res_out_second)
            
        _ , encoder_out, res_out_first, res_out_second = self.many_one(x_bulk, fuse = True)
        
        fuse1.append(encoder_out)
        fuse2.append(res_out_first)
        fuse3.append(res_out_second)
        
        fuse1 = torch.cat(fuse1, 1)
        fuse2 = torch.cat(fuse2, 1)
        fuse3 = torch.cat(fuse3, 1)
        
        x = self.encoder(fuse1)
        x = torch.cat([x, fuse2],1)
        x = self.residual(x)

        x = torch.cat([x, fuse3],1)
        x = self.decoder(x)
        
        #print(fuse1.shape, fuse2.shape, fuse3.shape)
        return x
        
        
            

# Training MustGAN

In [6]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "dataset/train"
T1_PATH = '../input/ixi-t1'
T2_PATH = '../input/ixit2-slices'
VAL_DIR = "dataset/val"
LEARNING_RATE = 2e-4
BATCH_SIZE = 4
NUM_WORKERS = 0
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 500
LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_DISC = "must_dis.pth.tar"
CHECKPOINT_GEN = "must_gan.pth.tar"


both_transform = A.Compose(
    [A.Resize(width=256, height=256),], additional_targets={"image0": "image"},
)

transform_only_input = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ])

In [7]:
import os 
import shutil


t1_list = os.listdir("../input/ixi-t1/image slice-T1")
t2_list  = os.listdir("../input/ixit2-slices/image slice-T2")

print(len(t1_list))
print(len(t2_list))

t1_sampled = []
t2_sampled = []

for file1 in t1_list:
    found = False 
    for file2 in t2_list:
        if file1.split("-")[0] == file2.split("-")[0]:
            found = True 
            match_folder = file2 

    if found:  
        #print("matched")
        t1_sampled.append(file1)
        t2_sampled.append(match_folder)
        #print(file1 +"\t"+ match_folder)

    else: 
        #shutil.rmtree("./image slice-T2/"+file1)
        print("no matched")

In [8]:
print(len(t1_sampled))
print(len(t2_sampled))

In [9]:
class MapDataset(Dataset):
    def __init__(self,root_dir_T1, root_dir_T2, t1_sampled, t2_sampled):
        super(MapDataset, self).__init__() 
        self.list_files_t1 = t1_sampled
        self.list_files_t2 = t2_sampled
        
        self.root_dir_T1 = root_dir_T1
        self.root_dir_T2 = root_dir_T2
    def __len__(self):
        return len(self.list_files_t1)
    
    def __getitem__(self, idex):
        folder_file1 = self.list_files_t1[idex]
        image_files = os.listdir(os.path.join(self.root_dir_T1, folder_file1))
        
        out_file1 = np.zeros((50, 256, 256))
        for i, imf in enumerate(image_files): 
            path_file = os.path.join(self.root_dir_T1, folder_file1, imf)
            img = cv2.imread(path_file, 0)
            img = cv2.resize(img, (256, 256))
            #img = np.expand_dims(img,0)
            out_file1[i] = img 
            
        folder_file2 = self.list_files_t2[idex]
        image_files2 = os.listdir(os.path.join(self.root_dir_T2, folder_file2))
        
        out_file2 = np.zeros((50, 256, 256))
        for i, imf in enumerate(image_files2): 
            path_file = os.path.join(self.root_dir_T2, folder_file2, imf)
            img = cv2.imread(path_file, 0)
            img = cv2.resize(img, (256, 256))
            #img = np.expand_dims(img,0)
            out_file2[i] = img 
            
            
        
        return out_file1.astype(np.float32), out_file2.astype(np.float32)

In [10]:
root_dir_T1 = '../input/ixi-t1/image slice-T1'
root_dir_T2 = '../input/ixit2-slices/image slice-T2'
dataset = MapDataset(root_dir_T1, root_dir_T2, t1_sampled, t2_sampled)

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
disc = Discriminator(in_channels= 50).to(device)
gen = MustGAN().to(device)
opt_disc = optim.Adam(disc.parameters(), lr = 0.0002, betas = (0.5, 0.999))
opt_gen = optim.Adam(gen.parameters(), lr = 0.0002, betas = (0.5, 0.999))

BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()

In [12]:
train_loader = DataLoader(
    dataset, 
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = NUM_WORKERS
)

val_loader = DataLoader(dataset, batch_size=1, shuffle=False)

g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

In [13]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler):
    pbar = tqdm.tqdm(loader, leave = True)
    for idx,(x,y) in enumerate(pbar):
        #print(x.shape)
        #print(y.shape)
        x = x.to(device) # input image type 
        y = y.to(device) # target image type 

        # train discriminator 
        with torch.cuda.amp.autocast():
            #print(x.dtype)
            y_fake = gen(x) # fake target generation
            
            D_real = disc(x,y) # disc pred with actual image 
            D_real_loss = bce(D_real, torch.ones_like(D_real))

            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))

            D_loss = (D_fake_loss + D_real_loss)/2 

        opt_disc.zero_grad()
        # D_loss.backward()
        # opt_disc.step()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # train generator 
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake,y)*L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            pbar.set_postfix(
                D_real = torch.sigmoid(D_real).mean().item(),
                D_fake = torch.sigmoid(D_fake).mean().item(),
            )
        

In [14]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    if not os.path.exists(folder):
        os.mkdir(folder)
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        
        save_image(y_fake[:,0:1,:,:], folder + f"/y_gen_{epoch}.png")
        save_image(x[:,0:1,:,:] * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        save_image(y[:,0:1,:,:] * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
SAVE_MODE = True 

for epoch in range(NUM_EPOCHS):
    train_fn(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler)
    if SAVE_MODE and epoch %1 ==0: 
        save_checkpoint(gen, opt_gen, filename = CHECKPOINT_GEN)
        save_checkpoint(disc, opt_disc, filename = CHECKPOINT_DISC)
            
    save_some_examples(gen, val_loader, epoch, folder = 'evaluation')   

# PSNR and SSIM 

In [17]:
PSNR = []
SSIM = []


for sub_folder in tqdm.tqdm(t1_sampled):
    for file in os.listdir(os.path.join(t1_actual,sub_folder)):
        
        img_actual = cv2.imread(os.path.join(t1_actual, sub_folder, file))
        img_generated = cv2.imread(os.path.join(t1_generated, sub_folder, file))
        #print(os.path.join(t1_generated, sub_folder, file))
        #print(os.path.join(t1_actual, sub_folder, file))
        
        img_actual = cv2.resize(img_actual, (256, 256))
        img_generated = cv2.resize(img_generated, (256, 256))
        #print(img_actual.shape)
        #print(img_generated.shape)
       
        psnr = peak_signal_noise_ratio(img_actual, img_generated)
        ssim = structural_similarity(img_actual, img_generated,multichannel=True)
        
        PSNR.append(psnr)
        SSIM.append(ssim)
        
print(np.mean(PSNR))
print(np.mean(SSIM))

In [21]:
count = 0 
output= "./"
for sub_folder in os.listdir(t1_actual):
    if count==9:
        break
    count+=1
    
    
    for file in os.listdir(os.path.join(t1_actual,sub_folder)):
        
        img_actual = cv2.imread(os.path.join(t1_actual, sub_folder, file))
        img_generated = cv2.imread(os.path.join(output,t1_generated, sub_folder, file))
        #print(os.path.join(t1_generated, sub_folder, file))
        #print(os.path.join(t1_actual, sub_folder, file))
        
        img_actual = cv2.resize(img_actual, (256, 256))
        img_generated = cv2.resize(img_generated, (256, 256))
        #print(img_actual.shape)
        #print(img_generated.shape)
        
        plt.subplot(121)
        plt.imshow(img_actual)
        
        plt.subplot(122)
        plt.imshow(img_generated)
        plt.title("Actual Vs Generated")
        plt.show()
        
        
        break