In [None]:
import argparse
import os
import random
import logging
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import torch.optim
import sys

import torch.distributed as dist
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import nibabel as nib

In [None]:
!git clone https://github.com/raovish6/TABS

In [None]:
sys.path.append('/kaggle/working/TABS') 

In [None]:
from Models.TABS_Model import TABS

In [None]:
!pip install gdown
!gdown https://drive.google.com/uc?id=1Du6N8hr4lcRCjwSYuwLsepzWVXPmdjEr

In [None]:
# Load the raw CT scan
ct_scan_1R = nib.load('/kaggle/input/micro-ct-scans-of-human-cochlea-normalized/normalized_ct_1R.nii')

# Load the corresponding mask
mask_1R = nib.load('/kaggle/input/micro-ct-scans-of-human-cochlea-normalized/normalized_mask_1R.nii')

# Get the CT scan data
ct_scan_1R_data = ct_scan_1R.get_fdata()

# Get the mask data
mask_1R_data = mask_1R.get_fdata()

In [None]:
class CTScanDataset(Dataset):
    def __init__(self, ct_scan_data, mask_data, cube_size=(80, 80, 80), channels=1):
        self.ct_scan = ct_scan_data
        self.mask = mask_data
        self.cube_size = cube_size
        self.channels = channels

    def __len__(self):
        #return len(self.ct_scan)
        return 120

    def __getitem__(self, idx):
        # Get random cube of specified size
        x_start = np.random.randint(0, self.ct_scan.shape[1] - self.cube_size[0])
        y_start = np.random.randint(0, self.ct_scan.shape[2] - self.cube_size[1])
        z_start = np.random.randint(0, self.ct_scan.shape[0] - self.cube_size[2])
        x_end = x_start + self.cube_size[0]
        y_end = y_start + self.cube_size[1]
        z_end = z_start + self.cube_size[2]
        ct_cube = self.ct_scan[z_start:z_end, x_start:x_end, y_start:y_end]
        mask_cube = self.mask[z_start:z_end, x_start:x_end, y_start:y_end]

        # Add channel dimension
        if self.channels == 1:
            ct_cube = torch.unsqueeze(torch.from_numpy(ct_cube).float(), 0)
            mask_cube = np.repeat(mask_cube[np.newaxis,:,:,:], 5, axis=0)
            mask_cube = torch.from_numpy(mask_cube).float()
            
        else:
            ct_cube = torch.from_numpy(ct_cube).float()
            mask_cube = torch.from_numpy(mask_cube).float()
            

        return ct_cube, mask_cube

In [None]:
dataset_train = CTScanDataset(ct_scan_1R_data, mask_1R_data, cube_size=(80, 80, 80), channels=1)
dataloader_train = DataLoader(dataset_train, batch_size=3, shuffle=True)

In [None]:
#for ct_scan, mask in dataloader_train:
    #print(ct_scan.shape, mask.shape)

In [None]:
# Load the raw CT scan
ct_scan_2R = nib.load('/kaggle/input/micro-ct-scans-of-human-cochlea-normalized/normalized_ct_2R.nii')

# Load the corresponding mask
mask_2R = nib.load('/kaggle/input/micro-ct-scans-of-human-cochlea-normalized/normalized_mask_2R.nii')

# Get the CT scan data
ct_scan_2R_data = ct_scan_2R.get_fdata()

# Get the mask data
mask_2R_data = mask_2R.get_fdata()

In [None]:
dataset_val = CTScanDataset(ct_scan_2R_data, mask_2R_data, cube_size=(80, 80, 80), channels=1)
dataloader_val = DataLoader(dataset_val, batch_size=3, shuffle=True)

In [None]:
#for ct_scan, mask in dataloader_val:
    #print(ct_scan.shape, mask.shape)  # should print (batch_size, channels, height, width, depth)

In [None]:
# Declare variables
date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
root = ''
lr = 0.00001
weight_decay = 1e-5
amsgrad = True
seed = 1000
no_cuda = False
num_workers = 4
batch_size = 1
start_epoch = 0
end_epoch = 2
gpu = 0
gpu_available = '0,1,2'

