In [12]:
import torch
import os
import time

from torchmetrics import Dice
from torchmetrics.aggregation import MeanMetric
import segmentation_models_pytorch as smp

from Benchmark import dataset
from methods import unet, unetplusplus

from prettytable import PrettyTable
from colorama import Fore, Style, init
from tqdm import tqdm

Arguments

In [6]:
device = 'cpu'
load_path = './saved_model/pretrain-unet-efficientnet.pth'


def segment(image, model):
  with torch.inference_mode():
    prediction = model(image)
    return torch.sigmoid(prediction)

Load test data_loader

In [4]:
test_batch_size = 64
test_loader = dataset.BraTS20(root='./Benchmark', mode='test', mini=False, memory=False)(batch_size=test_batch_size)

Load Model

In [7]:
# model = unet.UNet(n_channels=3, n_classes=3, bilinear=False).to(device)
model = unet.pre_train_unet(in_channels=4, classes=4, encoder_name='efficientnet-b1').to(device)
# model = unetplusplus.UnetPlusPlus(encoder_name='efficientnet-b3').to(device)
# model = unetplusplus.UnetPlusPlus(encoder_name='resnet18').to(device)


sate = torch.load(load_path)
model.load_state_dict(sate['state_dict'])

<All keys matched successfully>

In [10]:
def evaluate(model, test_loader, device='cpu'):
    model.eval().to(device)
    
    # Initialize metrics
    iou_score, f1_score, f2_score, accuracy, recall = MeanMetric(), MeanMetric(), MeanMetric(), MeanMetric(), MeanMetric()
    dice_metric = Dice(average='micro').to(device)

    with torch.inference_mode():
        # Wrap the test_loader with tqdm for a progress bar
        for inputs, targets in tqdm(test_loader, desc="Evaluating", leave=False):
            inputs = inputs.to(device)
            targets = targets.to(device)

            # Model predictions
            outputs = segment(inputs, model)
    
            # Convert targets to integer type
            targets = targets.to(torch.int32)

            # Calculate true positives, false positives, false negatives, true negatives
            tp, fp, fn, tn = smp.metrics.get_stats(outputs.cpu(), targets.cpu(), mode='multilabel', threshold=0.5)
      
            # Update the metrics
            iou_score.update(smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro"), weight=len(targets))
            f1_score.update(smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro"), weight=len(targets))
            f2_score.update(smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro"), weight=len(targets))
            accuracy.update(smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro"), weight=len(targets))
            recall.update(smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise"), weight=len(targets))

            # Update the Dice metric
            dice_metric(outputs, targets)

    # Compute final metric values
    return (iou_score.compute().item(), 
            f1_score.compute().item(), 
            f2_score.compute().item(), 
            accuracy.compute().item(), 
            recall.compute().item(), 
            dice_metric.compute().item())

In [11]:
iou_score, f1_score, f2_score, accuracy, recall, dice = evaluate(model, test_loader, device='cuda')

table = PrettyTable()

# Define column names and alignment
table.field_names = ["Metric", "Value"]
table.align["Metric"] = "l"
table.align["Value"] = "r"

table.add_row(["IoU Score", f"{iou_score:.2%}"])
table.add_row(["F1 Score", f"{f1_score:.2%}"])
table.add_row(["F2 Score", f"{f2_score:.2%}"])
# table.add_row(["Accuracy", f"{accuracy:.2%}"])
table.add_row(["Recall", f"{recall:.2%}"])
table.add_row(["Dice", f"{dice:.2%}"])

print(table)

                                                             

+-----------+--------+
| Metric    |  Value |
+-----------+--------+
| IoU Score | 87.28% |
| F1 Score  | 93.21% |
| F2 Score  | 89.73% |
| Recall    | 87.55% |
| Dice      | 93.21% |
+-----------+--------+


