In [1]:
import os
import sys

sys.path.append("../../../")

import yaml
import torch
import numpy as np
from typing import Dict
from architectures.build_architecture import build_architecture
from dataloaders.build_dataset import build_dataset
from typing import Tuple, Dict
from fvcore.nn import FlopCountAnalysis
from tqdm.notebook import tqdm
from sklearn.metrics import (
    jaccard_score,
    accuracy_score,
    confusion_matrix,
)
import monai

Load Config File

In [2]:
def load_config(config_path: str) -> Dict:
    """loads the yaml config file

    Args:
        config_path (str): _description_

    Returns:
        Dict: _description_
    """
    with open(config_path, "r") as file:
        config = yaml.safe_load(file)
    return config


config = load_config("config.yaml")

Build Dataset and DataLoaders

In [3]:
# build validation dataset & validataion data loader
testset = build_dataset(
    dataset_type=config["dataset_parameters"]["dataset_type"],
    dataset_args=config["dataset_parameters"]["val_dataset_args"],
    augmentation_args=config["test_augmentation_args"],
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=1, shuffle=False, num_workers=1
)

print(len(testset))

40


In [4]:
model = build_architecture(config=config)
checkpoint = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(checkpoint)
model = model.to("cpu")
model = model.eval()

Model Complexity

In [5]:
import torchvision.models as models
import torch
from ptflops import get_model_complexity_info

with torch.cuda.device(0):
    net = model
    macs, params = get_model_complexity_info(
        net, (3, 256, 256), as_strings=True, print_per_layer_stat=False, verbose=False
    )
    print("{:<30}  {:<8}".format("Computational complexity: ", macs))
    print("{:<30}  {:<8}".format("Number of parameters: ", params))

Computational complexity:       1.3 GMac
Number of parameters:           3.01 M  


In [6]:
def flop_count_analysis(
    model: torch.nn.Module,
    input_dim: Tuple,
) -> Dict:
    """_summary_

    Args:
        input_dim (Tuple): shape: (batchsize=1, C, H, W, D(optional))
        model (torch.nn.Module): _description_
    """
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    input_tensor = torch.ones(()).new_empty(
        (1, *input_dim),
        dtype=next(model.parameters()).dtype,
        device=next(model.parameters()).device,
    )
    flops = FlopCountAnalysis(model, input_tensor)
    model_flops = flops.total()
    # print(f"Total trainable parameters: {round(trainable_params * 1e-6, 2)} M")
    # print(f"MAdds: {round(model_flops * 1e-9, 2)} G")

    out = {
        "params": round(trainable_params * 1e-6, 2),
        "flops": round(model_flops * 1e-9, 2),
    }

    return out


inference_result = flop_count_analysis(model, (3, 256, 256))
print("{:<30}  {:<8}".format("Computational complexity: ", inference_result["params"]))
print("{:<30}  {:<8}".format("Number of parameters: ", inference_result["flops"]))

Unsupported operator aten::silu encountered 84 time(s)
Unsupported operator aten::add encountered 45 time(s)
Unsupported operator aten::div encountered 15 time(s)
Unsupported operator aten::ceil encountered 6 time(s)
Unsupported operator aten::mul encountered 118 time(s)
Unsupported operator aten::softmax encountered 21 time(s)
Unsupported operator aten::clone encountered 4 time(s)
Unsupported operator aten::mul_ encountered 24 time(s)
Unsupported operator aten::upsample_bicubic2d encountered 2 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
encoder.encoder.conv_1x1_exp, encoder.encoder.conv_1x1_exp.activation, encoder.encoder.conv_1x1_exp.convolution, encoder.encoder.conv_1x1_exp.normalization


Computational complexity:       3.01    
Number of parameters:           1.24    


Calculate IoU Metric

In [7]:
iou = []
with torch.no_grad():
    for idx, data in tqdm(enumerate(testloader)):
        image = data["image"].cuda()
        mask = data["mask"].cuda()
        out = model.forward(image)
        out = torch.sigmoid(out)
        out[out < 0.5] = 0
        out[out >= 0.5] = 1
        mean_iou = jaccard_score(
            mask.detach().cpu().numpy().ravel(),
            out.detach().cpu().numpy().ravel(),
            average="binary",
            pos_label=1,
        )
        iou.append(mean_iou.item())

print(f"test iou: {np.mean(iou)}")

0it [00:00, ?it/s]

test iou: 0.9229712867991552


Accuracy

In [8]:
accuracy = []
with torch.no_grad():
    for idx, data in tqdm(enumerate(testloader)):
        image = data["image"].cuda()
        mask = data["mask"].cuda()
        out = model.forward(image)
        out = torch.sigmoid(out)
        out[out < 0.5] = 0
        out[out >= 0.5] = 1
        acc = accuracy_score(
            mask.detach().cpu().numpy().ravel(),
            out.detach().cpu().numpy().ravel(),
        )
        accuracy.append(acc.item())

print(f"test accuracy: {np.mean(accuracy)}")

0it [00:00, ?it/s]

test accuracy: 0.9771324157714844


Calculate Dice

In [9]:
dice = []
with torch.no_grad():
    for idx, data in tqdm(enumerate(testloader)):
        image = data["image"].cuda()
        mask = data["mask"].cuda()
        out = model.forward(image)
        out = torch.sigmoid(out)
        out[out < 0.5] = 0
        out[out >= 0.5] = 1
        mean_dice = monai.metrics.compute_dice(out, mask.unsqueeze(1))
        dice.append(mean_dice.item())

print(f"test dice: {np.mean(dice)}")

0it [00:00, ?it/s]

test dice: 0.9569526433944702


Calculate Specificity

In [10]:
specificity = []
with torch.no_grad():
    for idx, data in tqdm(enumerate(testloader)):
        image = data["image"].cuda()
        mask = data["mask"].cuda()
        out = model.forward(image)
        out = torch.sigmoid(out)
        out[out < 0.5] = 0
        out[out >= 0.5] = 1
        confusion = confusion_matrix(
            mask.detach().cpu().numpy().ravel(),
            out.detach().cpu().numpy().ravel(),
        )
        if float(confusion[0, 0] + confusion[0, 1]) != 0:
            sp = float(confusion[0, 0]) / float(confusion[0, 0] + confusion[0, 1])

        specificity.append(sp)

print(f"test specificity: {np.mean(specificity)}")

0it [00:00, ?it/s]

test specificity: 0.9660395899750096


Calculate Sensitivity

In [11]:
sensitivity = []
with torch.no_grad():
    for idx, data in tqdm(enumerate(testloader)):
        image = data["image"].cuda()
        mask = data["mask"].cuda()
        out = model.forward(image)
        out = torch.sigmoid(out)
        out[out < 0.5] = 0
        out[out >= 0.5] = 1
        confusion = confusion_matrix(
            mask.detach().cpu().numpy().ravel(),
            out.detach().cpu().numpy().ravel(),
        )
        if float(confusion[1, 1] + confusion[1, 0]) != 0:
            se = float(confusion[1, 1]) / float(confusion[1, 1] + confusion[1, 0])

        sensitivity.append(se)

print(f"test sensitivity: {np.mean(sensitivity)}")

0it [00:00, ?it/s]

test sensitivity: 0.9604657601219436


In [12]:
# DONE