# Set seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def main_worker():

    # Set seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    #torch.cuda.set_device(gpu)  # discouraged

    model = TABS(img_dim = 80,
                 output_ch = 5)

    model.cuda(gpu)

    print('Model Built!')

    # Using adam optimizer (amsgrad variant) with weight decay
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, amsgrad=amsgrad)

    # MSE loss for this task (regression). Using reduction value of sum because we want to specify the number of voxels to divide by (only in the brain map)
    criterion = nn.MSELoss(reduction='mean')
    criterion = criterion.cuda(gpu)

    # *************************************************************************
    # Place train and validation datasets/dataloaders here
    # *************************************************************************
    #dataloader_train = DataLoader(dataset_train, batch_size=3, shuffle=True)
    #dataloader_val = DataLoader(dataset_val, batch_size=3, shuffle=True)

    start_time = time.time()

    # Enable gradient calculation for training
    torch.set_grad_enabled(True)

    # Declare lists to keep track of training and val losses over the epochs
    train_global_losses = []
    val_global_losses = []
    best_epoch = 0

    print('Start to train!')
    
    start_epoch = 0
    end_epoch = 20

    # Main training/validation loop
    for epoch in range(start_epoch, end_epoch):

        # Declare lists to keep track of losses and metrics within the epoch
        train_epoch_losses = []
        val_epoch_losses = []
        val_epoch_pcorr = []
        val_epoch_psnr = []
        start_epoch = time.time()

        model.train()

        # Loop through train dataloader here.
        for i, (ct_scan, mask) in enumerate(dataloader_train):
            adjust_learning_rate(optimizer, epoch, end_epoch, lr)

            # Sample data for the purpose of demonstration
            #ct_scan = ct_scan.cuda(gpu, non_blocking=True)
            #mask = mask.cuda(gpu, non_blocking=True)
            ct_scan, mask = ct_scan.to(device), mask.to(device)

            loss, isolated_images, stacked_brain_map  = get_loss(model, criterion, ct_scan, mask, 'train')

            train_epoch_losses.append(loss.item())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Transition to val mode
        model.eval()

        with torch.no_grad():

        # Loop through validation dataloader here.
            for i, (ct_scan, mask) in enumerate(dataloader_val):
                #ct_scan = ct_scan.cuda(gpu, non_blocking=True)
                #mask = mask.cuda(gpu, non_blocking=True)
                ct_scan, mask = ct_scan.to(device), mask.to(device)

                loss, isolated_images, stacked_brain_map  = get_loss(model, criterion, ct_scan, mask, 'val')

                val_epoch_losses.append(loss.item())

                for j in range(0,len(isolated_images)):
                    #print(isolated_images.shape)
                    #print(mask.shape)
                    #print(stacked_brain_map.shape)
                    cur_pcorr = overall_metrics(isolated_images[j], mask[j], stacked_brain_map[j])
                    val_epoch_pcorr.append(cur_pcorr)

        end_epoch = time.time()

        # Average train and val loss over every MRI scan in the epoch. Save to global losses which tracks across epochs
        train_net_loss = sum(train_epoch_losses) / len(train_epoch_losses)
        val_net_loss = sum(val_epoch_losses) / len(val_epoch_losses)
        train_global_losses.append(train_net_loss)
        val_global_losses.append(val_net_loss)
        pcorr = sum(val_epoch_pcorr) / len(val_epoch_pcorr)

        print('Epoch: {} | Train Loss: {} | Val Loss: {} | Pearson: {}'.format(epoch, train_net_loss, val_net_loss, pcorr))

        checkpoint_dir = root
        # Save the model if it reaches a new min validation loss
        if val_global_losses[-1] == min(val_global_losses):
            print('saving model at the end of epoch ' + str(epoch))
            best_epoch = epoch
            file_name = os.path.join(checkpoint_dir, 'TABS_model_epoch_{}_val_loss_{}.pth'.format(epoch, val_global_losses[-1]))
            # Only save model at higher epochs
            if epoch > 150:
                torch.save({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict(),
                    },
                    file_name)

    end_time = time.time()
    total_time = (end_time - start_time) / 3600
    print('The total training time is {:.2f} hours'.format(total_time))

    print('----------------------------------The training process finished!-----------------------------------')

    # log_name = os.path.join(args.root, args.protocol, 'loss_log_restransunet.txt')
    log_name = os.path.join(root, 'loss_log_TABS.txt')

    with open(log_name, "a") as log_file:
        now = time.strftime("%c")
        log_file.write('================ Loss (%s) ================\n' % now)
        log_file.write('best_epoch: ' + str(best_epoch) + '\n')
        log_file.write('train_losses: ')
        log_file.write('%s\n' % train_global_losses)
        log_file.write('val_losses: ')
        log_file.write('%s\n' % val_global_losses)
        log_file.write('train_time: ' + str(total_time))

    learning_curve(best_epoch, train_global_losses, val_global_losses)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Input the best epoch, lists of global (across epochs) train and val losses. Plot learning curve
