In [1]:
# load library
import argparse
import os
import numpy as np
from tqdm import tqdm

from mypath import Path
from dataloaders import make_data_loader
from modeling.sync_batchnorm.replicate import patch_replication_callback
from modeling.deeplab import *
from utils.loss import SegmentationLosses
from utils.calculate_weights import calculate_weigths_labels
from utils.lr_scheduler import LR_Scheduler
from utils.saver import Saver
# from utils.summaries import TensorboardSummary
from utils.metrics import Evaluator

from dataloaders.datasets.lits import LiverSegmentation, TumorSegmentation
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import argparse
from PIL import Image
import cv2
import time
import torch

In [2]:
from scipy.ndimage import morphology

def surfd(input1, input2, sampling=1, connectivity=1):
    
    input_1 = np.atleast_1d(input1.astype(np.bool))
    input_2 = np.atleast_1d(input2.astype(np.bool))
    

    conn = morphology.generate_binary_structure(input_1.ndim, connectivity)

    S = input_1 ^ morphology.binary_erosion(input_1, conn)
    Sprime = input_2 ^ morphology.binary_erosion(input_2, conn)

    
    dta = morphology.distance_transform_edt(~S,sampling)
    dtb = morphology.distance_transform_edt(~Sprime,sampling)
    
    sds = np.concatenate([np.ravel(dta[Sprime!=0]), np.ravel(dtb[S!=0])])
       
    
    return sds

In [9]:
parser = argparse.ArgumentParser()
parser.add_argument('-f')
parser.add_argument('--batch-size', type=int, default=200)
parser.add_argument('--base-size', type=int, default=256)
parser.add_argument('--crop-size', type=int, default=256)
parser.add_argument('--mode', type=str, default='val')
parser.add_argument('--kind', type=str, default='liver')
parser.add_argument('--model-path', type=str, default='models/95_liver33.pth.tar')
parser.add_argument('--backbone', type=str, default='xception')
# parser.add_argument('--model-path', type=str, default='models/95_liver33.pth.tar')
# parser.add_argument('--backbone', type=str, default='xception')

args = parser.parse_args()

In [10]:
# load model
model = DeepLab(num_classes=2, backbone=args.backbone, output_stride=16, sync_bn=False, freeze_bn=False)

# ckpt = torch.load('run/lits_tumor/resume-tumor-bce-crop/experiment_0/_checkpoint37.pth.tar')#67 0.8809 0.8809
ckpt = torch.load(args.model_path)#72 
state_dict = ckpt['state_dict']
model.load_state_dict(state_dict)

<All keys matched successfully>

In [11]:
args.mode = 'val'
args.mode

'val'

