In [1]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

from unet import UNET
from unetpp import UNETPP
from att_unet import AttentionUNet

from unet_gtam import UNET_GTAM
from unetpp_gtam import  UNETPP_GTAM
from att_unet_gtam import AttentionUNet_GTAM
from fcn_resnet50_gtam import FCN_ResNet50_GTAM
from DeepLabV3_gtam import DeepLabV3_ResNet50_GTAM

from dataset_BUSI import BUSIDataset
from metrics import compute_metrics

from torchvision.models.segmentation import fcn_resnet50
from torchvision.models.segmentation import deeplabv3_resnet50

# ---------------------------
# Device and config
# ---------------------------
DEVICE = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)
NUM_CLASSES = 2

TEST_IMG_DIR = "BUSI/Data/test/images"
TEST_MASK_DIR = "BUSI/Data/test/masks"


# ---------------------------
# Transform
# ---------------------------
test_transform = A.Compose([
    A.Resize(height=256, width=256),
    A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
    ToTensorV2(),
])

val_ds = BUSIDataset(TEST_IMG_DIR, TEST_MASK_DIR, transform=test_transform)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)


# ---------------------------
# Evaluation functions
# ---------------------------

def evaluate(model, model_name):
    """For UNet / UNet++ / Attention / GTAM models."""
    model.eval()
    all_metrics = {
        "accuracy": [], "precision": [], "recall": [],
        "specificity": [], "dice": [], "iou": []
    }

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(DEVICE)
            masks = masks.long().to(DEVICE)

            logits = model(images)
            if isinstance(logits, dict):
                logits = logits["out"]

            probs = torch.softmax(logits, dim=1)[:, 1]
            preds = (probs > 0.5).float()

            m = compute_metrics(preds.cpu(), masks.cpu())
            for k, v in m.items():
                all_metrics[k].append(v)

    final_scores = {k: float(np.mean(v)) for k, v in all_metrics.items()}

    print(f"\n=== Results for {model_name} ===")
    for k, v in final_scores.items():
        print(f"{k:12s}: {v:.4f}")

    return final_scores


def evaluate_resnet(model, model_name):
    """For FCN-ResNet50 / DeepLabV3 models."""
    model.eval()
    all_metrics = {
        "accuracy": [], "precision": [], "recall": [],
        "specificity": [], "dice": [], "iou": []
    }

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(DEVICE)

            # Convert grayscale → RGB
            if images.shape[1] == 1:
                images = images.repeat(1, 3, 1, 1)

            masks = masks.long().to(DEVICE)

            outputs = model(images)
            logits = outputs["out"]

            probs = torch.softmax(logits, dim=1)[:, 1]
            preds = (probs > 0.5).float()

            m = compute_metrics(preds.cpu(), masks.cpu())
            for k, v in m.items():
                all_metrics[k].append(v)

    final_scores = {k: float(np.mean(v)) for k, v in all_metrics.items()}

    print(f"\n=== Results for {model_name} ===")
    for k, v in final_scores.items():
        print(f"{k:12s}: {v:.4f}")

    return final_scores


# ---------------------------
# Evaluate U-Net
# ---------------------------
unet = UNET(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE)
ckpt_unet = torch.load("BUSI/Best_BUSI/UNet_best.pth", map_location=DEVICE)
unet.load_state_dict(ckpt_unet["state_dict"])
scores_unet = evaluate(unet, "U-Net")

# ---------------------------
# Evaluate U-Net++
# ---------------------------
unetpp = UNETPP(in_channels=1, out_channels=NUM_CLASSES, deep_supervision=False).to(DEVICE)
ckpt_unetpp = torch.load("BUSI/Best_BUSI/unet_PP_best.pth", map_location=DEVICE)
unetpp.load_state_dict(ckpt_unetpp["state_dict"])
scores_unetpp = evaluate(unetpp, "U-Net++")

# ---------------------------
# Evaluate Attention U-Net
# ---------------------------
att_unet = AttentionUNet(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE)
ckpt_att = torch.load("BUSI/Best_BUSI/AttentionUNet_best.pth", map_location=DEVICE)
att_unet.load_state_dict(ckpt_att["state_dict"])
scores_att = evaluate(att_unet, "AttentionUNet")

# ---------------------------
# Evaluate FCN-ResNet50
# ---------------------------
fcn = fcn_resnet50(weights=None, weights_backbone=None)
fcn.classifier[4] = torch.nn.Conv2d(512, NUM_CLASSES, kernel_size=1)
fcn = fcn.to(DEVICE)
ckpt_fcn = torch.load("BUSI/Best_BUSI/FCN_ResNet50_best.pth", map_location=DEVICE)
fcn.load_state_dict(ckpt_fcn["state_dict"])
scores_fcn = evaluate_resnet(fcn, "FCN_ResNet50")


