In [1]:
from network import enc_dec
from dataloader import image_loader
from dataloader import cevae_loader
from utils import loss_fn, tb_utils, utils
from torch.utils.data import DataLoader
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from copy import deepcopy
from datetime import date
from metrics import metrics

In [2]:
config = utils.read_config('./config/cevae_config.yml')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = config['inference']['seed']
utils.set_seed(seed)

## Slice wise inference

In [3]:
normal_path = config['inference']['dataloader']['normal_path']
abnormal_path = config['inference']['dataloader']['abnormal_path']
resize = tuple(config['inference']['dataloader']['resize'])
patch_size = tuple(config['inference']['dataloader']['patchsize'])
margin = tuple(config['inference']['dataloader']['margin'])
batch_size = config['inference']['dataloader']['batch']
num_workers = config['inference']['dataloader']['num_workers']

In [4]:
h_size = config['inference']['model']['h_size']
input_size = config['inference']['model']['input_size']
z_dim = config['inference']['model']['z_dim']
lamda = torch.tensor(config['inference']['model']['lamda'])
beta = torch.tensor(config['inference']['model']['beta'])

In [5]:
model_path = config['inference']['model']['load_model']
model_module = enc_dec.VAE(input_size, h_size, z_dim)
model = torch.load(model_path,map_location=device)

In [6]:
normal_data = cevae_loader.cevae(normal_path,patchsize=patch_size,margin=margin,resize=resize)
normal_loader = DataLoader(normal_data,batch_size=batch_size,num_workers=num_workers)
abnormal_data = cevae_loader.cevae(abnormal_path,patchsize=patch_size,margin=margin,resize=resize)
abnormal_loader = DataLoader(abnormal_data,batch_size=batch_size,num_workers=num_workers)

In [7]:
normal_score = []
model.eval()
with torch.no_grad():
    for idx,(data,masked_data) in enumerate(tqdm(normal_loader,desc='val_iter',leave=False)):
        data, masked_data = data.to(device), masked_data.to(device)
        rec_vae,mu, std = model(data)
        rec_ce,_,_ = model(masked_data)

        kl_loss = loss_fn.kl_divergence(mu,std)
        rec_loss_vae = loss_fn.rec_loss_fn(rec_vae,data)
        loss_vae = rec_loss_vae + kl_loss*beta
        rec_loss_ce = loss_fn.rec_loss_fn(rec_ce,data)
        loss = (1 - lamda)*loss_vae + lamda*rec_loss_ce
        normal_score.append(loss)

                                                             

In [8]:
abnormal_score = []
model.eval()
with torch.no_grad():
    for idx,(data,masked_data) in enumerate(tqdm(abnormal_loader,desc='iter',leave=False)):
        data, masked_data = data.to(device), masked_data.to(device)
        rec_vae,mu, std = model(data)
        rec_ce,_,_ = model(masked_data)

        kl_loss = loss_fn.kl_divergence(mu,std)
        rec_loss_vae = loss_fn.rec_loss_fn(rec_vae,data)
        loss_vae = rec_loss_vae + kl_loss*beta
        rec_loss_ce = loss_fn.rec_loss_fn(rec_ce,data)
        loss = (1 - lamda)*loss_vae + lamda*rec_loss_ce
        abnormal_score.append(loss)

                                                         

In [9]:
auroc,aupr,roc_curve,pr_curve = metrics.get_metrics(abnormal_score,normal_score)

In [10]:
auroc,aupr

(0.6515214190846292, 0.3687309393921744)

In [11]:
fpr,tpr,_ = roc_curve
precision,recall, _ = pr_curve

In [19]:
def save_scores(array,filename,savepath):
    np.save(savepath+filename,array)

In [20]:
save_scores(fpr,'false_positive_slicewise_CEVAE','./results/')
save_scores(tpr,'True_positive_slicewise_CEVAE','./results/')
save_scores(precision,'precision_slicewise_CEVAE','./results/')
save_scores(recall,'recall_slicewise_CEVAE','./results/')

## Volume wise scores

In [21]:
normal_volume = config['inference']['dataloader']['normal_volumes']
abnormal_volume = config['inference']['dataloader']['abnormal_volumes']