In [12]:
# load data
for sn in range(111, 131):
    if args.kind == 'liver':
        dataset_test = LiverSegmentation(args, split=args.mode, study_num=sn)
    if args.kind == 'tumor':
        dataset_test = TumorSegmentation(args, split=args.mode, study_num=sn)
    print("num test img: ", len(dataset_test))
    if len(dataset_test) == 0:
        continue
    dataloader = DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=0)

    # gpu use
    device = 'cuda'
    model.to(device)
    model.eval()

    # initialize scores
    cnt = 0
    total_precision = 0
    total_recall = 0
    total_time = 0
    total_cos = 0
    total_voe = 0
    total_assd = 0
    total_vd = 0
    # Dice, jaccard, VOE, ASSD, RVD, MSSD 
    # run inference
    for i, sample in enumerate(dataloader):
        image, target = sample['image'], sample['label']
        image = image.to(device)

        start_time = time.time()
        with torch.no_grad():
            output = model(image)
        timedelta = time.time() - start_time
        total_time += timedelta

        pred = output.data.cpu().numpy()
        target = target.cpu().numpy()
        pred = np.argmax(pred, axis=1)
        # print(np.unique(pred))
        # print(np.unique(target))

        image = image.cpu().numpy()
        for idx in range(len(pred)):
            if args.mode == 'val':
                ## scoring
                pred_ = pred[idx].astype(np.uint8)
                target_ = target[idx].astype(np.uint8)

                intersection = np.logical_and(target_, pred_)
                union = np.logical_or(target_, pred_)
                voe = 1.0 - np.sum(intersection)/np.sum(union)
                sds = surfd(target_, pred_)
                if len(sds) == 0:
                    assd = 0
                else:
                    assd = sds.mean()
                if np.sum(target_) == 0:
                    vd = 1.0
                else:
                    vd = abs((int(np.sum(pred_)) - int(np.sum(target_))) / args.crop_size**2)
                # iou_score = np.sum(intersection) / np.sum(union)
                tp = np.sum(np.logical_and(target_ == 1, pred_ == 1))/256**2
                fp = np.sum(np.logical_and(target_ == 0, pred_ == 1))/256**2
                tn = np.sum(np.logical_and(target_ == 0, pred_ == 0))/256**2
                fn = np.sum(np.logical_and(target_ == 1, pred_ == 0))/256**2

                target_ = target_.ravel()
                pred_ = pred_.ravel()

                cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
                precision = tp/(tp+fp)
                recall = tp/(tp+fn)
                voe = np.nan_to_num(voe, nan=1.0)
                cos_sim = np.nan_to_num(cos_sim, nan=1.0)
                precision = np.nan_to_num(precision, nan=1.0)
                recall = np.nan_to_num(recall, nan=1.0)

                total_cos += cos_sim
                total_precision+=precision
                total_recall+=recall
                total_voe += voe
                total_assd+=assd
                total_vd+=vd
            elif args.mode == 'vis':
                ##visualize(save)
                pred_ = pred[idx].astype(np.uint8)
                target_ = target[idx].astype(np.uint8)
                pred_[pred_ != 0] = 255
                target_[target_ != 0] = 255

                img_tmp = np.transpose(image[idx], axes=[1, 2, 0])
                img_tmp *= (0.229, 0.224, 0.225)
                img_tmp += (0.485, 0.456, 0.406)
                img_tmp *= 255.0
                img_tmp = img_tmp.astype(np.uint8)

                fig = plt.figure()
                fig.tight_layout()
                
                ax1 = fig.add_subplot(1, 3, 1)
                ax1.imshow(target_, cmap='gray')
                # ax1.set_title('Label')
                ax1.axes.xaxis.set_visible(False)
                ax1.axes.yaxis.set_visible(False)
                ax2 = fig.add_subplot(1, 3, 2)
                ax2.imshow(img_tmp, cmap=plt.cm.bone)
                # ax2.set_title('Original')
                ax2.axes.xaxis.set_visible(False)
                ax2.axes.yaxis.set_visible(False)
                ax3 = fig.add_subplot(1, 3, 3)
                ax3.imshow(pred_, cmap='gray')
                # ax3.set_title('Predict')
                ax3.axes.xaxis.set_visible(False)
                ax3.axes.yaxis.set_visible(False)

                # plt.show()
                os.makedirs('val/'+args.kind+f'/{str(sn)}/', exist_ok=True)
                plt.savefig('val/'+args.kind+f'/{str(sn)}/'+str(cnt)+'.png')
                plt.close(fig)
            cnt+=1
            print(cnt, end='\r')

if args.mode == 'val':
    # print scores
    avg_time = total_time/cnt
    p = total_precision/cnt*100
    r = total_recall/cnt*100
    cos = total_cos/cnt*100
    f1 = 2*p*r/(p+r)
    voe = total_voe/cnt*100
    assd = total_assd/cnt
    vd = total_vd/cnt*100
    print(f"avg_time:{round(avg_time,4)} precision:{round(p,4)} recall:{round(r,4)} dice:{round(f1,4)} jaccard:{round(cos,4)} voe:{round(voe,4)} assd:{round(assd,4)} vd:{round(vd,4)}")

num test img:  761


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  751


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  836


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  846


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  846


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  908


  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  recall = tp/(tp+fn)
  voe = 1.0 - np.sum(intersection)/np.sum(union)
  precision = tp/(tp+fp)


num test img:  836


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  427


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  461


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  424


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  463


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  422


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  432


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  407


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  410


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  401


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  987


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  654


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  338


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


num test img:  624


  voe = 1.0 - np.sum(intersection)/np.sum(union)
  cos_sim = np.dot(target_, pred_)/(np.linalg.norm(target_)*np.linalg.norm(pred_))
  precision = tp/(tp+fp)
  recall = tp/(tp+fn)


avg_time:0.0002 precision:98.5506 recall:91.7224 dice:95.014 jaccard:65.764 voe:70.4582 assd:7.0141 vd:61.6383


In [None]:
#liver encoder
avg_time:0.0058 precision:82.0091 recall:96.3349 dice:88.5966 jaccard:77.34 voe:79.59 assd:33.9379 vd:74.6414
#liver decoder
avg_time:0.0002 precision:43.4459 recall:78.314 dice:55.8874 jaccard:65.1172 voe:83.3158 assd:94.6391 vd:62.6889
#liver aspp
avg_time:0.0002 precision:54.8324 recall:94.1485 dice:69.3027 jaccard:78.0577 voe:81.2244 assd:96.2175 vd:74.6727

In [9]:
print(f"avg_time:{round(avg_time,4)} precision:{round(p,4)} recall:{round(r,4)} dice:{round(f1,4)} jaccard:{round(cos,4)} voe:{round(voe,4)} assd:{round(assd,4)} vd:{round(vd,4)}")

avg_time:0.0058 precision:82.0091 recall:96.3349 dice:88.5966 jaccard:77.34 voe:79.59 assd:33.9379 vd:74.6414
