# Recast Whole-Brain MRI to Pituitary MRI

- Load Data
- Pre-process
- Inference
- Viz

In [115]:
import os, sys, time, glob, re

import pandas as pd
import numpy as np
import torch
import SimpleITK as sitk


# transforms
from fastai.basics import *
from monai.transforms import (
    LoadImaged,
    AddChanneld,
    CenterSpatialCropd,
    Compose,
    NormalizeIntensityd,
    Spacingd,
    SpatialPadd,
    ToTensord,
)

# Models

In [116]:
from monai.networks.nets import VNet, UNet

ensemble_models = {
    "UNET3D_dice_loss": UNet(
                    dimensions=3,
                    in_channels=1,
                    out_channels=2,
                    channels=(16, 32, 64, 128, 256),
                    strides=(2, 2, 2, 2),
                    num_res_units=2,
                    dropout=0.0,
                ),

    "VNET_dice_loss": VNet(
                spatial_dims=3,
                in_channels=1,
                out_channels=2,
            ),
    
    "CONDSEG_dice_loss": UNet(
                    dimensions=3,
                    in_channels=3,
                    out_channels=2,
                    channels=(16, 32, 64, 128, 256),
                    strides=(2, 2, 2, 2),
                    num_res_units=2,
                    dropout=0.0,
                )
}

# Load pretrained weights
for model_name, model_arch in ensemble_models.items():
    loc = torch.load(f'ensemble_models/{model_name}/model.pth')
    model_arch.load_state_dict(loc['model'])
    
# Loss
def dice(input, target):
    iflat = input.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()
    return ((2. * intersection) /
           (iflat.sum() + tflat.sum()))

def dice_score(input, target):
    return dice(input.argmax(1), target)

def dice_loss(input, target): 
    return 1 - dice(input.softmax(1)[:, 1], target)

def ce_loss(input, target):
    return torch.nn.BCEWithLogitsLoss()(input[:, 1], target.squeeze(1))

# Load Data

