In [1]:
import numpy as np, torch, os, sys, multiprocessing, pandas as pd
import torch.backends.cudnn as cudnn, torchio as tio, random
import segmentation_models_pytorch as smp
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
import albumentations as A
from sklearn.metrics import *
sys.path.append("..")
from utils.model_res import generate_model
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from scipy import ndimage
from tqdm import tqdm   

num_workers = multiprocessing.cpu_count()
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True
random.seed(1234)
torch.manual_seed(1234)

<torch._C.Generator at 0x7fa74ad49490>

In [2]:
# load csv label
if True:
    csv_path = '../20211104_label_1-350.csv'
    table =  pd.read_csv(csv_path)
    table_3t = table[table['1/0: 3T/1.5T MRI']==1.0]
    table_3t_test = table_3t[table_3t['Valid data']=='V']
    table_3t_test = np.array(table_3t_test[table_3t_test['排除']=='Test data'])
    nii_3t_test = sorted([i for i in os.listdir(os.path.join('../dataset/S2_data/','test'))])

In [3]:
# Subject Function Building
def tio_process(nii_3t_, table_3t_, basepath_='../dataset/S2_data/train/'):
    subjects_ = []
    for  (nii_path, nii_table) in zip(nii_3t_ , table_3t_):
        # print(nii_path)
        if (params['S2_type']=='ap') and (nii_table[3]=='A' or nii_table[3]=='P'):
            subject = tio.Subject(
                dwi = tio.ScalarImage(os.path.join(basepath_, nii_path)), 
                msk = tio.ScalarImage(os.path.join('../dataset/S2_data/test_mask/', nii_path)), 
                ap = nii_table[3], 
                score=[])
            subjects_.append(subject)
        elif (params['S2_type']=='nl'):
            subject = tio.Subject(
                dwi = tio.ScalarImage(os.path.join(basepath_, nii_path)), 
                msk = tio.ScalarImage(os.path.join('../dataset/S2_data/test_mask/', nii_path)), 
                nl  = nii_table[4], 
                score=[])
            subjects_.append(subject)
    return subjects_


In [4]:
def S1_evaluate(model, valid_loader):
    predict_array = []
    ground_array = []
    model.eval()
    stream = tqdm(valid_loader)
    with torch.no_grad():
        for i, images in enumerate(stream, start=1):
            img = images['dwi'][tio.DATA]
            msk = images['msk'][tio.DATA]
            # torch.Size([1, 1, 384, 384, 26])
            img = img.squeeze(0)
            msk = msk.squeeze(0)
            img = img.permute(3,0,1,2)
            msk = msk.permute(3,0,1,2)
            img = img.to(device)
            msk = msk.to(device)
            output =  model(img)
            if False:
                print(images['dwi']['path'])
                for idx in range(26):
                    fig = plt.figure()
                    
                    ax1 = fig.add_subplot(1,3,1)
                    ax1.imshow(np.squeeze(img[idx], axis=0), cmap='bone')
                    ax1.set_title("Ground Truth")
                    ax1.get_xaxis().set_visible(False)
                    ax1.get_yaxis().set_visible(False)

                    ax2 = fig.add_subplot(1,3,2)
                    ax2.imshow(np.squeeze(msk[idx]>0.05, axis=0), cmap='bone')
                    ax2.set_title("Raw Masks")
                    ax2.get_xaxis().set_visible(False)
                    ax2.get_yaxis().set_visible(False)
                    
                    ax3 = fig.add_subplot(1,3,3)
                    ax3.imshow(np.squeeze(output[idx]>0.1, axis=0), cmap='bone')
                    ax3.set_title("Predict Masks")
                    ax3.get_xaxis().set_visible(False)
                    ax3.get_yaxis().set_visible(False)
                    plt.show()
                    plt.close(fig)
            
            predict_array.append(output)
            ground_array.append(msk)
    return predict_array, ground_array

def change_subject_img(subject_, tesnor_ary):
    for idx, i in enumerate(subject_):
        if False:
            print( subject_[idx]['dwi'].shape)
            for idx2 in range(26):
                fig = plt.figure(figsize=(12,12))
                ax1 = fig.add_subplot(1,3,1)
                ax1.imshow(np.squeeze((subject_[idx]['dwi'][tio.DATA])[...,idx2], axis=0), cmap='bone')
                ax1.get_xaxis().set_visible(False)
                ax1.get_yaxis().set_visible(False)
                ax2 = fig.add_subplot(1,3,2)
                ax2.imshow(np.squeeze((tesnor_ary[idx])[...,idx2], axis=0), cmap='bone')
                ax2.get_xaxis().set_visible(False)
                ax2.get_yaxis().set_visible(False)
                plt.show()
                plt.close(fig)
        subject_[idx]['dwi'].set_data(tesnor_ary[idx])
        # image =  subject_[idx]['dwi'][tio.DATA]
    return subject_

