In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from dataset import *
from model import *
import os
import SimpleITK as sitk
import math
from itkwidgets import view 
%matplotlib widget

In [2]:
mode='gpu'

In [3]:
if mode=='gpu':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # after switch device, you need restart the kernel
    torch.cuda.set_device(1)
    torch.set_default_tensor_type('torch.cuda.DoubleTensor')
else:
    device = torch.device('cpu')
    torch.set_default_dtype(torch.float64)

## Testing
### initialization

In [5]:
epoch = 35
output_dir = '/home/sci/hdai/Projects/LnSeg/Models/UNet1024'
checkpoint = torch.load(f'{output_dir}/epoch_{epoch}_checkpoint.pth.tar')
model = UNet1024()

model.load_state_dict(checkpoint['model_state_dict'])
net = torch.nn.DataParallel(model, device_ids=[0, 1])

# params 120237649, # conv layers 62


### save files

In [6]:
case_info = []
root_dir = '/home/sci/hdai/Projects/Dataset/LymphNodes'
patch_size = 128
field_list = ['Series UID', 'Collection', '3rd Party Analysis', 
                      'Data Description URI', 'Subject ID', 'Study UID', 
                      'Study Description', 'Study Date', 'Series Description', 
                      'Manufacturer', 'Modality', 'SOP Class Name', 
                      'SOP Class UID', 'Number of Images', 'File Size', 
                      'File Location', 'Download Timestamp']
with open(f'{root_dir}/metadata.csv', mode='r') as infile:
    reader = csv.reader(infile)
    for row in reader:
        case_info.append({field_list[i]:row[i] for i in range(len(row))})

case_info = case_info[87:]
        
for case in tqdm(case_info):         
#         construct 3d CT from dicom folder
        # '/CT Lymph Nodes/ABD_LYMPH_003/09-14-2014-ABDLYMPH003-abdominallymphnodes-39052/abdominallymphnodes-65663'
    relative_ct_folder_path = case['File Location'][1:].replace('\\','/')
        # '/home/sci/hdai/Projects/LymphNodes/CT Lymph Nodes/ABD_LYMPH_003/09-14-2014-ABDLYMPH003-abdominallymphnodes-39052/abdominallymphnodes-65663'
    ct_folder_path = f'{root_dir}{relative_ct_folder_path}'
    slice_name_list = [f for f in os.listdir(ct_folder_path)]
    slice_name_list.sort()
    slice_list = []
    for slice_name in slice_name_list:
        ds = pd.dcmread(f'{ct_folder_path}/{slice_name}')
        slice_list.append(torch.from_numpy(ds.pixel_array.transpose()))
    img = torch.stack(slice_list,-1).to(device)
    
    case_name = case['File Location'][17:30].replace('\\','/')
    mask_path = f'{root_dir}/MED_ABD_LYMPH_MASKS/{case_name}/{case_name}_mask.nii.gz'
    mask = torch.from_numpy(nib.load(mask_path).get_fdata()).to(device)
    mask[mask>1] = 1
    
    half_patch_size = int(patch_size/2)
    idx_x, idx_y, idx_z = torch.where(mask!=0)
    centroid_x, centroid_y, centroid_z = 256, 256, 300
    if int(torch.mean(idx_x.float())) < mask.shape[0]-half_patch_size and int(torch.mean(idx_x.float())) > half_patch_size:
        centroid_x = int(torch.mean(idx_x.float()))
    if int(torch.mean(idx_y.float())) < mask.shape[1]-half_patch_size and int(torch.mean(idx_y.float())) > half_patch_size:
        centroid_y = int(torch.mean(idx_y.float()))
    if int(torch.mean(idx_z.float())) < mask.shape[2]-half_patch_size and int(torch.mean(idx_z.float())) > half_patch_size:
        centroid_z = int(torch.mean(idx_z.float()))
    img = img[centroid_x-half_patch_size:centroid_x+half_patch_size, centroid_y-half_patch_size:centroid_y+half_patch_size, centroid_z-half_patch_size:centroid_z+half_patch_size]
    mask_pred = model(img.unsqueeze(0).unsqueeze(0))
    sigmoid = torch.nn.Sigmoid()
    mask_pred = sigmoid(mask_pred).squeeze()
    
