In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from scipy import ndimage
import json
from collections import defaultdict

import pathlib
import os, sys
from os.path import join, isfile, dirname, basename, exists
import pickle
import shutil
import random
from itertools import repeat
from functools import partial
import multiprocessing as mp
import multiprocessing

import cc3d
import SimpleITK as sitk
import torchio as tio

from skimage.measure import regionprops
from scipy.ndimage.morphology import binary_dilation
from skimage.morphology import disk, ball, octagon, octahedron

# relative import of get_recist
from rtnet.utils.get_recist import get_RECIST

  from .autonotebook import tqdm as notebook_tqdm


#### List data files

In [2]:
lns_suffix = "_CT_TaskLNSSeg_pred.nii.gz"
data_root = "/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred"
patient_folder = [os.path.join(data_root, p) for p in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, p)) and p.isnumeric()]
print(patient_folder[:5])

['/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred/78739', '/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred/128367', '/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred/163816', '/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred/239962', '/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred/165081']


#### Load testing list and filter out training cases

In [3]:
data_root = '/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred'
dataset_json_path = "/data/yirui/datasets/LN_classify/rtNetData/images/rtNet_raw_data/Task020_Chuanzhong_9LNS"
with open(os.path.join(dataset_json_path, "dataset.json"), 'rb') as f:
    dataset = json.load(f)
testset = dataset['test']
test_patient = [(os.path.join(data_root, p.split('/')[2]), p.split('_')[1]) for p in testset]
print(len(test_patient))
test_patient[:5]

2401


[('/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred/100241',
  'T07'),
 ('/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred/100306',
  'T01'),
 ('/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred/100306',
  'T03.P'),
 ('/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred/100346',
  'T05'),
 ('/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/Cleaned_SichuanData_CorrectHU_Pred/100346',
  'T07')]

#### RECIST-based clasification

In [4]:
ln_suffix = "_LN_Mediastinal.nii.gz"
lns_suffix = "_CT_TaskLNSSeg_pred.nii.gz"

In [5]:
label_to_LNS = {
    1: 'T01',
    2: 'T01',
    3: 'T02',
    4: 'T02',
    6: 'T03.P',
    7: 'T04',
    8: 'T04',
    9: 'T05',
    11: 'T07',
    12: 'T08',
}

lns_to_label = {v:k for k,v in label_to_LNS.items()}

lns_to_df_col_map = {
    'T01': ['Meta LNs 1L', 'Meta LNs 1R'],
    'T02': ['Meta LNs 2L', 'Meta LNs 2R'],
    'T03.P': ['Meta LNs 8U'],
    'T04': ['Meta LNs 4L', 'Meta LNs 4R'],
    'T05': ['Meta LNs 5'],
    'T07': ['Meta LNs 7'],
    'T08': ['Meta LNs 8M', 'Meta LNs 8L']
}

In [6]:
def get_diameter_stats(data, spacing):
    labels = np.unique(data.reshape(-1))
    labels = np.delete(labels, np.where(labels == 0))
    num_of_label = len(labels)
    
    short_diameters = []
    for n in labels:
        tmp_dat = np.copy(data)
        tmp_dat[tmp_dat != n] = 0
        tmp_dat[tmp_dat > 0] = 1
        short_d = get_RECIST(tmp_dat.transpose(2,0,1), spacing[0])
        short_diameters.append(short_d)
    return short_diameters