def label2value(label):
    if params["S2_type"]=='nl':
        target = [0 if i=='N' else 1 for i in label]
    else:
        target = [0 if i=='A' else 1 for i in label]
    return torch.LongTensor(target).to(device)

if True:
    test_transform = tio.Compose([])
    S1_weight = '../checkpoint/2021.11.18.t3 - 2DDenseNet121Unet/2DDenseNet121Unet - lr_0.001 - FTL --  epoch:101 | vDice:0.7688 | vLoss:0.06955.pt'
    S1_checkpoint = torch.load(S1_weight, map_location=torch.device(device))
    S1_model = smp.Unet(encoder_name='densenet121', encoder_weights=None, in_channels=1, classes=1)
    S1_model.load_state_dict(S1_checkpoint['model_state_dict'])
    S1_model.to(device)
for idx, i in enumerate(['nl']):
    params = {"S1_type": None, "S2_type": i}
    S1_subjects = tio_process(nii_3t_test, table_3t_test, basepath_ = '../dataset/S2_data/test/')
    S1_set = tio.SubjectsDataset(S1_subjects, transform=test_transform)
    test_loader = torch.utils.data.DataLoader(S1_set, batch_size=1, shuffle=False, num_workers=6)
    S1_reply, S1_ans = S1_evaluate(S1_model, test_loader)




100%|██████████| 45/45 [07:38<00:00, 10.18s/it]


In [78]:
class avg_metric:
        def __init__(self):
            self.dice_sum = 0.0
            self.iou_sum = 0.0
            self.sens_sum = 0.0
            self.spec_sum =0.0
        def metric_calc(self, batch_inputs, batch_targets):
                smooth = 1
                # print(inputs.max(), inputs.min())
                inputs = (batch_inputs.contiguous().view(-1))
                inputs = torch.sigmoid(inputs)
                targets = batch_targets.contiguous().view(-1)
                # print(inputs.shape, targets.shape)
                intersection = (inputs * targets).sum()
                total = (inputs + targets).sum()
                union = total - intersection 
                TP = int(intersection) #TP
                FN = int((targets * (1-inputs)).sum()) #FN
                TN = int(((1-targets) * (1-inputs)).sum()) #TN
                FP = int(((1-targets) * inputs).sum()) #FP

                smooth = 1
                y_true_f = np.array(batch_targets.flatten())
                y_pred_f = np.array(torch.sigmoid(batch_inputs).flatten())
                intersections = np.sum(y_true_f * y_pred_f)
                dice =  (2. * intersections + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)
                # print(dice)
                
                # print(TP,FN,TN,FP)
                self.dice_sum += ((2.*intersection)/(inputs.sum() + targets.sum())).item()
                self.iou_sum +=((intersection)/(union)).item()
                self.sens_sum+= round(float(TP)/(float(TP+FN)), 5)
                self.spec_sum += round(float(TN)/(float(TN+FP)), 5)
        def return_metric(self, len):
                return {'Dice': round(self.dice_sum/len, 5) , 
                                'IoU':round(self.iou_sum/len, 5) , 
                                'Sensitivity': round(self.sens_sum/len, 5) , 
                                'Specificity': round(self.spec_sum/len, 5) , }


In [79]:
metric_class = avg_metric()
for (inputs, targets) in zip(S1_reply, S1_ans):
    metric_class.metric_calc(inputs, targets)
print(metric_class.return_metric(len(S1_reply)))
# {'Dice': 0.68367, 'IoU': 0.55184, 'Sensitivity': 0.73547, 'Specificity': 0.99978} 0.2
# {'Dice': 0.60646, 'IoU': 0.48663, 'Sensitivity': 0.54321, 'Specificity': 0.9999} 0.0001
# {'Dice': 0.59049, 'IoU': 0.46844, 'Sensitivity': 0.50486, 'Specificity': 0.99993} 0.1
# {'Dice': 0.58312, 'IoU': 0.46012, 'Sensitivity': 0.49119, 'Specificity': 0.99994} 0.5
# {'Dice': 0.57463, 'IoU': 0.45072, 'Sensitivity': 0.47765, 'Specificity': 0.99994} 0.9

{'Dice': 0.68262, 'IoU': 0.55184, 'Sensitivity': 0.73547, 'Specificity': 0.99978}


In [7]:
# for idx2 in range(26):
#     fig = plt.figure(figsize=(12,12))
#     ax1 = fig.add_subplot(1,3,1)
#     print(max(S1_reply[0][idx2].flatten()), min(S1_reply[0][idx2].flatten()))
#     ax1.imshow(np.squeeze(S1_reply[0][idx2], axis=0), cmap='bone')
#     ax1.get_xaxis().set_visible(False)
#     ax1.get_yaxis().set_visible(False)

#     plt.show()
#     plt.close(fig)