In [1]:
import numpy as np
import torch
import torch.nn.functional as F
import nibabel as nib
from datasets.BRATS2018 import ToTensorVal
from datasets.BRATS2018_3D import CenterCropBRATS3D, NormalizeBRATS3D
from models.resnet3D import resnet50_3D
from models.logistic import Logistic
import os

In [2]:
# helper functions
def time_stamp() -> str:
    ts = time.time()
    time_stamp = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S')
    return time_stamp

In [3]:
def classification_infer(case_name, val_dir, model):
    case_dir = os.path.join(val_dir, case_name)
    
    t1 = nib.load(os.path.join(case_dir, case_name + '_t1.nii.gz')).get_data()
    t1ce = nib.load(os.path.join(case_dir, case_name + '_t1ce.nii.gz')).get_data()
    t2 = nib.load(os.path.join(case_dir, case_name + '_t2.nii.gz')).get_data()
    flair = nib.load(os.path.join(case_dir, case_name + '_flair.nii.gz')).get_data()
    
    assert t1.shape == (240, 240, 155)
    assert t1ce.shape == (240, 240, 155)
    assert t2.shape == (240, 240, 155)
    assert flair.shape == (240, 240, 155)
    
    sc = np.array([t1, t1ce, t2, flair])
    assert sc.shape == (4, 240, 240, 155)
    
    # fake label for position
    fake_label = np.array([1])
    
    center_crop = CenterCropBRATS3D()
    normalize = NormalizeBRATS3D()
    to_tensor = ToTensorVal()
    
    sc, fake_label = center_crop((sc, fake_label))
    sc, fake_label = normalize((sc, fake_label))
    sc = to_tensor(sc)
    
    # unsqueeze to 5-dimension array, NxCxHxWxD
    sc = torch.unsqueeze(sc, dim=0)
    sc = sc.to(device)
    
    with torch.no_grad():
        label = model(sc)
        label = torch.argmax(F.softmax(label, dim=1), dim=1, keepdim=True)
    
    # squeeze to rank-0 array
    label = torch.squeeze(label)
    label = label.cpu().numpy()
    
    return label

In [4]:
case_list = sorted(os.listdir('BRATS2018_Validation/'))
val_dir = 'BRATS2018_Validation'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = resnet50_3D(num_classes=2)
#model = Logistic(volume=(4, 160, 160, 144), num_classes=2)
model.load_state_dict(torch.load('../ResNet50-3D-32-CLS/trained_model.pt', map_location=device))
model.to(device)
model.eval()

