In [60]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import models
from torchvision.models.vgg import VGG
from PIL import Image
import numpy as np
import os
import nibabel as nib
import time
import datetime
import torch.nn.functional as F
from models.unet import UNet
import matplotlib.pyplot as plt
from metrics.torch_seg_metrics import dice_score, iou
from datasets.BRATS2018 import ToTensorVal, NormalizeBRATSVal, ZeroPadVal
from datasets.BRATS2018_3D import CenterCropBRATS3D, NormalizeBRATS3D
from models.resnet3D import resnet50_3D
from tqdm import tqdm

%matplotlib inline

In [2]:
def infer(case_name, seg_type, model_path, device, val_dir):
    case_dir = os.path.join(val_dir, case_name)
    if seg_type == 'et' or seg_type == 'tc':
        scan_path = os.path.join(case_dir, case_name + '_t1ce.nii.gz')
        sc = nib.load(scan_path)
        sc = sc.get_data()
        sc = sc.transpose((2, 0, 1))
        assert sc.shape == (155, 240, 240)
        
        model = UNet(n_channels=1, n_classes=1, residual=True)
    elif seg_type == 'wt':
        t2_path = os.path.join(case_dir, case_name + '_t2.nii.gz')
        flair_path = os.path.join(case_dir, case_name + '_flair.nii.gz')
        t2 = nib.load(t2_path).get_data()
        flair = nib.load(flair_path).get_data()
        sc = np.array([t2, flair]).transpose((3, 0, 1, 2))
        assert sc.shape == (155, 2, 240, 240)
        
        model = UNet(n_channels=2, n_classes=1, residual=True)
    elif seg_type == 'seg':
        t1_path = os.path.join(case_dir, case_name + '_t1.nii.gz')
        t1ce_path = os.path.join(case_dir, case_name + '_t1ce.nii.gz')
        t2_path = os.path.join(case_dir, case_name + '_t2.nii.gz')
        flair_path = os.path.join(case_dir, case_name + '_flair.nii.gz')
        
        t1 = nib.load(t1_path).get_data()
        t1ce = nib.load(t1ce_path).get_data()
        t2 = nib.load(t2_path).get_data()
        flair = nib.load(flair_path).get_data()
        sc = np.array([t1, t1ce, t2, flair]).transpose((3, 0, 1, 2))
        assert sc.shape == (155, 4, 240, 240)
        
        model = UNet(n_channels=4, n_classes=4, residual=True, expansion=2)
    else:
        raise ValueError('seg_type should only be et, tc, wt or seg')
    
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    
    preds = np.zeros((155, 240, 240))
    
    for i in range(155):
        slice_i = sc[i]
        if seg_type == 'et' or seg_type == 'tc':
            slice_i = np.expand_dims(slice_i, axis=0)
            assert slice_i.shape == (1, 240, 240)
        elif seg_type == 'wt':
            assert slice_i.shape == (2, 240, 240)
        elif seg_type == 'seg':
            assert slice_i.shape == (4, 240, 240)
        
        normalize = NormalizeBRATSVal()
        totensor = ToTensorVal()
        
        slice_i = normalize(slice_i)
        slice_i = totensor(slice_i)
        
        # unsqueeze the dimension to 4, NxCxHxW
        slice_i = torch.unsqueeze(slice_i, dim=0)
        slice_i = slice_i.to(device)
        with torch.no_grad():
            output = model(slice_i)
            if seg_type == 'seg':
                pred = torch.argmax(F.softmax(output, dim=1), dim=1, keepdim=True)
            else:
                pred = torch.sigmoid(output) > 0.5
        
        # squeeze the dimension down to 2, HxW
        pred = torch.squeeze(pred)
        pred = pred.cpu().numpy()
        preds[i] = pred
            
    
    return preds

In [3]:
val_dir = 'BRATS2018_Validation/'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
sorted(os.listdir(val_dir))[:10]

['Brats18_CBICA_AAM_1',
 'Brats18_CBICA_ABT_1',
 'Brats18_CBICA_ALA_1',
 'Brats18_CBICA_ALT_1',
 'Brats18_CBICA_ALV_1',
 'Brats18_CBICA_ALZ_1',
 'Brats18_CBICA_AMF_1',
 'Brats18_CBICA_AMU_1',
 'Brats18_CBICA_ANK_1',
 'Brats18_CBICA_APM_1']

In [5]:
preds = infer('Brats18_CBICA_ALA_1', 'seg', '../UNet-ResidualBlock-Expansion_210_end_to_end_manual/UNet-ResidualBlock-Expansion-BRATS2018-End-to-End_batch6_training_epochs15_Adam_scheduler-step10-gamma1.0_lr5e-05_w_decay3e-05/trained_model.pt', device, val_dir)



