In [22]:
import argparse
import os
import random
import logging
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import torch.optim
import sys
!pip install MedPy
from medpy.metric.binary import hd

import torch.distributed as dist
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import nibabel as nib

Collecting MedPy
  Downloading MedPy-0.4.0.tar.gz (151 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m151.8/151.8 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: MedPy
  Building wheel for MedPy (setup.py) ... [?25ldone
[?25h  Created wheel for MedPy: filename=MedPy-0.4.0-py3-none-any.whl size=214963 sha256=f415fdb89384f94a4b5792e59423b87a80061ddac82c097b605e1a1b00d8c014
  Stored in directory: /root/.cache/pip/wheels/d4/32/c7/6380ab2edb8cca018d39a0f1d43250fd9791922c963117de46
Successfully built MedPy
Installing collected packages: MedPy
Successfully installed MedPy-0.4.0
[0m

In [2]:
!git clone https://github.com/raovish6/TABS

Cloning into 'TABS'...
remote: Enumerating objects: 77, done.[K
remote: Counting objects: 100% (77/77), done.[K
remote: Compressing objects: 100% (64/64), done.[K
remote: Total 77 (delta 31), reused 28 (delta 9), pack-reused 0[K
Unpacking objects: 100% (77/77), 35.08 KiB | 1.30 MiB/s, done.


In [3]:
sys.path.append('/kaggle/working/TABS') 

In [4]:
from Models.TABS_Model import TABS

In [6]:
!pip install gdown
!gdown https://drive.google.com/uc?id=1Du6N8hr4lcRCjwSYuwLsepzWVXPmdjEr

[0mDownloading...
From (uriginal): https://drive.google.com/uc?id=1Du6N8hr4lcRCjwSYuwLsepzWVXPmdjEr
From (redirected): https://drive.google.com/uc?id=1Du6N8hr4lcRCjwSYuwLsepzWVXPmdjEr&confirm=t&uuid=305072b4-1e59-40b9-a398-5412ed0c40c8
To: /kaggle/working/best_model_TABS.pth
100%|█████████████████████████████████████████| 285M/285M [00:01<00:00, 211MB/s]


In [9]:
class CTScanDataset(Dataset):
    def __init__(self, ct_scan_data, mask_data, cube_size=(80, 80, 80), channels=1):
        self.ct_scan = ct_scan_data
        self.mask = mask_data
        self.cube_size = cube_size
        self.channels = channels

    def __len__(self):
        #return len(self.ct_scan)
        return 120

    def __getitem__(self, idx):
        # Get random cube of specified size
        x_start = np.random.randint(0, self.ct_scan.shape[1] - self.cube_size[0])
        y_start = np.random.randint(0, self.ct_scan.shape[2] - self.cube_size[1])
        z_start = np.random.randint(0, self.ct_scan.shape[0] - self.cube_size[2])
        x_end = x_start + self.cube_size[0]
        y_end = y_start + self.cube_size[1]
        z_end = z_start + self.cube_size[2]
        ct_cube = self.ct_scan[z_start:z_end, x_start:x_end, y_start:y_end]
        mask_cube = self.mask[z_start:z_end, x_start:x_end, y_start:y_end]

        # Add channel dimension
        if self.channels == 1:
            ct_cube = torch.unsqueeze(torch.from_numpy(ct_cube).float(), 0)
            mask_cube = np.repeat(mask_cube[np.newaxis,:,:,:], 5, axis=0)
            mask_cube = torch.from_numpy(mask_cube).float()
            
        else:
            ct_cube = torch.from_numpy(ct_cube).float()
            mask_cube = torch.from_numpy(mask_cube).float()
            

        return ct_cube, mask_cube

In [None]:
# Load the raw CT scan
ct_scan_1R = nib.load('/kaggle/input/micro-ct-scans-of-human-cochlea-normalized/normalized_ct_1R.nii')

# Load the corresponding mask
mask_1R = nib.load('/kaggle/input/micro-ct-scans-of-human-cochlea-normalized/normalized_mask_1R.nii')

# Get the CT scan data
ct_scan_1R_data = ct_scan_1R.get_fdata()

# Get the mask data
mask_1R_data = mask_1R.get_fdata()

In [None]:
dataset_train = CTScanDataset(ct_scan_1R_data, mask_1R_data, cube_size=(80, 80, 80), channels=1)
dataloader_train = DataLoader(dataset_train, batch_size=3, shuffle=True)

In [None]:
# Load the raw CT scan
ct_scan_2R = nib.load('/kaggle/input/micro-ct-scans-of-human-cochlea-normalized/normalized_ct_2R.nii')

# Load the corresponding mask
mask_2R = nib.load('/kaggle/input/micro-ct-scans-of-human-cochlea-normalized/normalized_mask_2R.nii')

# Get the CT scan data
ct_scan_2R_data = ct_scan_2R.get_fdata()

# Get the mask data
mask_2R_data = mask_2R.get_fdata()

In [None]:
dataset_val = CTScanDataset(ct_scan_2R_data, mask_2R_data, cube_size=(80, 80, 80), channels=1)
dataloader_val = DataLoader(dataset_val, batch_size=3, shuffle=True)

In [10]:
# Load the raw CT scan
ct_scan_5R = nib.load('/kaggle/input/micro-ct-scans-of-human-cochlea-normalized-test/normalized_ct_5R.nii')

# Load the corresponding mask
mask_5R = nib.load('/kaggle/input/micro-ct-scans-of-human-cochlea-normalized-test/normalized_mask_5R.nii')

# Get the CT scan data
ct_scan_5R_data = ct_scan_5R.get_fdata()

# Get the mask data
mask_5R_data = mask_5R.get_fdata()

In [44]:
x = torch.tensor(mask_5R_data)

# check if tensor is binary
is_binary = ((x == 0) | (x == 1)).all()

if is_binary:
    print("The tensor is binary.")
else:
    print("The tensor is not binary.")

The tensor is not binary.


In [45]:
num_rows_to_view = 5
for i in range(num_rows_to_view):
    print(f"Rows {i}:")
    print(np.squeeze(mask_5R_data[i, :, :]))

Rows 0:
[[-1.0062024 -1.0062024 -1.0062024 ... -1.0062024 -1.0062024 -1.0062024]
 [-1.0062024 -1.0062024 -1.0062024 ... -1.0062024 -1.0062024 -1.0062024]
 [-1.0062024 -1.0062024 -1.0062024 ... -1.0062024 -1.0062024 -1.0062024]
 ...
 [-1.0062024 -1.0062024 -1.0062024 ... -1.0062024 -1.0062024 -1.0062024]
 [-1.0062024 -1.0062024 -1.0062024 ... -1.0062024 -1.0062024 -1.0062024]
 [-1.0062024 -1.0062024 -1.0062024 ... -1.0062024 -1.0062024 -1.0062024]]
Rows 1:
[[ -38.89281136  342.82980826  476.33500174 ... 2196.41711718
  1818.15240234 1402.15142223]
 [ 343.28083932  179.40622121  407.92862445 ... 1926.85088757
  1862.20310244 1669.76318392]
 [ 640.96133828  128.13902417  220.60039127 ... 1866.26238197
  1833.48745835 1892.87321445]
 ...
 [1896.03043187 1817.85171497 2011.34403929 ... 2078.84835446
  1900.54074246 1771.39551589]
 [2020.36466047 2081.7048845  2289.32951534 ... 1938.87838247
  1867.16444409 1818.30274603]
 [2410.95755759 2330.82437277 2350.51939568 ... 1839.35086212
  1818.9

In [11]:
dataset_test = CTScanDataset(ct_scan_5R_data, mask_5R_data , cube_size=(80, 80, 80), channels=1)
dataloader_test = DataLoader(dataset_test, batch_size=3, shuffle=True)

In [12]:
# Declare variables
date = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
root = ''
lr = 0.00001
weight_decay = 1e-5
amsgrad = True
seed = 1000
no_cuda = False
num_workers = 4
batch_size = 1
start_epoch = 0
end_epoch = 2
gpu = 0
gpu_available = '0,1,2'
load_dir = ''

# Set seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [14]:
def get_loss(model, criterion, ct_scan, mask, mode):

    if mode == 'val':
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

    # Gen model outputs
    output = model(ct_scan.float())

    # Construct binary brain map to consider loss only within there
    input_squeezed = torch.squeeze(ct_scan,dim=1)
    brain_map = (input_squeezed > -1).float()
    stacked_brain_map = torch.cat([brain_map.unsqueeze(1)]*5, dim=1)

    # Zero out the background of the segmentation output
    isolated_images = torch.mul(stacked_brain_map, output)

    # Calculate loss over just the brain map
    loss = criterion(isolated_images, mask)
    num_brain_voxels = stacked_brain_map.sum()
    loss = loss / num_brain_voxels

    return loss, isolated_images, stacked_brain_map


In [15]:
def tissue_wise_probability_metrics(isolated_image, target, stacked_brain_map):
    criterion = nn.MSELoss()
    criterion = criterion.cuda(args.gpu)

    # metrics dict to store metric for each tissue type
    metrics_list = ['pearson_corr', 'spearman_corr', 'mse']
    metrics = { i : [] for i in metrics_list }

    # list of flattened tensors being collected and their corresponding dict
    necessary_flattened_tensors = ['GT_flattened_0', 'GT_flattened_1', 'GT_flattened_2', 'GT_flattened_3', 'GT_flattened_4', 'iso_flattened_0', 'iso_flattened_1', 'iso_flattened_2', 'iso_flattened_3', 'iso_flattened_4']
    flattened_tensors = { i : {} for i in necessary_flattened_tensors }

    # flattened single channel brain mask (192x192x192 --> flat)
    mask_flattened = torch.flatten(stacked_brain_map[0])

    # Only save the part of the flattened GT/output that correspond to nonzero values of the brain mask
    for i in range(0,5):
        # flatten gt of channel i (each channel corresponds to a tissue type)
        flattened_tensors['GT_flattened_' + str(i)] = torch.flatten(target[i])
        # choose only the portion of the flattened gt that correspons to the brain
        flattened_tensors['GT_flattened_' + str(i)] = flattened_tensors['GT_flattened_' + str(i)][mask_flattened.nonzero(as_tuple=True)]
        # make this now a numpy array
        flattened_tensors['GT_flattened_' + str(i)] = flattened_tensors['GT_flattened_' + str(i)].cpu().detach().numpy()

        # repeat for the model output image
        flattened_tensors['iso_flattened_' + str(i)] = torch.flatten(isolated_image[i])
        flattened_tensors['iso_flattened_' + str(i)] = flattened_tensors['iso_flattened_' + str(i)][mask_flattened.nonzero(as_tuple=True)]
        flattened_tensors['iso_flattened_' + str(i)] = flattened_tensors['iso_flattened_' + str(i)].cpu().detach().numpy()

    for i in range(0,5):
        # get output and gt from dict i just constructed
        model_output = flattened_tensors['iso_flattened_' + str(i)]
        GT = flattened_tensors['GT_flattened_' + str(i)]

        # get metrics using the numpy arrays of both (cropped to brain)
        cur_pcorr = np.corrcoef(model_output, GT)[0][1]
        cur_scorr = spearmanr(model_output, GT)[0]

        cur_mse = criterion(torch.tensor(model_output).cuda(args.gpu), torch.tensor(GT).cuda(args.gpu))

        metrics['pearson_corr'].append(cur_pcorr)
        metrics['spearman_corr'].append(cur_scorr)
        metrics['mse'].append(cur_mse.item())

    return metrics


In [32]:
def tissue_wise_map_metrics(isolated_image, target, stacked_brain_map):
    # metrics dict to store metric for each tissue type
    metrics_list = ['DICE', 'HD', 'Jaccard']
    metrics = { i : [] for i in metrics_list }

    # list of flattened tensors (segmentation masks) I'm gonna collect and their corresponding dict
    necessary_masks_list = ['GT_0', 'GT_1', 'GT_2', 'GT_3', 'GT_4', 'iso_0', 'iso_1', 'iso_2', 'iso_3', 'iso_4']
    necessary_tensors = { i : {} for i in necessary_masks_list }

    # current output and gt is 3x192x192x192. Basically, each voxel of the brain has 3 probabilities assigned to it for each tissue type. Taking the argmax gives us the most likely tissue type of each voxel (now 1x192x192x192)
    full_map_model = torch.argmax(isolated_image,0)
    full_map_GT = torch.argmax(target,0)
    mask = stacked_brain_map[0]
    mask_flattened = torch.flatten(stacked_brain_map[0])

    for i in range(0,5):
        # now that we have the argmax, we can imagine the brain with each voxel having a value of 0,1,2. To get the masks for each tissue type, we save a new tensor corresponding to 1 where the argmax tensor has a value of the given tissue type and 0 otherwise.
        necessary_tensors['GT_' + str(i)] = (full_map_GT==i).float()
        necessary_tensors['iso_' + str(i)] = (full_map_model==i).float()
        if i == 0:
            # make sure background is 0
            necessary_tensors['GT_' + str(i)] = torch.mul(necessary_tensors['GT_' + str(i)], mask)
            necessary_tensors['iso_' + str(i)] = torch.mul(necessary_tensors['iso_' + str(i)], mask)

        # calc HD with the segmentation masks
        h_dist = hd(necessary_tensors['iso_' + str(i)].cpu().detach().numpy(), necessary_tensors['GT_' + str(i)].cpu().detach().numpy())
        metrics['HD'].append(h_dist)

        # now make cropped 1d numpy arrays only containing mask values for within the brain for dice calculation
        necessary_tensors['GT_' + str(i)] = torch.flatten(necessary_tensors['GT_' + str(i)])
        necessary_tensors['GT_' + str(i)] = necessary_tensors['GT_' + str(i)][mask_flattened.nonzero(as_tuple=True)]
        necessary_tensors['GT_' + str(i)] = necessary_tensors['GT_' + str(i)].cpu().detach().numpy()
        necessary_tensors['iso_' + str(i)] = torch.flatten(necessary_tensors['iso_' + str(i)])
        necessary_tensors['iso_' + str(i)] = necessary_tensors['iso_' + str(i)][mask_flattened.nonzero(as_tuple=True)]
        necessary_tensors['iso_' + str(i)] = necessary_tensors['iso_' + str(i)].cpu().detach().numpy()

    for i in range(0,5):
        model_output = necessary_tensors['iso_' + str(i)]
        GT = necessary_tensors['GT_' + str(i)]
        # dice formula
        dice = np.sum(model_output[GT==1])*2.0 / (np.sum(model_output) + np.sum(GT))
        jaccard = jaccard_score(GT, model_output)

        metrics['DICE'].append(dice)
        metrics['Jaccard'].append(jaccard)

    return metrics

In [35]:
if __name__ == '__main__':

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    
    model = TABS(img_dim = 80,
                 output_ch = 5)  # Instantiate an instance of the model's class
    
    state_dict = torch.load('/kaggle/working/best_model_TABS.pth')
    model.load_state_dict(state_dict, strict=False)
    
    #checkpoint = torch.load(load_dir, map_location=torch.device(gpu))
    #model.load_state_dict(checkpoint['state_dict'])  # Load the saved state dictionary
    model.cuda(gpu)
    
    #dataloader_train = DataLoader(dataset_train, batch_size=3, shuffle=True)
    #dataloader_val = DataLoader(dataset_val, batch_size=3, shuffle=True)
    
    criterion = nn.MSELoss(reduction='mean')
    criterion = criterion.cuda(gpu)

    probability_metrics_list = ['pearson_corr', 'spearman_corr', 'mse']
    probability_metrics = { i : [] for i in probability_metrics_list }
    map_metrics_list = ['DICE', 'HD', 'Jaccard']
    map_metrics = { i : [] for i in map_metrics_list }
    
    
    model.eval()  # Set the model to evaluation mode


    with torch.no_grad():
        val_losses = []
        test = []
        val_corr = []

        # Loop through test dataloader here.
        for i, (ct_scan, mask) in enumerate(dataloader_test):

            dataloader_test = DataLoader(dataset_test, batch_size=3, shuffle=True)

            ct_scan, mask = ct_scan.to(device), mask.to(device)

            loss, isolated_images, stacked_brain_maps  = get_loss(model, criterion, ct_scan, mask, 'val')

            val_losses.append(loss)

            for g in range(0,len(isolated_images)):
                isolated_image = isolated_images[g]
                target = mask[g]
                stacked_brain_map = stacked_brain_maps[g]
                metrics_maps = tissue_wise_map_metrics(isolated_image, target, stacked_brain_map)
                metrics =  tissue_wise_probability_metrics(isolated_image, target, stacked_brain_map)

                for metric in probability_metrics_list:
                    probability_metrics[metric].append(metrics[metric])
                for metric in map_metrics_list:
                    map_metrics[metric].append(metrics_maps[metric])

    val_net_loss = sum(val_losses)/len(val_losses)

    overall_pcorr = probability_metrics['pearson_corr']
    overall_pcorr = np.array(overall_pcorr)
    avg_pcorr = sum(overall_pcorr)/len(overall_pcorr)
    sd_pcorr = np.std(overall_pcorr, axis=0, ddof=1)

    overall_scorr = probability_metrics['spearman_corr']
    overall_scorr = np.array(overall_scorr)
    avg_scorr = sum(overall_scorr)/len(overall_scorr)
    sd_scorr = np.std(overall_scorr, axis=0, ddof=1)

    overall_mse = probability_metrics['mse']
    overall_mse = np.array(overall_mse)
    avg_mse = sum(overall_mse)/len(overall_mse)
    sd_mse = np.std(overall_mse, axis=0, ddof=1)

    overall_DICE = map_metrics['DICE']
    overall_DICE = np.array(overall_DICE)
    avg_DICE = sum(overall_DICE)/len(overall_DICE)
    sd_DICE = np.std(overall_DICE, axis=0, ddof=1)

    overall_HD = map_metrics['HD']
    overall_HD = np.array(overall_HD)
    avg_HD = sum(overall_HD)/len(overall_HD)
    sd_HD = np.std(overall_HD, axis=0, ddof=1)

    overall_jaccard = map_metrics['Jaccard']
    overall_jaccard = np.array(overall_jaccard)
    avg_jaccard = sum(overall_jaccard)/len(overall_jaccard)
    sd_jaccard = np.std(overall_jaccard, axis=0, ddof=1)

    print('Probability-Based Metrics:')
    print('Val Loss: {} | Pearson: {} SD: {} | Spearman: {} SD: {} | MSE: {} SD: {}'.format(val_net_loss, avg_pcorr, sd_pcorr, avg_scorr, sd_scorr, avg_mse, sd_mse))

    print('Map-Based Metrics:')
    print('DICE: {} SD: {} | HD: {} SD: {} | Jaccard: {} SD: {}'.format(avg_DICE, sd_DICE, avg_HD, sd_HD, avg_jaccard, sd_jaccard))

    log_name = os.path.join(args.root, 'test_TABS.txt')
    with open(log_name, "a") as log_file:
        log_file.write('Pearson: {} SD: {} | Spearman: {} SD: {} | MSE: {} SD: {}'.format(avg_pcorr, sd_pcorr, avg_scorr, sd_scorr, avg_mse, sd_mse))
        log_file.write('\n')
        log_file.write('DICE: {} SD: {} | HD: {} SD: {} | Jaccard: {} SD: {}'.format(avg_DICE, sd_DICE, avg_HD, sd_HD, avg_jaccard, sd_jaccard))
        log_file.write('\n')
        log_file.write('pcorr')
        log_file.write('%s\n' % overall_pcorr)
        log_file.write('scorr')
        log_file.write('%s\n' % overall_scorr)
        log_file.write('MSE')
        log_file.write('%s\n' % overall_mse)
        log_file.write('dice')
        log_file.write('%s\n' % overall_DICE)
        log_file.write('jaccard')
        log_file.write('%s\n' % overall_jaccard)
        log_file.write('hd')
        log_file.write('%s\n' % overall_HD)

RuntimeError: The second supplied array does not contain any binary object.