ResNet3D(
  (conv1): Conv3d(4, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck3D(
      (conv1): Conv3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (bn2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv3d(32, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (bn3): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (downsample): Sequential(
        (0): Conv3d(32, 128, kernel_

## test classification for BraTS2018 validation set

### resnet50 3D 16 lr 7e-4

In [5]:
for case_name in case_list:
    label = classification_infer(case_name, val_dir, model)
    if label == 0.:
        label = 'LGG'
    else:
        label = 'HGG'
        
    print(f'{case_name} - {label}')

Brats18_CBICA_AAM_1 - HGG
Brats18_CBICA_ABT_1 - HGG
Brats18_CBICA_ALA_1 - HGG
Brats18_CBICA_ALT_1 - HGG
Brats18_CBICA_ALV_1 - HGG
Brats18_CBICA_ALZ_1 - HGG
Brats18_CBICA_AMF_1 - HGG
Brats18_CBICA_AMU_1 - HGG
Brats18_CBICA_ANK_1 - HGG
Brats18_CBICA_APM_1 - HGG
Brats18_CBICA_AQE_1 - HGG
Brats18_CBICA_ARR_1 - HGG
Brats18_CBICA_ATW_1 - HGG
Brats18_CBICA_AUC_1 - HGG
Brats18_CBICA_AUE_1 - HGG
Brats18_CBICA_AZA_1 - HGG
Brats18_CBICA_BHF_1 - HGG
Brats18_CBICA_BHN_1 - HGG
Brats18_CBICA_BKY_1 - HGG
Brats18_CBICA_BLI_1 - HGG
Brats18_CBICA_BLK_1 - HGG
Brats18_MDA_1012_1 - HGG
Brats18_MDA_1015_1 - HGG
Brats18_MDA_1081_1 - HGG
Brats18_MDA_907_1 - HGG
Brats18_MDA_922_1 - HGG
Brats18_TCIA02_230_1 - HGG
Brats18_TCIA02_400_1 - HGG
Brats18_TCIA03_216_1 - HGG
Brats18_TCIA03_288_1 - HGG
Brats18_TCIA03_313_1 - HGG
Brats18_TCIA03_604_1 - HGG
Brats18_TCIA04_212_1 - HGG
Brats18_TCIA04_253_1 - HGG
Brats18_TCIA07_600_1 - HGG
Brats18_TCIA07_601_1 - HGG
Brats18_TCIA07_602_1 - HGG
Brats18_TCIA09_248_1 - LGG
Brats18

### resnet50 3D 16 lr 1e-3

In [7]:
for case_name in case_list:
    label = classification_infer(case_name, val_dir, model)
    print(f'{case_name} - {label}')

Brats18_CBICA_AAM_1 - 1
Brats18_CBICA_ABT_1 - 1
Brats18_CBICA_ALA_1 - 1
Brats18_CBICA_ALT_1 - 1
Brats18_CBICA_ALV_1 - 1
Brats18_CBICA_ALZ_1 - 1
Brats18_CBICA_AMF_1 - 1
Brats18_CBICA_AMU_1 - 1
Brats18_CBICA_ANK_1 - 1
Brats18_CBICA_APM_1 - 1
Brats18_CBICA_AQE_1 - 1
Brats18_CBICA_ARR_1 - 1
Brats18_CBICA_ATW_1 - 1
Brats18_CBICA_AUC_1 - 1
Brats18_CBICA_AUE_1 - 1
Brats18_CBICA_AZA_1 - 1
Brats18_CBICA_BHF_1 - 1
Brats18_CBICA_BHN_1 - 1
Brats18_CBICA_BKY_1 - 1
Brats18_CBICA_BLI_1 - 1
Brats18_CBICA_BLK_1 - 1
Brats18_MDA_1012_1 - 1
Brats18_MDA_1015_1 - 1
Brats18_MDA_1081_1 - 1
Brats18_MDA_907_1 - 1
Brats18_MDA_922_1 - 1
Brats18_TCIA02_230_1 - 1
Brats18_TCIA02_400_1 - 1
Brats18_TCIA03_216_1 - 1
Brats18_TCIA03_288_1 - 1
Brats18_TCIA03_313_1 - 1
Brats18_TCIA03_604_1 - 1
Brats18_TCIA04_212_1 - 1
Brats18_TCIA04_253_1 - 1
Brats18_TCIA07_600_1 - 1
Brats18_TCIA07_601_1 - 1
Brats18_TCIA07_602_1 - 1
Brats18_TCIA09_248_1 - 0
Brats18_TCIA10_195_1 - 0
Brats18_TCIA10_311_1 - 0
Brats18_TCIA10_609_1 - 0
Brats18_

### resnet50 3D 32

In [5]:
for case_name in case_list:
    label = classification_infer(case_name, val_dir, model)
    print(f'{case_name} - {label}')

Brats18_CBICA_AAM_1 - 1
Brats18_CBICA_ABT_1 - 1
Brats18_CBICA_ALA_1 - 1
Brats18_CBICA_ALT_1 - 1
Brats18_CBICA_ALV_1 - 1
Brats18_CBICA_ALZ_1 - 1
Brats18_CBICA_AMF_1 - 1
Brats18_CBICA_AMU_1 - 1
Brats18_CBICA_ANK_1 - 1
Brats18_CBICA_APM_1 - 1
Brats18_CBICA_AQE_1 - 1
Brats18_CBICA_ARR_1 - 1
Brats18_CBICA_ATW_1 - 1
Brats18_CBICA_AUC_1 - 1
Brats18_CBICA_AUE_1 - 1
Brats18_CBICA_AZA_1 - 1
Brats18_CBICA_BHF_1 - 1
Brats18_CBICA_BHN_1 - 1
Brats18_CBICA_BKY_1 - 1
Brats18_CBICA_BLI_1 - 1
Brats18_CBICA_BLK_1 - 1
Brats18_MDA_1012_1 - 1
Brats18_MDA_1015_1 - 1
Brats18_MDA_1081_1 - 1
Brats18_MDA_907_1 - 1
Brats18_MDA_922_1 - 1
Brats18_TCIA02_230_1 - 1
Brats18_TCIA02_400_1 - 1
Brats18_TCIA03_216_1 - 1
Brats18_TCIA03_288_1 - 1
Brats18_TCIA03_313_1 - 1
Brats18_TCIA03_604_1 - 1
Brats18_TCIA04_212_1 - 1
Brats18_TCIA04_253_1 - 1
Brats18_TCIA07_600_1 - 1
Brats18_TCIA07_601_1 - 1
Brats18_TCIA07_602_1 - 1
Brats18_TCIA09_248_1 - 0
Brats18_TCIA10_195_1 - 0
Brats18_TCIA10_311_1 - 0
Brats18_TCIA10_609_1 - 0
Brats18_