- Data source: LONI Imaging & Data Archive (https://ida.loni.usc.edu/login.jsp). 
- Datasets: ABIDE, ABVIB, ADNI1_Complete_1Yr_1.5T, AIBL, ICMB, and PPMI.

In [117]:
# Data Source
dset_src = '/gpfs/data/oermannlab/private_data/DeepPit/PitMRdata'

# Load example files
inference_df = pd.read_csv("inference_example.csv")
display(inference_df)

# We are doing inference on one example from each dataset
mr_paths = inference_df.fn.values
print("MR paths: ", *mr_paths, sep="\n")

# for conditional segmentation models, we input also contains an atlas MR and segmentation
example_atlas = {
    'image': f'{dset_src}/samir_labels/50373-50453/50437/MP-RAGE/2000-01-01_00_00_00.0/S165191/ABIDE_50437_MRI_MP-RAGE_br_raw_20120830214425874_S165191_I329201_corrected_n4.nii',
    'label': f'{dset_src}/samir_labels/50373-50453/50437/seg.nii'
}

Unnamed: 0,fn,imputedSeq,sz,px,sp,dir
0,ABIDE/ABIDE_1/50412/MP-RAGE/2000-01-01_00_00_00.0/S164292,MPR,"(106, 256, 256)",16-bit signed integer,"(1.4, 1.0, 1.0)","(1, 0, 0, 0, -1, 0, 0, 0, 1)"
1,ABVIB/ABVIB/3830/MPRAGE/2012-06-18_11_46_49.0/S341930,MPR,"(256, 256, 192)",16-bit unsigned integer,"(1.0, 1.0, 1.0)","(0, 0, -1, 1, 0, 0, 0, -1, 0)"
2,ADNI/ADNI1_Complete_1Yr_1.5T/ADNI/023_S_0139/MPR-R__GradWarp__B1_Correction__N3__Scaled/2007-02-09_09_44_27.0/S26343,MPR,"(192, 192, 160)",32-bit float,"(1.26, 1.25, 1.19)","(0, 0, 1, 0, 1, 0, -1, 0, 0)"
3,AIBL/AIBL/338/MPRAGE_SAG_ISO_p2_ND/2012-10-06_11_32_11.0/S231118,MPR,"(256, 256, 176)",16-bit unsigned integer,"(1.0, 1.0, 1.0)","(0, 0, -1, 1, 0, 0, 0, -1, 0)"
4,ICMB/ICBM/UTHC_1098/MPRAGE_T1_AX_0.8_mm_TI-780/2009-03-13_13_01_09.0/S68959,MPR,"(220, 320, 208)",16-bit unsigned integer,"(0.8, 0.8, 1.0)","(1, 0, 0, 0, 1, 0, 0, 0, 1)"
5,PPMI/PPMI/3505/MPRAGEadni/2010-12-23_10_50_52.0/S189286,MPR,"(288, 288, 170)",16-bit unsigned integer,"(0.92, 0.92, 1.2)","(0, 0, -1, 1, 0, 0, 0, -1, 0)"


MR paths: 
ABIDE/ABIDE_1/50412/MP-RAGE/2000-01-01_00_00_00.0/S164292
ABVIB/ABVIB/3830/MPRAGE/2012-06-18_11_46_49.0/S341930
ADNI/ADNI1_Complete_1Yr_1.5T/ADNI/023_S_0139/MPR-R__GradWarp__B1_Correction__N3__Scaled/2007-02-09_09_44_27.0/S26343
AIBL/AIBL/338/MPRAGE_SAG_ISO_p2_ND/2012-10-06_11_32_11.0/S231118
ICMB/ICBM/UTHC_1098/MPRAGE_T1_AX_0.8_mm_TI-780/2009-03-13_13_01_09.0/S68959
PPMI/PPMI/3505/MPRAGEadni/2010-12-23_10_50_52.0/S189286


# Pre-process

SITK: N4 bias correction and Re-orient to LAS coordinates. Note: N4 bias correction + re-orientation takes ~20s per input. In our workflow, we compute and save the result of this pre-processing step once over all the raw inputs.

MONAI: Preprocess inputs to standard voxel spacing, intensity, center crop to standard dimensions. On-the-fly.

In [118]:
def read_dcm(fn):
    """IO for reading .dcm files in terminal folder"""
    dcms = sitk.ImageSeriesReader_GetGDCMSeriesFileNames(fn)
    if len(dcms) == 1: dcms = dcms[0]   
    im = sitk.ReadImage(dcms, sitk.sitkFloat32)
    return im

def read_nii(fn):
    """IO for reading .nii files in terminal folder"""
    if not fn.endswith(".nii"):
        niis = [f for f in os.listdir(fn) if f.endswith(".nii") and not f.startswith("._")]
        nii   = niis[0]
        im = sitk.ReadImage(f"{fn}/{nii}", sitk.sitkFloat32)    
    else:
        im = sitk.ReadImage(fn, sitk.sitkFloat32)
    return im

def n4_bias_correct(mr_path, do_write=False):
    """ Perform N4 bias correction. Input: path to folder with .dcms or .niis. """
    start = time.time()
    
    # Read in image
    try:
        inputImage = read_nii(mr_path)
    except:
        inputImage = read_dcm(mr_path) 

    # Mask the head to estimate bias
    maskImage = sitk.OtsuThreshold(inputImage, 0, 1, 200)

    # Set corrector
    corrector = sitk.N4BiasFieldCorrectionImageFilter()
    corrector.SetMaximumNumberOfIterations([3] * 3)
    corrected_image = corrector.Execute(inputImage, maskImage)

    # LPI coordinates
    corrected_image = sitk.DICOMOrient(inputImage, "LAS")
    
    # write image
    if do_write:
        corrected_fn = f"{mr_path}/corrected_n4.nii"
        sitk.WriteImage(corrected_image, corrected_fn)

    elapsed = time.time() - start
    print(f"Elapsed: {elapsed:0.2f} s")
    
    return corrected_image

class UndoDict(ItemTransform):
    """Convert dictionary to tuple """
    split_idx = None

    def __init__(self, keys=["image"]):
        self.keys = keys
        
    def encodes(self, d):
        item = tuple(d[key] for key in self.keys)
        # for condseg, 3-ch input
        item = torch.cat(item, dim=0)
        return item

    def __str__(self):
        return f"UndoDict({self.keys})"
    
def get_inference_transforms(keys, sp=(1.5,1.5,1.5), sz=(96,96,96), do_condseg = False):    
    """Preprocess inputs to standard voxel spacing, intensity, center crop to standard dimensions"""
    
    # Z-scale intensity values in image (not labels)
    image_keys = [k for k in keys if "image" in k]
    
    # nearest neighbor interpolation for labels
    interp_mode = tuple(["bilinear" if "image" in k else "nearest" for k in keys])
        
    return Compose([
        Spacingd(keys, pixdim=sp, mode = interp_mode),
        NormalizeIntensityd(image_keys, nonzero=True, channel_wise=False),
        AddChanneld(keys),
        SpatialPadd(keys, spatial_size=sz, method="symmetric", mode="constant"),
        CenterSpatialCropd(keys, roi_size=sz),
        ToTensord(keys),
        UndoDict(keys),
    ])

# sitk obj and np array have different index conventions
def sitk2np(obj): return np.swapaxes(sitk.GetArrayFromImage(obj), 0, 2)
def np2sitk(arr): return sitk.GetImageFromArray(np.swapaxes(arr, 0, 2))

def torch2sitk(t): return sitk.GetImageFromArray(torch.transpose(t, 0, 2))
def sitk2torch(o): return torch.transpose(torch.tensor(sitk.GetArrayFromImage(o)), 0, 2)

In [119]:
# Optional: save the results as time-intensive, eg 
# our example atlas for condseg has already been N4 bias corrected and LAS oriented.
n4_las_images = [n4_bias_correct(f"{dset_src}/{mr_path}") for mr_path in inference_df.fn.values]

Elapsed: 25.69 s
Elapsed: 24.29 s
Elapsed: 9.00 s
Elapsed: 36.08 s
Elapsed: 34.02 s
Elapsed: 28.59 s


In [120]:
# the atlas is only used for CONDSEG
inputsd  = [
    {
        "image": sitk2np(im), 
        "atlas_image": sitk2np(read_nii(example_atlas["image"])),
        "atlas_label": sitk2np(read_nii(example_atlas["label"]))
    } 
    for im in n4_las_images
]

preproc_inputs         = get_inference_transforms(keys = ("image",))(inputsd)
condseg_preproc_inputs = get_inference_transforms(keys = ("image","atlas_image", "atlas_label"), do_condseg=True)(inputsd)

print("UNET/VNET input shape: ", inputs[0].shape, ". CONDSEG input shape: ", condseg_inputs[0].shape)

UNET/VNET input shape:  torch.Size([3, 96, 96, 96]) . CONDSEG input shape:  torch.Size([3, 96, 96, 96])


# Inference

Post-process: (a)  keep largest connected component (b) majority vote among ensemble models.

In [121]:
# source sitk 36_Microscopy_Colocalization_Distance_Analysis.html
def get_largest_connected_component(binary_seg):
    # connected components in sitkSeg
    labeled_seg = sitk.ConnectedComponent(binary_seg)

    # re-order labels according to size (at least 1_000 pixels = 10x10x10)
    labeled_seg = sitk.RelabelComponent(labeled_seg, minimumObjectSize=1000, sortByObjectSize=True)

    # return segm of largest label
    binary_seg = labeled_seg == 1
    
    return binary_seg

In [122]:
# each item in the list is the model's preds
all_outputs = []
all_preds = []

for model_name, model_arch in ensemble_models.items():
    
    # get pre-processed inputs
    inputs = preproc_inputs if model_name != "CONDSEG_dice_loss" else condseg_preproc_inputs
    inputs = torch.stack(inputs, dim=0)
    print(inputs.shape)
    # apply model to the 6 example inputs
    model_arch.eval()
    with torch.no_grad():
        preds = model_arch(inputs).cpu()
        all_preds.append(preds)
        print(preds.shape)
        
    # keep largest connected component for each model prediction
    lcc = [get_largest_connected_component(torch2sitk(pred.argmax(0).byte())) for pred in preds]
    all_outputs.append(lcc)
    print(len(lcc), lcc[0].GetSize())
    
# get majority vote for each input
def get_votes(i):
    return [lcc[i] for lcc in all_outputs]

labelForUndecidedPixels = 0
majority_votes = [sitk.LabelVoting(get_votes(i), labelForUndecidedPixels) for i in range(len(preproc_inputs))]

torch.Size([6, 1, 96, 96, 96])
torch.Size([6, 2, 96, 96, 96])
6 (96, 96, 96)
torch.Size([6, 1, 96, 96, 96])
torch.Size([6, 2, 96, 96, 96])
6 (96, 96, 96)
torch.Size([6, 3, 96, 96, 96])
torch.Size([6, 2, 96, 96, 96])
6 (96, 96, 96)


In [123]:
for i in range(6):
    for j in range(3):
        print(len(torch.nonzero(sitk2torch(all_outputs[j][i]))))

9626
10126
10222
10661
11766
11556
5531
5578
6341
9529
9528
9507
12876
11665
13773
7834
8561
7219


In [130]:
# get majority vote for each input
def get_votes(i):
    return [lcc[i] for lcc in outputs]

labelForUndecidedPixels = 0
majority_votes = [sitk.LabelVoting(get_votes(i), labelForUndecidedPixels) for i in range(len(preproc_inputs))]

In [131]:
len(majority_votes)

6

In [132]:
for i in range(6):
    print(len(torch.nonzero(sitk2torch(majority_votes[i]))))

0
11304
5771
9532
12689
8024


### Viz inputs and outputs

In [133]:
# ROI bounding box

def mask2bbox(mask):
    k = torch.any(torch.any(mask, dim=0), dim=0) # 0 -> 1,2 -> 1 -> 2 left
    j = torch.any(torch.any(mask, dim=0), dim=1) # 0 -> 1,2 -> 2 -> 1 left
    i = torch.any(torch.any(mask, dim=1), dim=1) # 1 -> 0,2 -> 0 -> 0 left
    
    imin, imax = torch.where(i)[0][[0, -1]]
    jmin, jmax = torch.where(j)[0][[0, -1]]
    kmin, kmax = torch.where(k)[0][[0, -1]]
    
    # inclusive indices
    return torch.tensor([imin, imax+1, jmin, jmax+1, kmin, kmax+1])

In [134]:
import matplotlib.pyplot as plt
%matplotlib inline

from matplotlib import colors
bin_cmap2  = colors.ListedColormap(['white', 'yellow'])

def get_mid_idx(bbox, is_mask=False):
    
    if is_mask:
        bbox = mask2bbox(bbox)
       
    axis_len = bbox[1] - bbox[0]
    mid0     = bbox[0] + axis_len//2
        
    axis_len = bbox[3] - bbox[2]
    mid1     = bbox[2] + axis_len//2

    axis_len = bbox[5] - bbox[4]
    mid2     = bbox[4] + axis_len//2

    return mid0, mid1, mid2

def imshows(ims):
    nrow = len(ims)
    ncol = 6
    
    fig, axes = plt.subplots(nrow, ncol, figsize=(
        ncol * 3, nrow * 3), facecolor='white')
    for i, im_dict in enumerate(ims):
        im = im_dict["image"].detach().squeeze().cpu().numpy()
        label = im_dict["label"].detach().squeeze().cpu().numpy()
        fname = im_dict["fname"]
        
        # get bbox
        bbox = mask2bbox(torch.tensor(label))
        mids = get_mid_idx(bbox)
        for axis_idx in range(3):
            slice_idx = mids[axis_idx]

            # plot image 0-2
            ax  = axes[i, 0+axis_idx]
            #ax.set_title(f"{fname} Slice {slice_idx} (Axis {axis_idx})")
            ax.imshow(np.rot90(np.take(im, slice_idx, axis=axis_idx)), cmap=plt.cm.gray)
            ax.imshow(np.rot90(np.take(label, slice_idx, axis=axis_idx)), alpha=0.2, cmap=bin_cmap2)
            ax.axis("off")
                
            # plot labels 3-5
            ax        = axes[i, 3+axis_idx]
            ax.axis("off")
            im_show = ax.imshow(np.rot90(np.take(label, slice_idx, axis=axis_idx)))
            #fig.colorbar(im_show, ax=ax)
            
        # print fname
        axes[i,0].set_title(f"{fname}")
        axes[i,3].set_title(f"Ensemble Seg")

In [136]:
# test on one
to_imshow = []

for i in range(len(preproc_inputs)):
    # get dset name
    input_fn = mr_paths[i]
    dset_name = input_fn[:input_fn.index("/")]
    
    # get input and output
    input_im   = preproc_inputs[i]
    output_seg = sitk2torch(majority_votes[i])
    
    print(mask2bbox(output_seg))
    
    to_imshow.append({"image": input_im, "label": output_seg, "fname": dset_name})

imshows(to_imshow)

IndexError: index is out of bounds for dimension with size 0