### Brain Extraction

Once the data is downloaded, this script is used extract the brain from the skull. We used Deep_MRI_brain_extraction [Deep MRI Brain Extraction](https://github.com/GUR9000/Deep_MRI_brain_extraction)
(commit version 7c2db1e). However, the algorithm relies on Theano, which can cause compatibility issue with the packages we use for KeyMorph. For this demo, we trained a separate brain extractor based on the mask from Deep_MRI_brain_extraction. If you want to recreate the original mask, feel free to follow the instruction in their repo. 

The scripts within this notebook creates 3 folder `T1_mask`, `T2_mask` and `PD_mask` within `../data/processed_IXI/`. Each mask has the following naming convention `IXI###-HOSPITAL-####_mask.nii.gz`. For example, subject `../data/processed_IXI/T1/IXI261-HH-1704.nii` will have a corresponding mask  `../data/processed_IXI/T1_mask/IXI261-HH-1704_mask.nii.gz`.

### Import Libraries and Define Variables


In [None]:
import sys
sys.path.insert(0, '../')

from pathlib import Path
from tqdm.notebook import tqdm
import torch.nn.functional as F
from torchio.transforms import Lambda
from keymorph.data import ixi as loader
from keymorph.model import Simple_Unet, clean_mask


import os
import torch
import torchio as tio
import SimpleITK as sitk

"""Create Folder"""
Path('../data/processed_IXI/T1_mask/').mkdir(parents=True, exist_ok=True)
Path('../data/processed_IXI/T2_mask/').mkdir(parents=True, exist_ok=True)
Path('../data/processed_IXI/PD_mask/').mkdir(parents=True, exist_ok=True)

"""Define Model"""
enc_nf = [4, 8, 16, 32]
dec_nf = [32, 16, 8, 4]    
u1 = Simple_Unet(input_ch=1,
                 out_ch=1,
                 use_in= False,
                 enc_nf= enc_nf,
                 dec_nf= dec_nf)

u1 = torch.nn.DataParallel(u1)
u1.cuda()

weights = torch.load('../weights/brain_extraction_model.pth.tar')['u1']
u1.load_state_dict(weights)

### Load data

In [None]:
directory = '../data/processed_IXI/'

"""Load data"""
transform = Lambda(lambda x: x.permute(0,1,3,2))
N = len(os.listdir(directory+'/T1/'))
_, t1_loader = loader.create_simple(directory,
                                    modality='T1',
                                    transform=transform)

N = len(os.listdir(directory+'/T2/'))
_, t2_loader = loader.create_simple(directory,
                                    modality='T2',
                                    transform=transform)

N = len(os.listdir(directory+'/PD/'))
_, pd_loader = loader.create_simple(directory,
                                    modality='PD',
                                    transform=transform)

### Predict

In [None]:
loaders = [t1_loader, t2_loader, pd_loader]
modalities = ['T1','T2','PD']
for m, loader in enumerate(loaders):
    print('Processing {}'.format(modalities[m]))
    
    for i, data in tqdm(enumerate(loader)):
        name = data['mri']['stem'][0]
        
        x = data['mri'][tio.DATA]
        x = F.interpolate(x, size=(128,128,128), mode='trilinear', align_corners=False)      
        x = x.cuda()
        
        mask = u1(x)
        mask = F.interpolate(mask, scale_factor=2, mode='trilinear', align_corners=False)
        mask = (mask>=0.5).float()
        mask = mask.squeeze().permute(1,2,0)
        mask = mask.detach().cpu().numpy().astype('uint8')
        mask = clean_mask(mask, 0.2)

        itkimage = sitk.GetImageFromArray(mask)
        sitk.WriteImage(itkimage, 
                        '../data/processed_IXI/{}_mask/{}_mask.nii.gz'.format(modalities[m],
                                                                              name))