# ---------------------------
# Evaluate DeepLabV3-ResNet50
# ---------------------------
deeplab = deeplabv3_resnet50(weights=None, weights_backbone=None, aux_loss=False)
deeplab.classifier[-1] = torch.nn.Conv2d(256, NUM_CLASSES, kernel_size=1)
deeplab = deeplab.to(DEVICE)
ckpt_deeplab = torch.load("BUSI/Best_BUSI/DeepLabV3_best.pth", map_location=DEVICE)
deeplab.load_state_dict(ckpt_deeplab["state_dict"])
scores_deeplab = evaluate_resnet(deeplab, "DeepLabV3_ResNet50")



# ---------------------------
# Evaluate GTAM Models
# ---------------------------
unet_GTAM = UNET_GTAM(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE)
ckpt_unet_GTAM = torch.load("BUSI/Best_BUSI_GTAM/unet_GTAM_best.pth", map_location=DEVICE)
unet_GTAM.load_state_dict(ckpt_unet_GTAM["state_dict"])
scores_unet_GTAM = evaluate(unet_GTAM, "U-Net_GTAM")

unetpp_GTAM = UNETPP_GTAM(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE)
ckpt_unetpp_GTAM = torch.load("BUSI/Best_BUSI_GTAM/UNetPP_GTAM_best.pth", map_location=DEVICE)
unetpp_GTAM.load_state_dict(ckpt_unetpp_GTAM["state_dict"])
scores_unetpp_GTAM = evaluate(unetpp_GTAM, "U-Net++_GTAM")

att_GTAM = AttentionUNet_GTAM(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE)
ckpt_att_GTAM = torch.load("BUSI/Best_BUSI_GTAM/AttentionUNet_GTAM_best.pth", map_location=DEVICE)
att_GTAM.load_state_dict(ckpt_att_GTAM["state_dict"])
scores_att_GTAM = evaluate(att_GTAM, "AttentionUNet_GTAM")

fcn_gtam = FCN_ResNet50_GTAM(num_classes=2).to(DEVICE)
ckpt_fcn_gtam = torch.load("BUSI/Best_BUSI_GTAM/FCN_ResNet50_GTAM_best.pth", map_location=DEVICE)
fcn_gtam.load_state_dict(ckpt_fcn_gtam["state_dict"])
scores_fcn_gtam = evaluate_resnet(fcn_gtam, "FCN_ResNet50_GTAM")

deeplab_gtam = DeepLabV3_ResNet50_GTAM(num_classes=NUM_CLASSES).to(DEVICE)
ckpt_deeplab_gtam = torch.load("BUSI/Best_BUSI_GTAM/DeepLabV3_ResNet50_GTAM_best.pth", map_location=DEVICE)
deeplab_gtam.load_state_dict(ckpt_deeplab_gtam["state_dict"])
scores_deeplab_gtam = evaluate_resnet(deeplab_gtam, "DeepLabV3_ResNet50_GTAM")

# ---------------------------
# Summary
# ---------------------------
print("\n=== COMPARISON (Models on BUSI Dataset) ===")
for key in scores_unet.keys():
    print(f"{key:12s}  "
          f"UNet: {scores_unet[key]:.4f}   "
          f"UNet++: {scores_unetpp[key]:.4f}   "
          f"AttUNet: {scores_att[key]:.4f}   "
          f"FCN50: {scores_fcn[key]:.4f}   "
          f"DeepLabV3: {scores_deeplab[key]:.4f}   "
          f"UNet_GTAM: {scores_unet_GTAM[key]:.4f}   "
          f"UNet++_GTAM: {scores_unetpp_GTAM[key]:.4f}   "
          f"AttUNet_GTAM: {scores_att_GTAM[key]:.4f}"
          f"FCN50_GTAM: {scores_fcn_gtam[key]:.4f}"
          f"DeepLabV3_GTAM: {scores_deeplab_gtam[key]:.4f}"
    )



=== Results for U-Net ===
accuracy    : 0.9235
precision   : 0.5548
recall      : 0.4389
specificity : 0.9805
dice        : 0.4341
iou         : 0.3441

=== Results for U-Net++ ===
accuracy    : 0.9227
precision   : 0.5952
recall      : 0.4397
specificity : 0.9784
dice        : 0.4429
iou         : 0.3541

=== Results for AttentionUNet ===
accuracy    : 0.9308
precision   : 0.5940
recall      : 0.5030
specificity : 0.9801
dice        : 0.5036
iou         : 0.4216