def learning_curve(best_epoch, train_global_losses, val_global_losses):
    fig, ax1 = plt.subplots(figsize=(12, 8))

    ax1.set_xlabel('Epochs')
    ax1.set_xticks(np.arange(0, int(len(train_global_losses) + 1), 10))

    ax1.set_ylabel('Loss')
    ax1.plot(train_global_losses, '-r', label='Training loss', markersize=3)
    ax1.plot(val_global_losses, '-b', label='Validation loss', markersize=3)
    ax1.axvline(best_epoch, color='m', lw=4, alpha=0.5, label='Best epoch')
    ax1.legend(loc='upper left')
    save_name = 'Learning_Curve_TABS' + '.png'
    plt.savefig(os.path.join(root, save_name))

def adjust_learning_rate(optimizer, epoch, max_epoch, init_lr, power=0.9):
    for param_group in optimizer.param_groups:
        param_group['lr'] = round(init_lr * np.power(1 - (epoch / max_epoch), power), 8)

# Calculate pearson correlation and psnr only between the voxels of the brain map (do by total brain not tissue type during training)
def overall_metrics(isolated_image, target, stacked_brain_map):
    # Flatten the GT, isolated output, and brain mask
    GT_flattened = torch.flatten(target)
    iso_flattened = torch.flatten(isolated_image)
    mask_flattened = torch.flatten(stacked_brain_map)

    # Only save the part of the flattened GT/output that corresponds to nonzero values of the brain mask
    GT_flattened = GT_flattened[mask_flattened.nonzero(as_tuple=True)]
    iso_flattened = iso_flattened[mask_flattened.nonzero(as_tuple=True)]

    iso_flattened = iso_flattened.cpu().detach().numpy()
    GT_flattened = GT_flattened.cpu().detach().numpy()

    pearson = np.corrcoef(iso_flattened, GT_flattened)[0][1]

    return pearson

# Given the model, criterion, input, and GT, this function calculates the loss and returns the isolated output (stripped of background) and brain map
def get_loss(model, criterion, ct_scan, mask, mode):

    if mode == 'val':
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

    # Gen model outputs
    output = model(ct_scan.float())

    # Construct binary brain map to consider loss only within there
    input_squeezed = torch.squeeze(ct_scan,dim=1)
    brain_map = (input_squeezed > -1).float()
    stacked_brain_map = torch.cat([brain_map.unsqueeze(1)]*5, dim=1)

    # Zero out the background of the segmentation output
    isolated_images = torch.mul(stacked_brain_map, output)

    # Calculate loss over just the brain map
    loss = criterion(isolated_images, mask)
    num_brain_voxels = stacked_brain_map.sum()
    loss = loss / num_brain_voxels

    return loss, isolated_images, stacked_brain_map

if __name__ == '__main__':
        os.environ['CUDA_VISIBLE_DEVICES'] = gpu_available
        assert torch.cuda.is_available(), "Currently, we only support CUDA version"
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        main_worker()

In [None]:
best_model = torch.load('/kaggle/working/best_model_TABS.pth')

In [None]:
torch.cuda.empty_cache()