In [None]:
import torch.nn as nn
from segmentation_models_pytorch import UnetPlusPlus, Segformer, DeepLabV3, DeepLabV3Plus, Unet
model = DeepLabV3Plus(encoder_name="resnet34", in_channels=1, classes=1)
model.segmentation_head[2] = nn.Sigmoid()
total_params = round(sum(p.numel() for p in model.parameters()) / 1e6,2)
print(f"Total Parameters: {total_params} M\n")

In [1]:
from Scripts.model import TEUnet, AAUnet, AttUnet, Unet, TEUnet2
model = TEUnet2(1,1,32)
total_params = round(sum(p.numel() for p in model.parameters()) / 1e6,2)
print(f"Total Parameters: {total_params} M\n")



Total Parameters: 25.7 M



In [2]:
import torch
images = torch.rand(size=(2,1,512,512))
out = model(images)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = "cuda:1"
database_name = "BUSI"
model_name = "TEUnet"
seed = 0
in_channels = 1
out_channels = 1
hidden_channels = 64
p = 0.0

def calc_dice_score(predicted, target, smooth=1e-5):
    intersection = torch.sum(predicted * target)
    union = torch.sum(predicted) + torch.sum(target)
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice.item()

def calc_iou_score(predicted, target):
    if torch.sum(predicted) == 0 and torch.sum(target) == 0:
        return 1.0
    intersection = torch.logical_and(predicted, target).sum()
    union = torch.logical_or(predicted, target).sum()
    iou = intersection / union if union != 0 else torch.tensor(0.0)
    return iou.item()

def calc_recall_score(predicted, target):
    true_positive = torch.sum(predicted * target)
    false_negative = torch.sum(target * (~predicted))
    recall = true_positive / (true_positive + false_negative + 1e-5)
    return recall.item()


def calc_precision_score(predicted, target, smooth=1e-5):
    true_positive = torch.sum(predicted * target)
    false_positive = torch.sum(predicted * (1-target))
    precision = true_positive / (true_positive + false_positive + smooth)
    return precision.item()

def calc_specificity_score(predicted, target, smooth=1e-5):
    true_negtive = torch.sum((~predicted) * (1 - target))
    false_positive = torch.sum(predicted * (1 - target))
    specificity = true_negtive / (true_negtive + false_positive + smooth)
    return specificity.item()

from Scripts.model import TEUnet, AAUnet, AttUnet, Unet
from segmentation_models_pytorch import UnetPlusPlus, Segformer, DeepLabV3
from Scripts.utils.data import Database2Dataloader

if model_name == "Unet":
    from Scripts.model import Unet
    model = Unet(in_channels,out_channels,hidden_channels, p) 
elif model_name == "AttUnet":
    from Scripts.model import AttUnet
    model = AttUnet(in_channels,out_channels,hidden_channels, p)
elif model_name == "AAUnet":
    from Scripts.model import AAUnet
    model = AAUnet(in_channels,out_channels,hidden_channels//2, p)
elif model_name == "TEUnet":
    from Scripts.model import TEUnet
    model = TEUnet(in_channels,out_channels,hidden_channels//2, p)
elif model_name == "Unet++":
    from segmentation_models_pytorch import UnetPlusPlus
    model = UnetPlusPlus(encoder_name="resnet34", in_channels=in_channels, classes=out_channels)
elif model_name == "Segformer":
    from segmentation_models_pytorch import Segformer
    model = Segformer(encoder_name="resnet34", in_channels=in_channels, classes=out_channels)
elif model_name == "DeepLabV3+":
    from segmentation_models_pytorch import DeepLabV3Plus
    model = DeepLabV3Plus(encoder_name="resnet34", in_channels=in_channels, classes=out_channels)
elif model_name == "DeepLabV3":
    from segmentation_models_pytorch import DeepLabV3
    model = DeepLabV3(encoder_name="resnet34", in_channels=in_channels, classes=out_channels)
model.load_state_dict(torch.load(f"./Checkpoints/{model_name}_{database_name}_{seed}.pth", map_location=device))
model.to(device)
model.eval()

dataloader = Database2Dataloader(database_path=f"Database/{database_name}", seed=0)

iou_score = 0.0
dice_score = 0.0
recall_score = 0.0
precision_score = 0.0
specificity_score = 0.0
n = 0
maxpooler = nn.MaxPool2d(16,16)

for images, masks in dataloader["test"]:
    images, masks = images.to(device), masks.to(device)
    with torch.no_grad():
        outputs = model(images)[1] if model_name == "TEUnet" else model(images)
        outputs = F.sigmoid(outputs) if model_name in ["Unet++", "Segformer", "DeepLabV3+", "DeepLabV3"] else outputs
        dice_score += calc_dice_score(outputs>0.5, masks)
        iou_score += calc_iou_score(outputs>0.5, masks)
        recall_score += calc_recall_score(outputs>0.5, masks)
        precision_score += calc_precision_score(outputs>0.5, masks)
        specificity_score += calc_specificity_score(outputs>0.5, masks)
        n += 1
print(f"Test Dice Score: {dice_score/n:.4f}")
print(f"Test IoU Score: {iou_score/n:.4f}")
print(f"Test Recall Score: {recall_score/n:.4f}")
print(f"Test precision Score: {precision_score/n:.4f}")
print(f"Test specificity Score: {specificity_score/n:.4f}")