In [34]:
et_indices = np.argwhere(preds == 3.)
et_xs = [ind[0] for ind in et_indices]
et_ys = [ind[1] for ind in et_indices]
et_zs = [ind[2] for ind in et_indices]

edema_indices = np.argwhere(preds == 2.)
edema_xs = [ind[0] for ind in edema_indices]
edema_ys = [ind[1] for ind in edema_indices]
edema_zs = [ind[2] for ind in edema_indices]

necrotic_indices = np.argwhere(preds == 1.)
necrotic_xs = [ind[0] for ind in necrotic_indices]
necrotic_ys = [ind[1] for ind in necrotic_indices]
necrotic_zs = [ind[2] for ind in necrotic_indices]

In [19]:
import plotly.plotly as py
import plotly.graph_objs as go
import plotly

In [20]:
plotly.tools.set_credentials_file(username='MartinMa28', api_key='nA5kDovmOVdudM3HaZ1V')

In [28]:
trace = go.Scatter3d(
    x=et_xs,
    y=et_ys,
    z=et_zs,
    mode='markers',
    marker=dict(
        size=3,
        line={'color': 'rgba(217, 57, 57, 0.34)',
             'width': 0.5},
        opacity=0.6)
)
data = [trace]
layout = go.Layout(
    margin=dict(
        l=0,
        r=0,
        b=0,
        t=0
    )
)
fig = go.Figure(data=data, layout=layout)
py.iplot(fig, filename='3D-Enhancing-tumor')


Consider using IPython.display.IFrame instead



In [33]:
trace_edema = go.Scatter3d(
    x=edema_xs,
    y=edema_ys,
    z=edema_zs,
    mode='markers',
    marker=dict(size=3,
                line={'color': 'rgba(57, 217, 57, 0.34)',
                     'width': 0.5},
                opacity=0.6)
)

data_edema = [trace_edema]
fig_edema = go.Figure(data=data_edema, layout=layout)
py.iplot(fig_edema, filename='3D-Edema')


Woah there! Look at all those points! Due to browser limitations, the Plotly SVG drawing functions have a hard time graphing more than 500k data points for line charts, or 40k points for other types of charts. Here are some suggestions:
(1) Use the `plotly.graph_objs.Scattergl` trace object to generate a WebGl graph.
(2) Trying using the image API to return an image instead of a graph URL
(3) Use matplotlib
(4) See if you can create your visualization with fewer data points



Consider using IPython.display.IFrame instead



In [36]:
trace_necrotc = go.Scatter3d(
    x=necrotic_xs,
    y=necrotic_ys,
    z=necrotic_zs,
    mode='markers',
    marker=dict(size=3,
                line={'color': 'rgba(57, 57, 217, 0.34)',
                     'width': 0.5},
                opacity=0.8)
)

data_necrotic = [trace_necrotc]
fig_necrotic = go.Figure(data=data_necrotic, layout=layout)
py.iplot(fig_necrotic, filename='3D-Necrotic-tumor')


Consider using IPython.display.IFrame instead



## 3D scatter plot of glioma segmentation

In [43]:
seg_xs = et_xs + edema_xs + necrotic_xs
seg_ys = et_ys + edema_ys + necrotic_ys
seg_zs = et_zs + edema_zs + necrotic_zs
seg_color = [3] * len(et_xs) + [2] * len(edema_xs) + [1] * len(necrotic_xs)

In [63]:
trace_seg = go.Scatter3d(
    x=seg_xs,
    y=seg_ys,
    z=seg_zs,
    mode='markers',
    marker={
        'size': 3,
        'color': seg_color,
        'colorscale': 'Viridis',
        'opacity': 0.8
    }
)
data_seg = [trace_seg]
fig_seg = go.Figure(data=data_seg, layout=layout)
py.iplot(fig_seg, filename='3D-Glioma-segmentation')


Woah there! Look at all those points! Due to browser limitations, the Plotly SVG drawing functions have a hard time graphing more than 500k data points for line charts, or 40k points for other types of charts. Here are some suggestions:
(1) Use the `plotly.graph_objs.Scattergl` trace object to generate a WebGl graph.
(2) Trying using the image API to return an image instead of a graph URL
(3) Use matplotlib
(4) See if you can create your visualization with fewer data points




The draw time for this plot will be slow for clients without much RAM.



Estimated Draw Time Slow


Consider using IPython.display.IFrame instead



In [58]:
import IPython

#iframe = '<iframe width="900" height="800" frameborder="0" scrolling="no" src="//plot.ly/~MartinMa28/108.embed"></iframe>'
IPython.display.IFrame(src="//plot.ly/~MartinMa28/108.embed", width=900, height=800)