#     segment_depth = 128
#     segment_num = math.ceil(img.shape[2]/segment_depth)
#     mask_pred_segment_list = []
#     for i in range(segment_num):
#         begin = i*segment_depth
#         end = min(i*segment_depth+segment_depth,img.shape[2])
#         mask_pred_segment = model(img[192:320,192:320,begin:end].unsqueeze(0).unsqueeze(0))
#         mask_pred_segment_list.append(mask_pred_segment)
        
#     mask_pred = torch.stack(mask_pred_segment_list,-1)
    print(case_name)
    mask_path = f'{output_dir}/PredResult/{case_name}_pred_mask.nii.gz'
    nib.save(nib.Nifti1Image(mask_pred.cpu().detach().numpy(), None), mask_path)

  0%|          | 0/88 [00:04<?, ?it/s]


KeyboardInterrupt: 

## Visualization

In [4]:
case_info = []
root_dir = '/home/sci/hdai/Projects/Dataset/LymphNodes'
patch_size = 128
field_list = ['Series UID', 'Collection', '3rd Party Analysis', 
                      'Data Description URI', 'Subject ID', 'Study UID', 
                      'Study Description', 'Study Date', 'Series Description', 
                      'Manufacturer', 'Modality', 'SOP Class Name', 
                      'SOP Class UID', 'Number of Images', 'File Size', 
                      'File Location', 'Download Timestamp']
with open(f'{root_dir}/metadata.csv', mode='r') as infile:
    reader = csv.reader(infile)
    for row in reader:
        case_info.append({field_list[i]:row[i] for i in range(len(row))})

case_info = case_info[87:]

In [12]:
idx = 50 #1

relative_ct_folder_path = case_info[idx]['File Location'][1:].replace('\\','/')
# '/home/sci/hdai/Projects/LymphNodes/CT Lymph Nodes/ABD_LYMPH_003/09-14-2014-ABDLYMPH003-abdominallymphnodes-39052/abdominallymphnodes-65663'
ct_folder_path = f'{root_dir}{relative_ct_folder_path}'
slice_name_list = [f for f in os.listdir(ct_folder_path)]
slice_name_list.sort()
slice_list = []
for slice_name in slice_name_list:
    ds = pd.dcmread(f'{ct_folder_path}/{slice_name}')
    slice_list.append(torch.from_numpy(ds.pixel_array.transpose()))
img = torch.stack(slice_list,-1)

case_name = case_info[idx]['File Location'][17:30].replace('\\','/')
mask_path = f'/home/sci/hdai/Projects/Dataset/LymphNodes/MED_ABD_LYMPH_MASKS/{case_name}/{case_name}_mask.nii.gz'
mask = torch.from_numpy(nib.load(mask_path).get_fdata())
mask[mask>1] = 1

mask_pred_path = f'/home/sci/hdai/Projects/LnSeg/Models/UNet1024/PredResult/{case_name}_pred_mask.nii.gz'
mask_pred = torch.from_numpy(nib.load(mask_pred_path).get_fdata())
# mask_pred[mask_pred>=0.5] = 1
# mask_pred[mask_pred<0.5] = 0

In [6]:
half_patch_size = int(patch_size/2)
idx_x, idx_y, idx_z = torch.where(mask!=0)
centroid_x, centroid_y, centroid_z = 256, 256, 300
if int(torch.mean(idx_x.float())) < mask.shape[0]-half_patch_size and int(torch.mean(idx_x.float())) > half_patch_size:
    centroid_x = int(torch.mean(idx_x.float()))
if int(torch.mean(idx_y.float())) < mask.shape[1]-half_patch_size and int(torch.mean(idx_y.float())) > half_patch_size:
    centroid_y = int(torch.mean(idx_y.float()))
if int(torch.mean(idx_z.float())) < mask.shape[2]-half_patch_size and int(torch.mean(idx_z.float())) > half_patch_size:
    centroid_z = int(torch.mean(idx_z.float()))
img = img[centroid_x-half_patch_size:centroid_x+half_patch_size, centroid_y-half_patch_size:centroid_y+half_patch_size, centroid_z-half_patch_size:centroid_z+half_patch_size]
mask = mask[centroid_x-half_patch_size:centroid_x+half_patch_size, centroid_y-half_patch_size:centroid_y+half_patch_size, centroid_z-half_patch_size:centroid_z+half_patch_size]

In [7]:
view(img)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageSS3; pr…

In [8]:
view(mask)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageD3; pro…

In [15]:
view(mask_pred)

