In [None]:
# import al dependencies and required functions

import numpy as np
from monai.data import Dataset, DataLoader
from monai.transforms import (
    Compose,
    ScaleIntensityd,
    AddChanneld,
    ToTensord
)
import pandas as pd
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
import torch
from torch.nn import L1Loss, MSELoss

#function for calculating dice scores
def get_dice_score(pred, target, num_class):
    epsilon = 1e-5
    dice_scores = np.zeros(num_class)
    for s in range(pred.shape[0]):
        mask_p = pred[s].detach().cpu().numpy()
        mask_t = target[s].detach().cpu().numpy()
        for i in range(num_class):
            arr = np.full((256, 256), i)
            mask_p_i = (mask_p == arr).astype(int)
            mask_t_i = (mask_t == arr).astype(int)
            dice_scores[i] += ((2*np.sum(mask_p_i*mask_t_i)+epsilon) / (np.sum(mask_p_i)+np.sum(mask_t_i)+epsilon))
    dice_scores /= pred.shape[0]
    return dice_scores

In [None]:
# load the model

config_vit = CONFIGS_ViT_seg['R50-ViT-B_16']
config_vit.n_classes = 5
config_vit.n_skip = 3



model = ViT_seg(config_vit, img_size=256, num_classes=5)
model = model.double()
model.load_state_dict(torch.load("./trained_models/model_250epochs_0.0033base_lr/epoch_250.pth"))

parameters_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('model created! total number of parameters: {}'.format(parameters_num))

In [None]:
# load the test data
from utils import test_loader

test_path = './data/test.pkl'
test_dict = test_loader(test_path)

test_transforms = Compose(

    [
        
        AddChanneld(keys=['mri', 'maj']),
        ScaleIntensityd(keys='mri', channel_wise=True),
        ToTensord(keys=['mri', 'maj', 'var', 'level'])
    ],
        log_stats=True
)

test_set = Dataset(data=test_dict, transform=test_transforms)


testloader = DataLoader(test_set,batch_size=1,pin_memory=True)

In [None]:
device = torch.device("cuda")
dice_scores = []
l1_losses = []
loss_fn = L1Loss()
model.eval()
model.cuda()
i=0
for data_dict in testloader:

        data = data_dict["mri"].double().to(device)
        mask_majority = data_dict["maj"].long().to(device)
        var_gt = data_dict["var"].double().to(device)

        mask_pred, var_pred = model(data)

        loss_l11 = loss_fn(var_pred[0,0], var_gt[0,0])
        loss_l12 = loss_fn(var_pred[0,1], var_gt[0,1])
        loss_l13 = loss_fn(var_pred[0,2], var_gt[0,2])
        loss_l14 = loss_fn(var_pred[0,3], var_gt[0,3])
        
        l1_losses.append([loss_l11.item(), loss_l12.item(), loss_l13.item(), loss_l14.item()])

        
        normalized_masks = torch.nn.functional.softmax(mask_pred, dim=1)
        temp = torch.squeeze(normalized_masks)
        final_mask = torch.max(normalized_masks, dim=1)
        dice = get_dice_score(final_mask[1], torch.squeeze(mask_majority, 1), 5)
        dice_scores.append(dice)
        
        i = i+1
        
dice1, dice2, dice3, dice4, dice5 = [], [], [], [], []
for dice in dice_scores:
    dice1.append(dice[0]*100)
    dice2.append(dice[1]*100)
    dice3.append(dice[2]*100)
    dice4.append(dice[3]*100)
    dice5.append(dice[4]*100)
    
loss1, loss2, loss3, loss4 = [], [], [], []
for loss in l1_losses:
    loss1.append(loss[0])
    loss2.append(loss[1])
    loss3.append(loss[2])
    loss4.append(loss[3])
    
    
print(np.sum(dice2)/len(dice2))
print(np.sum(dice3)/len(dice3))
print(np.sum(dice4)/len(dice4))
print(np.sum(dice5)/len(dice5))

print(" ")
print(np.sum(loss1)/len(loss1))
print(np.sum(loss2)/len(loss2))
print(np.sum(loss3)/len(loss3))
print(np.sum(loss4)/len(loss4))