=== Results for FCN_ResNet50 ===
accuracy    : 0.9092
precision   : 0.4544
recall      : 0.5109
specificity : 0.9405
dice        : 0.4265
iou         : 0.3471

=== Results for DeepLabV3_ResNet50 ===
accuracy    : 0.9235
precision   : 0.5693
recall      : 0.4968
specificity : 0.9698
dice        : 0.4782
iou         : 0.3881

=== Results for U-Net_GTAM ===
accuracy    : 0.9272
precision   : 0.5763
recall      : 0.4952
specificity : 0.9688
dice        : 0.4728
iou         : 0.3813

=== Results for U-Net++_GTAM ===
accuracy    :

In [1]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2

from unet import UNET
from unetpp import UNETPP
from att_unet import AttentionUNet

from unet_gtam import UNET_GTAM
from unetpp_gtam import  UNETPP_GTAM
from att_unet_gtam import AttentionUNet_GTAM
from fcn_resnet50_gtam import FCN_ResNet50_GTAM
from DeepLabV3_gtam import DeepLabV3_ResNet50_GTAM

from dataset_UDIAT import UDIATDataset
from metrics import compute_metrics

from torchvision.models.segmentation import fcn_resnet50
from torchvision.models.segmentation import deeplabv3_resnet50

# ---------------------------
# Device and config
# ---------------------------
DEVICE = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)
NUM_CLASSES = 2

TEST_IMG_DIR = "UDIAT_Data/test/images"
TEST_MASK_DIR = "UDIAT_Data/test/masks"


# ---------------------------
# Transform
# ---------------------------
test_transform = A.Compose([
    A.Resize(height=256, width=256),
    A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
    ToTensorV2(),
])

val_ds = UDIATDataset(TEST_IMG_DIR, TEST_MASK_DIR, transform=test_transform)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)


# ---------------------------
# Evaluation functions
# ---------------------------

def evaluate(model, model_name):
    """For UNet / UNet++ / Attention / GTAM models."""
    model.eval()
    all_metrics = {
        "accuracy": [], "precision": [], "recall": [],
        "specificity": [], "dice": [], "iou": []
    }

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(DEVICE)
            masks = masks.long().to(DEVICE)

            logits = model(images)
            if isinstance(logits, dict):
                logits = logits["out"]

            probs = torch.softmax(logits, dim=1)[:, 1]
            preds = (probs > 0.5).float()

            m = compute_metrics(preds.cpu(), masks.cpu())
            for k, v in m.items():
                all_metrics[k].append(v)

    final_scores = {k: float(np.mean(v)) for k, v in all_metrics.items()}

    print(f"\n=== Results for {model_name} ===")
    for k, v in final_scores.items():
        print(f"{k:12s}: {v:.4f}")

    return final_scores


def evaluate_resnet(model, model_name):
    """For FCN-ResNet50 / DeepLabV3 models."""
    model.eval()
    all_metrics = {
        "accuracy": [], "precision": [], "recall": [],
        "specificity": [], "dice": [], "iou": []
    }

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(DEVICE)

            # Convert grayscale → RGB
            if images.shape[1] == 1:
                images = images.repeat(1, 3, 1, 1)

            masks = masks.long().to(DEVICE)

            outputs = model(images)
            logits = outputs["out"]

            probs = torch.softmax(logits, dim=1)[:, 1]
            preds = (probs > 0.5).float()

            m = compute_metrics(preds.cpu(), masks.cpu())
            for k, v in m.items():
                all_metrics[k].append(v)

    final_scores = {k: float(np.mean(v)) for k, v in all_metrics.items()}

    print(f"\n=== Results for {model_name} ===")
    for k, v in final_scores.items():
        print(f"{k:12s}: {v:.4f}")

    return final_scores


# ---------------------------
# Evaluate U-Net
# ---------------------------
unet = UNET(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE)
ckpt_unet = torch.load("Best_UDIAT/UNet_best.pth", map_location=DEVICE)
unet.load_state_dict(ckpt_unet["state_dict"])
scores_unet = evaluate(unet, "U-Net")

# ---------------------------
# Evaluate U-Net++
# ---------------------------
unetpp = UNETPP(in_channels=1, out_channels=NUM_CLASSES, deep_supervision=False).to(DEVICE)
ckpt_unetpp = torch.load("Best_UDIAT/UNetPP_best.pth", map_location=DEVICE)
unetpp.load_state_dict(ckpt_unetpp["state_dict"])
scores_unetpp = evaluate(unetpp, "U-Net++")

# ---------------------------
# Evaluate Attention U-Net
# ---------------------------
att_unet = AttentionUNet(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE)
ckpt_att = torch.load("Best_UDIAT/AttentionUNet_best.pth", map_location=DEVICE)
att_unet.load_state_dict(ckpt_att["state_dict"])
scores_att = evaluate(att_unet, "AttentionUNet")

