In [1]:
import numpy as np, torch, os, glob, nibabel as nib, multiprocessing, pandas as pd, sys
import torch.backends.cudnn as cudnn, torchio as tio, random
import segmentation_models_pytorch as smp
import torch.backends.cudnn as cudnn
import matplotlib.pyplot as plt
import albumentations as A
from sklearn.metrics import *
sys.path.append("..")
from utils.model_res import generate_model
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
from scipy import ndimage
from tqdm import tqdm   

num_workers = multiprocessing.cpu_count()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True
random.seed(1234)
torch.manual_seed(1234)

<torch._C.Generator at 0x7fc8bb2aa450>

In [2]:
# load csv label
if True:
    csv_path = '../20211104_label_1-350_1.5&3.0.csv'
    table_3t =  pd.read_csv(csv_path)
    # table_3t = table[table['1/0: 3T/1.5T MRI']==1.0]
    table_3t_test = table_3t[table_3t['Valid data']=='V']
    table_3t_test = np.array(table_3t_test[table_3t_test['排除']=='Test data'])
    nii_3t_test = sorted([i for i in os.listdir(os.path.join('../dataset/S2_data1.5&3.0/','test'))])

In [3]:
# Subject Function Building
def tio_process(nii_3t_, table_3t_, basepath_='./dataset/S2_data/train/'):
    subjects_ = []
    for  (nii_path, nii_table) in zip(nii_3t_ , table_3t_):
        if (params['S2_type']=='ap') and (nii_table[3]=='A' or nii_table[3]=='P'):
            subject = tio.Subject(
                dwi = tio.ScalarImage(os.path.join(basepath_, nii_path)), 
                ap = nii_table[3], 
                score=[])
            subjects_.append(subject)
        elif (params['S2_type']=='nl'):
            subject = tio.Subject(
                dwi = tio.ScalarImage(os.path.join(basepath_, nii_path)), 
                nl  = nii_table[4], 
                score=[])
            subjects_.append(subject)
    return subjects_


In [4]:
def S1_evaluate(model, valid_loader):
    predict_array = []
    model.eval()
    stream = tqdm(valid_loader)
    with torch.no_grad():
        for i, images in enumerate(stream, start=1):
            images = images['dwi'][tio.DATA]
            # torch.Size([1, 1, 384, 384, 26])
            images = images.squeeze(0)
            images = images.permute(3,0,1,2)
            # .permute(1,0,2).
            images = images.to(device)
            output =  model(images)
            output = torch.sigmoid(output)
            output = (output - output.min()) / (output.max() - output.min() + 1e-8) #****!!!!
            # print(torch.max(output), torch.min(output))
            output = output.cpu() 
            images = images.cpu() 
            # torch.Size([26, 1, 384, 384])
            pred_mask = torch.where(output>0.5, images, images*0)
            pred_mask = pred_mask.cpu() 
            if False:
                for idx in range(26):
                    fig = plt.figure()
                    
                    ax1 = fig.add_subplot(1,3,1)
                    ax1.imshow(np.squeeze(images[idx], axis=0), cmap='bone')
                    ax1.get_xaxis().set_visible(False)
                    ax1.get_yaxis().set_visible(False)

                    ax2 = fig.add_subplot(1,3,2)
                    ax2.imshow(np.squeeze(output[idx]>0.3, axis=0), cmap='bone')
                    ax2.get_xaxis().set_visible(False)
                    ax2.get_yaxis().set_visible(False)
                    
                    ax3 = fig.add_subplot(1,3,3)
                    ax3.imshow(np.squeeze(pred_mask[idx], axis=0), cmap='bone')
                    ax3.get_xaxis().set_visible(False)
                    ax3.get_yaxis().set_visible(False)
                    plt.show()
                    plt.close(fig)
            
            pred_mask = (pred_mask.permute(1,2,3,0))
            # torch.Size([1, 1, 384, 384, 26])
            predict_array.append(pred_mask)
            # if i ==1:
            #     break
    return predict_array

def change_subject_img(subject_, tesnor_ary):
    for idx, i in enumerate(subject_):
        if False:
            print( subject_[idx]['dwi'].shape)
            for idx2 in range(26):
                fig = plt.figure(figsize=(12,12))
                ax1 = fig.add_subplot(1,3,1)
                ax1.imshow(np.squeeze((subject_[idx]['dwi'][tio.DATA])[...,idx2], axis=0), cmap='bone')
                ax1.get_xaxis().set_visible(False)
                ax1.get_yaxis().set_visible(False)
                ax2 = fig.add_subplot(1,3,2)
                ax2.imshow(np.squeeze((tesnor_ary[idx])[...,idx2], axis=0), cmap='bone')
                ax2.get_xaxis().set_visible(False)
                ax2.get_yaxis().set_visible(False)
                plt.show()
                plt.close(fig)
        subject_[idx]['dwi'].set_data(tesnor_ary[idx])
        # image =  subject_[idx]['dwi'][tio.DATA]
    return subject_

def label2value(label):
    if params["S2_type"]=='nl':
        target = [0 if i=='N' else 1 for i in label]
    else:
        target = [0 if i=='A' else 1 for i in label]
    return torch.LongTensor(target).to(device)

