In [1]:
import os
import sys
import json
import time 
import copy
import argparse
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import SimpleITK as sitk
import tqdm
import pandas as pd 
from PIL import Image
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from torchvision import transforms
from torch.utils.data import DataLoader
from models import get_model
from data_loader.LiTS import LiTSDataset
from utils.train import ObjFromDict


def load_model(run_dir, metric='validation_dice'): 
    with open(os.path.join(run_dir,'config.json')) as json_file:
        config = json.load(json_file)
    config = ObjFromDict(config)
    model = get_model(config.model)
    checkpoint_path = os.path.join(run_dir, 'best_{}.pth'.format(metric)) 
    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
    return model, config

def compute_dice(gt, pred): 
    eps = 1e-5
    intersection = np.sum(gt * pred)
    return ((2*  intersection) + eps)/ (eps + np.sum(gt+pred))

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
run_dir = 'runs/2020-03-28_17h33min'
model, config = load_model(run_dir)

feature_scale not specified in config, setting to default 4


  nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
  nn.init.normal(m.weight.data, 1.0, 0.02)
  nn.init.constant(m.bias.data, 0.0)


In [3]:
device='cuda:0'
# device='cpu'

In [4]:
data_path = config.dataset.root
print('data_path ', data_path)

# fix the seed for the split
split_seed = 0 
np.random.seed(split_seed)

image_dir = os.listdir(os.path.join(data_path,'Training Batch 1')) + os.listdir(os.path.join(data_path,'Training Batch 2'))
all_indexes = [ int(file_name[7:-4]) for file_name in image_dir if 'volume' in file_name]
split = np.random.permutation(all_indexes)
n_train, n_val, n_test = int(0.8 * len(split)), int(0.1 * len(split)), int(0.1 * len(split))

train = split[: n_train]
val = split[n_train : n_train+n_val]
test = split[n_train + n_val :]


# Setup Data Loader
train_dataset = LiTSDataset(data_path, train, augment=True, no_tumor=True)
val_dataset = LiTSDataset(data_path, val, no_tumor=True)
test_dataset = LiTSDataset(data_path, test, no_tumor=True)
train_dataloader = DataLoader(dataset=train_dataset, num_workers=config.dataset.num_workers, batch_size=config.training.batch_size, shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset, num_workers=config.dataset.num_workers, batch_size=config.training.batch_size, shuffle=False)
test_dataloader  = DataLoader(dataset=test_dataset,  num_workers=config.dataset.num_workers, batch_size=config.training.batch_size, shuffle=False)



data_path  /home/raubyb/LiTS


In [19]:
batch_size=1
model = model.to(device)
model.eval()
dataset = LiTSDataset(data_path, test, physical_reference_size = (768, 768, 768), spacing=3, no_tumor=True, inference_mode=True)
dataloader = DataLoader(dataset=dataset, num_workers=0, batch_size=batch_size, shuffle=False)
dices = []
with torch.no_grad():
    for batch_idx, (data, target) in tqdm.tqdm(enumerate(dataloader)):
            data = data.to(device)
            output = model(data)
#             output = target["one_hot_target"]
            for i in range(batch_size):
                img = sitk.ReadImage(target['original_image_path'][i])
                normalization_transform = dataset.get_normalization_transform(img)
                inv_normalization_transform = normalization_transform.GetInverse()
                out_mask = np.round(output.cpu().numpy()[i,1,:,:,:])

                bb = target['bounding_box'].cpu().numpy()[i]

                ref_img = sitk.Resample(img, dataset.reference_image, normalization_transform)

                big_mask_numpy = sitk.GetArrayFromImage(ref_img)
                big_mask_numpy[:,:,:] = 0
                big_mask_numpy[bb[0]:bb[1],bb[2]:bb[3],bb[4]:bb[5]] = out_mask

                out_mask_image = sitk.GetImageFromArray(big_mask_numpy)
                out_mask_image.SetOrigin(ref_img.GetOrigin())
                out_mask_image.SetDirection(ref_img.GetDirection())
                out_mask_image.SetSpacing(ref_img.GetSpacing())
                out_mask_image_original_space = sitk.Resample(out_mask_image, img, inv_normalization_transform, sitk.sitkNearestNeighbor)
                original_mask = sitk.ReadImage(target['original_mask_path'][i])
                original_mask_array = np.clip(sitk.GetArrayFromImage(original_mask),0,1)
                out_mask_image_original_space_array = np.clip(sitk.GetArrayFromImage(out_mask_image_original_space),0,1)
                dice = compute_dice(out_mask_image_original_space_array, original_mask_array)
                dices.append(dice)

14it [00:46,  3.29s/it]


In [20]:
dices

[0.8513404395052372,
 0.9100924853410346,
 0.9165751480415626,
 0.9142655420725183,
 0.43112118995357535,
 0.8740405738258199,
 0.9179437127462833,
 0.9096868976777376,
 0.9355061261182154,
 0.8886765722969694,
 0.9017660041069219,
 0.37521614171300277,
 0.6221982711237583,
 0.9152808335922661]

In [21]:
np.mean(dices)

0.8116935670082074

In [7]:
dataset.augment

False

In [8]:
batch_idx

13

In [9]:
data, target = dataset.__getitem__(7)

In [10]:
dataset.mask_filenames[7]

'/home/raubyb/LiTS/Training Batch 1/segmentation-1.nii'

In [11]:
target_raw = sitk.ReadImage(dataset.mask_filenames[7])

In [12]:
data_raw = sitk.ReadImage(dataset.image_filenames[7])

In [13]:
target = mask

NameError: name 'mask' is not defined

In [None]:
print(target_raw.GetOrigin())
print(data_raw.GetOrigin())
print(target_raw.GetSpacing())
print(data_raw.GetSpacing())
print(target_raw.GetDirection())
print(data_raw.GetDirection())

In [None]:
target_raw_array = sitk.GetArrayFromImage(target_raw)
data_raw_array = sitk.GetArrayFromImage(data_raw)

In [None]:
plt.imshow(data_raw_array[135,:,:])

In [None]:
plt.imshow(target_raw_array[135,:,:])

In [None]:
data_raw.GetOrigin()

In [None]:
dices

In [None]:
np.mean(dices)

In [None]:
big_mask_numpy
h, w,d = big_mask_numpy.shape
plt.imshow(big_mask_numpy[:,:,d//2])

In [None]:
np.where(out_mask_image_original_space_array==1)

In [None]:
h, w,d = out_mask_image_original_space_array.shape
plt.imshow(out_mask_image_original_space_array[135,:,:])

In [None]:
plt.imshow(data[0,145,:,:])

In [None]:
plt.imshow(target["one_hot_target"][1,145,:,:])
# (target["one_hot_target"][1,:,:,:]==1)

In [None]:
h, w,d = original_mask_array.shape
plt.imshow(original_mask_array[135,:,:])

In [None]:
h, w,d = original_mask_array.shape
plt.imshow(original_mask_array[135,:,:])

In [None]:
volume = output.cpu().numpy()[0,1,:,:,:]
image = data.cpu().numpy()[0,0,:,:,:]
volume = np.round(volume)
mask = target.cpu().numpy()[0,1,:,:,:]
intersection = np.sum(mask*volume)
union = np.sum(np.clip(mask+volume,0,1))
print(intersection)
print(union)
print(intersection/union)

In [None]:
np.unique(volume)

In [None]:
fg, ax = plt.subplots(3,1,figsize=(10,10))
h,w,d = mask.shape
ax[0].imshow(image[:,:,d//2])
ax[1].imshow(mask[:,:,d//2])
ax[2].imshow(volume[:,:,d//2])