Viewer(geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.itkImagePython.itkImageD3; pro…

## DSC evaluation

In [20]:
dice_score_list = []

for idx in tqdm(range(len(case_info))):
#     idx = 50 #1

    relative_ct_folder_path = case_info[idx]['File Location'][1:].replace('\\','/')
    # '/home/sci/hdai/Projects/LymphNodes/CT Lymph Nodes/ABD_LYMPH_003/09-14-2014-ABDLYMPH003-abdominallymphnodes-39052/abdominallymphnodes-65663'
    ct_folder_path = f'{root_dir}{relative_ct_folder_path}'
    slice_name_list = [f for f in os.listdir(ct_folder_path)]
    slice_name_list.sort()
    slice_list = []
    for slice_name in slice_name_list:
        ds = pd.dcmread(f'{ct_folder_path}/{slice_name}')
        slice_list.append(torch.from_numpy(ds.pixel_array.transpose()))
    img = torch.stack(slice_list,-1)

    case_name = case_info[idx]['File Location'][17:30].replace('\\','/')
    mask_path = f'/home/sci/hdai/Projects/Dataset/LymphNodes/MED_ABD_LYMPH_MASKS/{case_name}/{case_name}_mask.nii.gz'
    mask = torch.from_numpy(nib.load(mask_path).get_fdata())
    mask[mask>1] = 1
    
    half_patch_size = int(patch_size/2)
    idx_x, idx_y, idx_z = torch.where(mask!=0)
    centroid_x, centroid_y, centroid_z = 256, 256, 300
    if int(torch.mean(idx_x.float())) < mask.shape[0]-half_patch_size and int(torch.mean(idx_x.float())) > half_patch_size:
        centroid_x = int(torch.mean(idx_x.float()))
    if int(torch.mean(idx_y.float())) < mask.shape[1]-half_patch_size and int(torch.mean(idx_y.float())) > half_patch_size:
        centroid_y = int(torch.mean(idx_y.float()))
    if int(torch.mean(idx_z.float())) < mask.shape[2]-half_patch_size and int(torch.mean(idx_z.float())) > half_patch_size:
        centroid_z = int(torch.mean(idx_z.float()))
#     img = img[centroid_x-half_patch_size:centroid_x+half_patch_size, centroid_y-half_patch_size:centroid_y+half_patch_size, centroid_z-half_patch_size:centroid_z+half_patch_size]
    mask = mask[centroid_x-half_patch_size:centroid_x+half_patch_size, centroid_y-half_patch_size:centroid_y+half_patch_size, centroid_z-half_patch_size:centroid_z+half_patch_size]

    mask_pred_path = f'/home/sci/hdai/Projects/LnSeg/Models/UNet1024/PredResult/{case_name}_pred_mask.nii.gz'
    mask_pred = torch.from_numpy(nib.load(mask_pred_path).get_fdata())
    
    threshold = 0.05
    mask_pred[mask_pred>=threshold]=1
    mask_pred[mask_pred<threshold]=0
    
    dice_score = torch.sum(2*mask*mask_pred)/torch.sum(mask+mask_pred)
    dice_score_list.append(dice_score.item())

100%|██████████| 88/88 [17:53<00:00, 12.20s/it]


In [21]:
dice_score_list

[0.037837500334489925,
 0.07237193678906786,
 0.21423269126278977,
 0.12874634536612703,
 0.19702454373820014,
 0.27397494062264444,
 0.13945103833029213,
 0.034085166980568696,
 0.41503299516883674,
 0.026016720742991536,
 0.009768211299464698,
 0.1736218720753617,
 0.13160417350108,
 0.18062853382305818,
 0.10605138195543651,
 0.058498294330489885,
 0.06520400526709674,
 0.05287595701641331,
 0.3022813903674629,
 0.03110995012285614,
 0.11043663951929895,
 0.14953188835677222,
 0.06161135511603011,
 0.2649540820269957,
 0.3959286896250273,
 0.0,
 0.09487177248238345,
 0.0054200706675133,
 0.0036163749457543757,
 0.031155581308720665,
 0.10472545052626465,
 0.04981060606060606,
 0.003690313321794533,
 0.16764394407840963,
 0.12997151676931565,
 0.16021520323706545,
 0.03507762362860753,
 0.09390544052598063,
 0.06882976478138582,
 0.2410515710836961,
 0.03467887782294831,
 0.0,
 0.012413556118723423,
 0.004694795750869643,
 0.28851732715031375,
 0.07543336259957802,
 0.049290097001688