# ---------------------------
# Evaluate DeepLabV3-ResNet50
# ---------------------------
deeplab = deeplabv3_resnet50(weights=None, weights_backbone=None, aux_loss=False)
deeplab.classifier[-1] = torch.nn.Conv2d(256, NUM_CLASSES, kernel_size=1)
deeplab = deeplab.to(DEVICE)
ckpt_deeplab = torch.load("Best_UDIAT/DeepLabV3_best.pth", map_location=DEVICE)
deeplab.load_state_dict(ckpt_deeplab["state_dict"])
scores_deeplab = evaluate_resnet(deeplab, "DeepLabV3_ResNet50")


# ---------------------------
# Evaluate GTAM Models
# ---------------------------
unet_GTAM = UNET_GTAM(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE)
ckpt_unet_GTAM = torch.load("Best_UDIAT_GTAM/UNet_GTAM_best.pth", map_location=DEVICE)
unet_GTAM.load_state_dict(ckpt_unet_GTAM["state_dict"])
scores_unet_GTAM = evaluate(unet_GTAM, "U-Net_GTAM")

unetpp_GTAM = UNETPP_GTAM(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE)
ckpt_unetpp_GTAM = torch.load("Best_UDIAT_GTAM/UNetPP_GTAM_best.pth", map_location=DEVICE)
unetpp_GTAM.load_state_dict(ckpt_unetpp_GTAM["state_dict"])
scores_unetpp_GTAM = evaluate(unetpp_GTAM, "U-Net++_GTAM")

att_GTAM = AttentionUNet_GTAM(in_channels=1, out_channels=NUM_CLASSES).to(DEVICE)
ckpt_att_GTAM = torch.load("Best_UDIAT_GTAM/AttentionUNet_GTAM_best.pth", map_location=DEVICE)
att_GTAM.load_state_dict(ckpt_att_GTAM["state_dict"])
scores_att_GTAM = evaluate(att_GTAM, "AttentionUNet_GTAM")

deeplab_gtam = DeepLabV3_ResNet50_GTAM(num_classes=NUM_CLASSES).to(DEVICE)
ckpt_deeplab_gtam = torch.load("Best_UDIAT_GTAM/DeepLabV3_ResNet50_GTAM_best.pth", map_location=DEVICE)
deeplab_gtam.load_state_dict(ckpt_deeplab_gtam["state_dict"])
scores_deeplab_gtam = evaluate_resnet(deeplab_gtam, "DeepLabV3_ResNet50_GTAM")

# ---------------------------
# Summary
# ---------------------------
print("\n=== COMPARISON (Models on UDIAT Dataset) ===")
for key in scores_unet.keys():
    print(f"{key:12s}  "
          f"UNet: {scores_unet[key]:.4f}   "
          f"UNet++: {scores_unetpp[key]:.4f}   "
          f"AttUNet: {scores_att[key]:.4f}   "
          f"DeepLabV3: {scores_deeplab[key]:.4f}   "
          f"UNet_GTAM: {scores_unet_GTAM[key]:.4f}   "
          f"UNet++_GTAM: {scores_unetpp_GTAM[key]:.4f}   "
          f"AttUNet_GTAM: {scores_att_GTAM[key]:.4f}   "
          f"DeepLabV3_GTAM: {scores_deeplab_gtam[key]:.4f}"
    )



=== Results for U-Net ===
accuracy    : 0.8926
precision   : 0.4728
recall      : 0.6682
specificity : 0.9422
dice        : 0.4387
iou         : 0.3445

=== Results for U-Net++ ===
accuracy    : 0.8853
precision   : 0.5479
recall      : 0.5324
specificity : 0.9753
dice        : 0.3595
iou         : 0.2738

=== Results for AttentionUNet ===
accuracy    : 0.9237
precision   : 0.6575
recall      : 0.7633
specificity : 0.9772
dice        : 0.5869
iou         : 0.4531

=== Results for DeepLabV3_ResNet50 ===
accuracy    : 0.9255
precision   : 0.5490
recall      : 0.4044
specificity : 0.9886
dice        : 0.4210
iou         : 0.3328

=== Results for U-Net_GTAM ===
accuracy    : 0.9075
precision   : 0.7148
recall      : 0.5086
specificity : 0.9919
dice        : 0.4702
iou         : 0.3703

=== Results for U-Net++_GTAM ===
accuracy    : 0.8968
precision   : 0.5408
recall      : 0.6588
specificity : 0.9559
dice        : 0.4576
iou         : 0.3348

=== Results for AttentionUNet_GTAM ===
accurac