def S2_evaluate(model,valid_loader):
    predict_array = {'target':[], 'predict':[]}
    model.eval()
    stream_v = tqdm(S2_subjects)
    with torch.no_grad():
        for i, data in enumerate(stream_v, start=1):
            images = data['dwi'][tio.DATA].to(device).unsqueeze(0)
            target = label2value(data[params["S2_type"]])
            images = images.to(device)
            target = target.to(device)
            output = model(images).squeeze(1)
            output = torch.sigmoid(output)
            output = (output - output.min()) / (output.max() - output.min() + 1e-8) #****!!!!
            # print(output)
            _, outputs = torch.max(output, 1)
            predict_array['predict'].append(outputs.item())
            predict_array['target'].append(target.item())
    return predict_array
if True:
    test_transform = tio.Compose([])
    # test_transform = tio.Compose([tio.ZNormalization(masking_method=tio.ZNormalization.mean)])
    test_transform2 = tio.Compose([tio.ZNormalization(masking_method=tio.ZNormalization.mean)])
    S1_weight = '../checkpoint/2021.11.23.t2 - 2DDenseNet121Unet/2DDenseNet121Unet - lr_0.001 - FTL --  epoch:105 | vDice:0.7908 | vLoss:0.0634.pt'
    S2_weight_ap = '../checkpoint/2021.11.25.t2 - 3DResNet18 - ap/ap - 3dresnet18 - lr_0.001 - FL --  epoch:93 | vLoss:0.01088 | vAcc:100.0.pt'
    S2_weight_nl = '../checkpoint/2021.11.24.t2 - 3DResNet18 - nl/nl - 3dresnet18 - lr_0.001 - FL --  epoch:22 | vLoss:0.01902 | vAcc:93.75.pt'
    S2_weight_stack = {'ap': S2_weight_ap, 'nl': S2_weight_nl}
    S1_checkpoint = torch.load(S1_weight, map_location=torch.device(device))
    S1_model = smp.Unet(encoder_name='densenet121', encoder_weights=None, in_channels=1, classes=1)
    S1_model.load_state_dict(S1_checkpoint['model_state_dict'])
    S1_model.to(device)
    S2_reply = {'ap': [], 'nl':[]}
for idx, i in enumerate(['ap', 'nl']):
# for idx, i in enumerate(['nl']):
    # ///////////////////////////////////////////////////
    S2_model = generate_model(model_depth=18, n_input_channels=1, n_classes=2)
    S2_checkpoint= torch.load(S2_weight_stack[i], map_location=torch.device(device))
    S2_model.load_state_dict(S2_checkpoint['model_state_dict'])
    S2_model.to(device)
    # //////////////////////////////////////////////////
    params = {"S1_type": None, "S2_type": i}
    S1_subjects = tio_process(nii_3t_test, table_3t_test, basepath_ = '../dataset/S2_data1.5&3.0/test/')
    S1_set = tio.SubjectsDataset(S1_subjects, transform=test_transform)
    test_loader = torch.utils.data.DataLoader(S1_set, batch_size=1, shuffle=False, num_workers=4)
    S1_reply = S1_evaluate(S1_model, test_loader)
    S2_subjects = change_subject_img(S1_subjects, S1_reply)
    S2_reply[i].append(S2_evaluate(S2_model, S2_subjects))

100%|██████████| 78/78 [00:24<00:00,  3.14it/s]
100%|██████████| 78/78 [00:05<00:00, 13.65it/s]
100%|██████████| 80/80 [00:21<00:00,  3.64it/s]
100%|██████████| 80/80 [00:05<00:00, 14.87it/s]


In [5]:
for i in ['ap', 'nl']:
    GT =np.array(S2_reply[i][0]['target'])
    SR = np.array(S2_reply[i][0]['predict'])
    TP = int((SR * GT).sum()) #TP
    FN = int((GT * (1-SR)).sum()) #FN
    TN = int(((1-GT) * (1-SR)).sum()) #TN
    FP = int(((1-GT) * SR).sum()) #FP
    print(f'{i}  -   Accuracy  :', round((TP + TN)/(TP + TN + FP + FN), 2)*100, '%')
    print(f'{i}  -   Sensitivity  :', round(float(TP)/(float(TP+FN) + 1e-6), 5))
    print(f'{i}  -   Specificity  :', round(float(TN)/(float(TN+FP) + 1e-6), 5))

ap  -   Accuracy  : 91.0 %
ap  -   Sensitivity  : 0.96296
ap  -   Specificity  : 0.88235
nl  -   Accuracy  : 93.0 %
nl  -   Sensitivity  : 0.97297
nl  -   Specificity  : 0.88372


## 11/23 only 3.0T
#### ap  -   Accuracy  : 91.0 %
#### ap  -   Sensitivity  : 0.71429
#### ap  -   Specificity  : 1.0
#### nl  -   Accuracy  : 93.0 %
#### nl  -   Sensitivity  : 0.95833
#### nl  -   Specificity  : 0.90476

In [6]:



# for i in test_loader2:
#     print(i['dwi'][tio.DATA].shape)
#     image = i['dwi'][tio.DATA].squeeze(0)
#     for idx2 in range(26):
#         fig = plt.figure()
#         ax1 = fig.add_subplot(1,1,1)
#         print(idx2+1)
#         ax1.imshow(np.squeeze(image[...,idx2], axis=0), cmap='bone')
#         # ax1.get_xaxis().set_visible(False)
#         # ax1.get_yaxis().set_visible(False)
#         plt.show()
#         plt.close(fig)