In [7]:
def process_patient(case, df, all_stats):
    folder, lns_name = case
    cur_patient = {}
    patient_id = folder.split('/')[-1]
    lymph_node_mask_pth = os.path.join(folder, patient_id + ln_suffix)
    lns_mask_pth = os.path.join(folder, patient_id + lns_suffix)
    
    ln_mask = tio.LabelMap(lymph_node_mask_pth)
    lns_mask = tio.LabelMap(lns_mask_pth)

    # statistics for patient-wise lymph node
    ln_dat = ln_mask.numpy().squeeze().astype(int)
    ln_siz = ln_mask.spacing # W, H, D
    ln_vox = np.prod(ln_siz)
    
    lns_dat = lns_mask.numpy().squeeze().astype(int)

    cur_lns_dat = (lns_dat == lns_to_label[lns_name]).astype(int)
    cur_ln_dat = ln_dat * cur_lns_dat
    ln_short_diameter_stats = get_diameter_stats(cur_ln_dat, ln_mask.spacing)
    max_recist = np.max(ln_short_diameter_stats) if ln_short_diameter_stats else 0
    
    df_row = df.loc[int(patient_id)]
    try:
        labels = (np.sum([int(df_row[p]) for p in lns_to_df_col_map[lns_name]]) > 0) * 1.0
    except:
        labels = [df_row[p] for p in lns_to_df_col_map[lns_name]]
        print(labels)
    # labels = 1 if 1 in labels else 0
        
    all_stats[patient_id + '_' + lns_name] = [max_recist, labels]

csv_pth = "/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/ChuanzhongClinical_mined_with_MRN.csv"
df_main = pd.read_csv(csv_pth)
df_main = df_main.set_index('MRN')
df_main = df_main.drop_duplicates()

manager = mp.Manager()
all_stats = manager.dict()
pool = multiprocessing.Pool(processes=30, maxtasksperchild=1)
pool.map(partial(process_patient, df=df_main, all_stats=all_stats), test_patient)
pool.close()

all_results = dict(all_stats)
print("Total {} cases".format(len(all_results)))

Total 2401 cases


In [None]:
# csv_pth = "/data/yirui/datasets/LN_classify/Sichuan_thoracic_LN/ChuanzhongClinical_mined_with_MRN.csv"
# df_main = pd.read_csv(csv_pth)
# df_main = df_main.set_index('MRN')
# df_main = df_main.drop_duplicates()

# manager = mp.Manager()
# all_stats = manager.dict()
# pool = multiprocessing.Pool(processes=30, maxtasksperchild=1)
# pool.map(partial(process_patient, df=df_main, all_stats=all_stats), test_patient)
# pool.close()

# all_results = dict(all_stats)
# print("Total {} cases".format(len(all_results)))

In [8]:
per_station_pred = defaultdict(list)
for k, v in all_results.items():
    cur_lns = k.split('_')[1]
    per_station_pred[cur_lns].append(v)

In [9]:
np.array(per_station_pred['T02'])[:,0].max()

19.152578711509705

In [42]:
for k, v in per_station_pred.items():
    pred_label = [[p[0], p[1]] for p in v]
    pred_label = np.array(pred_label)
    preds = (pred_label[:, 0] > 9) * 1.0
    labels = pred_label[:, 1]
    acc = np.sum(preds == labels) / labels.shape[0]
    recall = np.sum(preds[labels == 1]) / np.sum(labels)
    specificity = 1 - np.sum(preds[labels == 0]) / np.sum(labels == 0)
    print("LNS {} pos ratio {} / {}".format(k, np.sum(labels), labels.shape[0]))
    print("LNS {} -- Accuracy: {:.4f}, recall: {:.4f}, specificity: {:.4f}".format(k, acc, recall, specificity))

LNS T02 pos ratio 83.0 / 383
LNS T02 -- Accuracy: 0.7963, recall: 0.2651, specificity: 0.9433
LNS T05 pos ratio 6.0 / 354
LNS T05 -- Accuracy: 0.9746, recall: 0.0000, specificity: 0.9914
LNS T07 pos ratio 54.0 / 337
LNS T07 -- Accuracy: 0.5994, recall: 0.5556, specificity: 0.6078
LNS T01 pos ratio 18.0 / 335
LNS T01 -- Accuracy: 0.8925, recall: 0.3889, specificity: 0.9211
LNS T03.P pos ratio 17.0 / 343
LNS T03.P -- Accuracy: 0.9417, recall: 0.0000, specificity: 0.9908
LNS T08 pos ratio 58.0 / 338
LNS T08 -- Accuracy: 0.8136, recall: 0.1207, specificity: 0.9571
LNS T04 pos ratio 13.0 / 311
LNS T04 -- Accuracy: 0.8682, recall: 0.2308, specificity: 0.8960
