In [2]:
import time
import os
import math
import argparse
from glob import glob
from collections import OrderedDict
import random
import warnings
import datetime
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import joblib
from skimage.io import imread

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torchvision
from torchvision import datasets, models, transforms

from dataset.liver import Dataset_liver
import pandas as pd
from medpy import metric

In [3]:
def precision_recall_f(output, target):
    smooth = 1e-5
    num = output.shape[0]
    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()

    input_1 = output[:, 0, :, :]


    target_1 = target[:, 0, :, :]
    input_1 = input_1 > 0.5
    target_1 = target_1 > 0.5
    tp = (input_1 & target_1).sum()
    fn = ((input_1 == 0) & (target_1 == 1)).astype('int')
    fp = ((input_1 == 1) & (target_1 == 0)).astype('int')
 
    precision = tp / (tp.sum() + fp.sum() + smooth)
    recall = tp / (tp.sum() + fn.sum() + smooth)


    return precision, recall

def rvd_f(output, target):
    smooth = 1e-5
    num = output.shape[0]
    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()

    input_1 = output[:, 0, :, :]


    target_1 = target[:, 0, :, :]


    rvd_score_1 = (target_1.sum() - input_1.sum() + smooth) / (input_1.sum() + smooth)

    return rvd_score_1
def dice_coef(output, target):
    smooth = 1e-5
    num = output.shape[0]
    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()

    input_1 = output[:, 0, :, :]


    target_1 = target[:, 0, :, :]


    intersection_1 = (input_1 * target_1)


    dice_1 = (2. * intersection_1.sum() + smooth) / (input_1.sum() + target_1.sum() + smooth)


    return dice_1
def voe_f(output, target):
    smooth = 1e-5

    if torch.is_tensor(output):
        output = torch.sigmoid(output).data.cpu().numpy()
    if torch.is_tensor(target):
        target = target.data.cpu().numpy()
    input_1 = output[:, 0, :, :]


    target_1 = target[:, 0, :, :]

    input_1 = input_1 > 0.5
    target_1 = target_1 > 0.5

    intersection_1 = (input_1 & target_1).sum()
    union_1 = (input_1 | target_1).sum()


    voe_1 = 1 - (intersection_1 + smooth) / (union_1 + smooth)


    return voe_1

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.li = []

    def update(self, val, n=1):
        self.li.append(val)
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        

def validate(val_loader, model):

    voes = AverageMeter()

    dices = AverageMeter()

    rvds = AverageMeter()
    
    precisions = AverageMeter()
    recalls = AverageMeter()


    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (input, target) in tqdm(enumerate(val_loader), total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()

        

            output = model(input)

            voe = voe_f(output, target)
            dice = dice_coef(output, target)
            rvd = rvd_f(output, target)
            precision, recall = precision_recall_f(output, target)


            voes.update(torch.tensor(voe), input.size(0))
            
            dices.update(torch.tensor(dice), input.size(0))
           
            rvds.update(torch.tensor(rvd), input.size(0))
            precisions.update(torch.tensor(precision), input.size(0))
            recalls.update(torch.tensor(recall), input.size(0))
            
            

    log = OrderedDict([
        ('voe_1', voes.avg),

        ('rvd_1', rvds.avg),

        ('dice_1', dices.avg),

        ('voe_1_var', np.std(voes.li)),

        ('rvd_1_var', np.std(rvds.li)),

        ('dice_1_var', np.std(dices.li)),
        ('precision_1', precisions.avg),
        ('recall_1', recalls.avg),
        ('precision_1_var', np.std(precisions.li)),

        ('recall_1_var', np.std(recalls.li)),
    ])

    return log

In [4]:
from model.ResTransUNet import ResTransUNet_224 as ResTransUNet


model = ResTransUNet(0)

model = torch.nn.DataParallel(model).cuda()
# ours
model.load_state_dict(torch.load('./weight/best.pth'))




# dataset
val_img_paths = glob('./data/liver/validImage/*')
val_mask_paths = glob('./data/liver/validMask/*')
# val_img_paths = glob('./data/3Diradb/liver/Image/*')
# val_mask_paths = glob('./data/3Diradb/liver/Mask/*')
# val_img_paths = glob('./data/chaos/Image/*')
# val_mask_paths = glob('./data/chaos/Mask/*')
# val_img_paths = glob('./data/sliver07/liver/Image/*')
# val_mask_paths = glob('./data/sliver07/liver/Mask/*')
val_dataset = Dataset_liver(0, val_img_paths, val_mask_paths, transform=False)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,
    pin_memory=True,
    drop_last=False)

log = pd.DataFrame(index=[], columns=[
    'dice_1', 'voe_1', 'rvd_1', 'dice_2', 'voe_2', 'rvd_2'
])


first_time = time.time()

val_log = validate(val_loader, model)

print('dice: %.4f+%.3f - voe: %.4f+%.3f - rvd: %.4f+%.3f - precision: %.4f+%.3f - recall: %.4f+%.3f'
          %(val_log['dice_1'], val_log['dice_1_var'], val_log['voe_1'], val_log['voe_1_var'], val_log['rvd_1'], val_log['rvd_1_var'],
           val_log['precision_1'], val_log['precision_1_var'], val_log['recall_1'], val_log['recall_1_var']))

end_time = time.time()
print("time:", (end_time - first_time) / 60)




torch.cuda.empty_cache()

53777211


100%|██████████| 534/534 [02:47<00:00,  3.18it/s]

dice: 0.9535+0.045 - voe: 0.0804+0.068 - rvd: -0.0007+0.095 - precision: 0.9502+0.053 - recall: 0.9661+0.056
time: 2.796311370531718