In [22]:
normal_vdata = cevae_loader.cevae(normal_volume,patchsize=patch_size,margin=margin,resize=resize)
abnormal_vdata = cevae_loader.cevae(abnormal_volume,patchsize=patch_size,margin=margin,resize=resize)
normal_vloader = DataLoader(normal_vdata,batch_size=batch_size,num_workers=num_workers)
abnormal_vloader = DataLoader(abnormal_vdata,batch_size=batch_size,num_workers=num_workers)

In [23]:
volume_normal = []
model.eval()
with torch.no_grad():
    for idx,(data,masked_data) in enumerate(tqdm(normal_vloader,desc='iter',leave=False)):
        data, masked_data = data.to(device), masked_data.to(device)
        rec_vae,mu, std = model(data)
        rec_ce,_,_ = model(masked_data)

        kl_loss = loss_fn.kl_divergence(mu,std)
        rec_loss_vae = loss_fn.rec_loss_fn(rec_vae,data)
        loss_vae = rec_loss_vae + kl_loss*beta
        rec_loss_ce = loss_fn.rec_loss_fn(rec_ce,data)
        loss = (1 - lamda)*loss_vae + lamda*rec_loss_ce
        volume_normal.append(loss)

                                                         

In [24]:
volume_abnormal = []
model.eval()
with torch.no_grad():
    for idx,(data,masked_data) in enumerate(tqdm(abnormal_vloader,desc='iter',leave=False)):
        data, masked_data = data.to(device), masked_data.to(device)
        rec_vae,mu, std = model(data)
        rec_ce,_,_ = model(masked_data)

        kl_loss = loss_fn.kl_divergence(mu,std)
        rec_loss_vae = loss_fn.rec_loss_fn(rec_vae,data)
        loss_vae = rec_loss_vae + kl_loss*beta
        rec_loss_ce = loss_fn.rec_loss_fn(rec_ce,data)
        loss = (1 - lamda)*loss_vae + lamda*rec_loss_ce
        volume_abnormal.append(loss)

                                                         

In [25]:
vauroc,vaupr,vroc, vpr = metrics.get_metrics(volume_abnormal,volume_normal)

In [26]:
vauroc,vaupr

(0.6688839979127315, 0.7037036897423272)

In [27]:
vfpr,vtpr,_ = vroc
vprec,vrec,_ = vpr

In [28]:
save_scores(vfpr,'false_positive_clinicwise_CEVAE','./results/')
save_scores(vtpr,'True_positive_clinicwise_CEVAE','./results/')
save_scores(vprec,'precision_clincwise_CEVAE','./results/')
save_scores(vrec,'recall_clinicwise_CEVAE','./results/')

## Mask creation and segmentation evaluation

In [9]:
from utils.image_utils import *
from utils.mask_utils import *
import matplotlib.pyplot as plt

In [10]:
data = load_data(abnormal_path,resize)
dataloader = DataLoader(data,batch_size=batch_size,num_workers=num_workers)

In [16]:
save_path = './masks/cevae_masks/'

In [20]:
with torch.no_grad():
    for idx,(data,filename) in enumerate(dataloader):
        data = data.float()
        data = data.to(device)
        output,mu,std = model(data)
        mask = output - data
        mask = mask.detach().cpu().numpy()
        mask = np.squeeze(mask,axis=0)
        mask = normalise_mask(mask[:,:,:])
        mask = mask < 0.45
        save_masks(mask,filename[0],save_path)
        

## segmentation metric

In [21]:
gt_path = './masks/segmentations/'
mask_path = './masks/cevae_masks/'
gt_files = sorted(glob(gt_path+'*.npy',recursive=True))
mask_files = sorted(glob(mask_path+'*.npy',recursive=True))
len(gt_files),len(mask_files)

(3505, 3505)

In [23]:
dice_scores = []
for i in tqdm(range(len(mask_files))):
    mask_file = np.load(mask_files[i])
    gt_file = np.load(gt_files[i])
    dice_scores.append(metrics.dice(gt_file,mask_file))

100%|██████████| 3505/3505 [00:21<00:00, 164.11it/s]


In [24]:
avg_dice = np.array(dice_scores).mean()
avg_dice

0.11842989473200181