In [None]:
# ===================================================================================
# Training Script for U-Net / U-Net++ / Attention U-Net / FCN / GTAM Models ON UDIAT
# ===================================================================================

import os
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader

# Base models
from unet import UNET
from unetpp import UNETPP
from att_unet import AttentionUNet
from torchvision.models.segmentation import fcn_resnet50
from torchvision.models.segmentation import deeplabv3_resnet50

# GTAM variants
from unet_gtam import UNET_GTAM
from unetpp_gtam import UNETPP_GTAM
from att_unet_gtam import AttentionUNet_GTAM
from fcn_resnet50_gtam import FCN_ResNet50_GTAM
from DeepLabV3_gtam import DeepLabV3_ResNet50_GTAM

from dataset_UDIAT import UDIATDataset
from utils import train_fn, validate_fn, save_checkpoint, save_predictions, train_fn_resnet , validate_fn_resnet,save_predictions_resnet


# ==============================================================
# CONFIGURATION
# ==============================================================

LEARNING_RATE = 1e-4
DEVICE = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)

BATCH_SIZE = 8
NUM_EPOCHS = 60
NUM_WORKERS = 2
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
PIN_MEMORY = True

NUM_CLASSES = 2
IN_CHANNELS = 1


# ==============================================================
# MODEL FLAGS — ONLY ONE SHOULD BE TRUE
# ==============================================================

Train_UNet = False
Train_UNetPP = False
Train_AttentionUNet = False
Train_FCN_ResNet50 = False
Train_DeepLabV3 = False

Train_UNET_GTAM = False
Train_UNetPP_GTAM = False
Train_AttentionUNet_GTAM = False
Train_FCN_ResNet50_GTAM = False
Train_DeepLabV3_ResNet50_GTAM = False

# ==============================================================
# SET OUTPUT PATHS BASED ON MODEL TYPE
# ==============================================================

def set_paths(model_name, is_gtam=False):
    if is_gtam:
        base_best = "Best_UDIAT_GTAM"
        base_ckpt = "Checkpoints_UDIAT_GTAM"
        base_out = "Outputs_UDIAT_GTAM"
    else:
        base_best = "Best_UDIAT"
        base_ckpt = "Checkpoints_UDIAT"
        base_out = "Outputs_UDIAT"

    best_path = f"{base_best}/{model_name}_best.pth"
    ckpt_path = f"{base_ckpt}/{model_name}_checkpoint.pth"
    out_path = f"{base_out}/{model_name}_outputs"

    os.makedirs(base_best, exist_ok=True)
    os.makedirs(base_ckpt, exist_ok=True)
    os.makedirs(out_path, exist_ok=True)

    return best_path, ckpt_path, out_path



def set_all_flags_false():
    global Train_UNet, Train_UNetPP, Train_AttentionUNet
    global Train_UNET_GTAM, Train_UNetPP_GTAM, Train_AttentionUNet_GTAM, IS_GTAM
    global Train_FCN_ResNet50, Train_DeepLabV3,Train_DeepLabV3_ResNet50_GTAM

    Train_UNet = False
    Train_UNetPP = False
    Train_AttentionUNet = False
    Train_UNET_GTAM = False
    Train_UnetPP_GTAM = False
    Train_AttentionUNet_GTAM = False
    Train_FCN_ResNet50 = False
    Train_FCN_ResNet50_GTAM = False
    Train_DeepLabV3 = False
    Train_DeepLabV3_ResNet50_GTAM = False
    IS_GTAM = False



TRAIN_IMG_DIR = "UDIAT_Data/train/images"
TRAIN_MASK_DIR = "UDIAT_Data/train/masks"
VAL_IMG_DIR = "UDIAT_Data/validation/images"
VAL_MASK_DIR = "UDIAT_Data/validation/masks"

# MAIN TRAINING

def main():

    #  Albumentations Augmentations 
    if MODEL_NAME in ["FCN_ResNet50", "DeepLabV3", "FCN_ResNet50_GTAM","DeepLabV3_ResNet50_GTAM"]:
        # convert 1-channel → 3-channel 
        train_transform = A.Compose([
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
            ToTensorV2(),
        ])

        val_transform = A.Compose([
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
            ToTensorV2(),
        ])
    else:
        train_transform = A.Compose([
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
            ToTensorV2()
        ])

        val_transform = A.Compose([
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
            ToTensorV2()
        ])

    #  DATASETS 
    train_ds = UDIATDataset(
        image_dir=TRAIN_IMG_DIR,
        mask_dir=TRAIN_MASK_DIR,
        transform=train_transform,
    )
    val_ds = UDIATDataset(
        image_dir=VAL_IMG_DIR,
        mask_dir=VAL_MASK_DIR,
        transform=val_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        drop_last=True,
    )


    # ---------------- Model Selection ----------------
    print(f"\nTraining model: {MODEL_NAME}\n")

    if MODEL_NAME == "UNet_GTAM":
        model = UNET_GTAM(IN_CHANNELS, NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "UNetPP":
        model = UNETPP(IN_CHANNELS, NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "UNetPP_GTAM":
        model = UNETPP_GTAM(IN_CHANNELS, NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "FCN_ResNet50_GTAM":
        model = FCN_ResNet50_GTAM(num_classes=NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "AttentionUNet":
        model = AttentionUNet(IN_CHANNELS, NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "AttentionUNet_GTAM":
        model = AttentionUNet_GTAM(IN_CHANNELS, NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "FCN_ResNet50":
        model = fcn_resnet50(weights=None, weights_backbone=None)
        model.classifier[4] = nn.Conv2d(512, NUM_CLASSES, kernel_size=1)
        model = model.to(DEVICE)
    elif MODEL_NAME == "DeepLabV3":
        model = deeplabv3_resnet50(weights=None, weights_backbone=None)
        model.classifier[-1] = nn.Conv2d(256, NUM_CLASSES, kernel_size=1)
        model = model.to(DEVICE)
    elif MODEL_NAME == "DeepLabV3_ResNet50_GTAM":
        model = DeepLabV3_ResNet50_GTAM(num_classes=NUM_CLASSES).to(DEVICE)
    else:
        model = UNET(IN_CHANNELS, NUM_CLASSES).to(DEVICE)

    

    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    
    #  TRAINING LOOP 
    patience = 4
    patience_counter = 0
    best_val_dice = -1.0

    for epoch in range(1, NUM_EPOCHS + 1):

        if MODEL_NAME in ["FCN_ResNet50", "DeepLabV3", "FCN_ResNet50_GTAM","DeepLabV3_ResNet50_GTAM"]:
            train_loss, train_dice = train_fn_resnet(
                train_loader, model, optimizer, loss_fn, DEVICE, NUM_CLASSES
            )
        else:
            train_loss, train_dice = train_fn(
                train_loader, model, optimizer, loss_fn, DEVICE, NUM_CLASSES
            )

        if MODEL_NAME in ["FCN_ResNet50", "DeepLabV3", "FCN_ResNet50_GTAM","DeepLabV3_ResNet50_GTAM"]:
            val_loss, val_dice = validate_fn_resnet(
                val_loader, model, loss_fn, DEVICE, NUM_CLASSES
            )
        else:
            val_loss, val_dice = validate_fn(
                val_loader, model, loss_fn, DEVICE, NUM_CLASSES
            )

        print(f"Epoch {epoch:03d} | "
              f"train_loss: {train_loss:.4f}  train_dice: {train_dice:.4f} | "
              f"val_loss: {val_loss:.4f}  val_dice: {val_dice:.4f}")

        # Save best model
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            patience_counter = 0
            save_checkpoint(model, optimizer, epoch, path=best)
        else:
            patience_counter += 1

        # Early stopping
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered at epoch {epoch}.")
            break

        # Occasional rolling checkpoint
        if epoch % 5 == 0:
            save_checkpoint(model, optimizer, epoch, path=checkpt)
    
    if MODEL_NAME in ["FCN_ResNet50", "DeepLabV3" ,"FCN_ResNet50_GTAM","DeepLabV3_ResNet50_GTAM"]:
        save_predictions_resnet(model, val_loader, DEVICE, folder=final_folder)
    else:
        save_predictions(model, val_loader, DEVICE, folder=final_folder)
    print("Saved predicted example masks")


def run_training():
    main() 

if __name__ == "__main__":
    models_to_train = [
        # "UNetPP",
        "DeepLabV3_ResNet50_GTAM",
        # "DeepLabV3",
        # "UNet",
        # "AttentionUNet",
        # "UNet_GTAM",
        # "UNetPP_GTAM",
        # "AttentionUNet_GTAM",
        
    ]

    for model in models_to_train:
        print("\n========================================")
        print(f"   TRAINING MODEL: {model}")
        print("========================================\n")

        # turn all switches off
        set_all_flags_false()

        MODEL_NAME = model

       
        # turn on only the required one
        if model == "UNet":
            Train_UNet = True
        elif model == "UNetPP":
            Train_UNetPP = True
        elif model == "AttentionUNet":
            Train_AttentionUNet = True
        elif model == "UNet_GTAM":
            Train_UNET_GTAM = True
            IS_GTAM = True
        elif model == "UNetPP_GTAM":
            Train_UnetPP_GTAM = True
            IS_GTAM = True
        elif model == "AttentionUNet_GTAM":
            Train_AttentionUNet_GTAM = True
            IS_GTAM = True
        elif model == "DeepLabV3":
            Train_DeepLabV3 = True
        elif model == "DeepLabV3_ResNet50_GTAM":
            Train_DeepLabV3_ResNet50_GTAM = True
            IS_GTAM = True


        best, checkpt, final_folder = set_paths(MODEL_NAME, is_gtam=IS_GTAM)

        if MODEL_NAME in ["FCN_ResNet50", "DeepLabV3", "FCN_ResNet50_GTAM","DeepLabV3_ResNet50_GTAM"]:
            IN_CHANNELS = 3
        else:
            IN_CHANNELS = 1

        # run training
        run_training()

        print(f"\n FINISHED TRAINING {model} \n")

    print("\n ALL MODELS TRAINED SUCCESSFULLY! ")
    


In [None]:
# =================================================================================
# Training Script for U-Net / U-Net++ / Attention U-Net / FCN / GTAM Models ON BUSI
# ==================================================================================

import os
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader

# Base models
from unet import UNET
from unetpp import UNETPP
from att_unet import AttentionUNet
from torchvision.models.segmentation import fcn_resnet50
from torchvision.models.segmentation import deeplabv3_resnet50

# GTAM variants
from unet_gtam import UNET_GTAM
from unetpp_gtam import UNETPP_GTAM
from att_unet_gtam import AttentionUNet_GTAM
from fcn_resnet50_gtam import FCN_ResNet50_GTAM
from DeepLabV3_gtam import DeepLabV3_ResNet50_GTAM

from dataset_BUSI import BUSIDataset
from utils import train_fn, validate_fn, save_checkpoint, save_predictions, train_fn_resnet , validate_fn_resnet,save_predictions_resnet


# ==============================================================
# CONFIGURATION
# ==============================================================

LEARNING_RATE = 1e-4
DEVICE = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)

BATCH_SIZE = 8
NUM_EPOCHS = 60
NUM_WORKERS = 2
IMAGE_HEIGHT = 256
IMAGE_WIDTH = 256
PIN_MEMORY = True

NUM_CLASSES = 2
IN_CHANNELS = 1


# ==============================================================
# MODEL FLAGS — ONLY ONE SHOULD BE TRUE
# ==============================================================

Train_UNet = False
Train_UNetPP = False
Train_AttentionUNet = False
Train_FCN_ResNet50 = False
Train_DeepLabV3 = False

Train_UNET_GTAM = False
Train_UNetPP_GTAM = False
Train_AttentionUNet_GTAM = False
Train_FCN_ResNet50_GTAM = False
Train_DeepLabV3_ResNet50_GTAM = False


# ==============================================================
# SET OUTPUT PATHS BASED ON MODEL TYPE
# ==============================================================

def set_paths(model_name, is_gtam=False):
    if is_gtam:
        base_best = "BUSI/Best_BUSI_GTAM"
        base_ckpt = "BUSI/Checkpoints_BUSI_GTAM"
        base_out = "BUSI/Outputs_BUSI_GTAM"
    else:
        base_best = "BUSI/Best_BUSI"
        base_ckpt = "BUSI/Checkpoints_BUSI"
        base_out = "BUSI/Outputs_BUSI"

    best_path = f"{base_best}/{model_name}_best.pth"
    ckpt_path = f"{base_ckpt}/{model_name}_checkpoint.pth"
    out_path = f"{base_out}/{model_name}_outputs"

    os.makedirs(base_best, exist_ok=True)
    os.makedirs(base_ckpt, exist_ok=True)
    os.makedirs(out_path, exist_ok=True)

    return best_path, ckpt_path, out_path



def set_all_flags_false():
    global Train_UNet, Train_UNetPP, Train_AttentionUNet
    global Train_UNET_GTAM, Train_UNetPP_GTAM, Train_AttentionUNet_GTAM, IS_GTAM
    global Train_FCN_ResNet50, Train_DeepLabV3,Train_DeepLabV3_ResNet50_GTAM

    Train_UNet = False
    Train_UNetPP = False
    Train_AttentionUNet = False
    Train_UNET_GTAM = False
    Train_UnetPP_GTAM = False
    Train_AttentionUNet_GTAM = False
    Train_FCN_ResNet50 = False
    Train_FCN_ResNet50_GTAM = False
    Train_DeepLabV3 = False
    Train_DeepLabV3_ResNet50_GTAM = False
    IS_GTAM = False



TRAIN_IMG_DIR = "BUSI/Data/train/images"
TRAIN_MASK_DIR = "BUSI/Data/train/masks"
VAL_IMG_DIR = "BUSI/Data/validation/images"
VAL_MASK_DIR = "BUSI/Data/validation/masks"

# MAIN TRAINING

def main():

    #  Albumentations Augmentations 
    if MODEL_NAME in ["FCN_ResNet50", "DeepLabV3", "FCN_ResNet50_GTAM","DeepLabV3_ResNet50_GTAM"]:
        # convert 1-channel → 3-channel 
        train_transform = A.Compose([
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
            ToTensorV2(),
        ])

        val_transform = A.Compose([
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0),
            ToTensorV2(),
        ])
    else:
        train_transform = A.Compose([
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
            ToTensorV2()
        ])

        val_transform = A.Compose([
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
            ToTensorV2()
        ])

    #  DATASETS 
    train_ds = BUSIDataset(
        image_dir=TRAIN_IMG_DIR,
        mask_dir=TRAIN_MASK_DIR,
        transform=train_transform,
    )
    val_ds = BUSIDataset(
        image_dir=VAL_IMG_DIR,
        mask_dir=VAL_MASK_DIR,
        transform=val_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        drop_last=True,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        drop_last=True,
    )


    # ---------------- Model Selection ----------------
    print(f"\nTraining model: {MODEL_NAME}\n")

    if MODEL_NAME == "UNet_GTAM":
        model = UNET_GTAM(IN_CHANNELS, NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "UNetPP":
        model = UNETPP(IN_CHANNELS, NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "UNetPP_GTAM":
        model = UNETPP_GTAM(IN_CHANNELS, NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "FCN_ResNet50_GTAM":
        model = FCN_ResNet50_GTAM(num_classes=NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "AttentionUNet":
        model = AttentionUNet(IN_CHANNELS, NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "AttentionUNet_GTAM":
        model = AttentionUNet_GTAM(IN_CHANNELS, NUM_CLASSES).to(DEVICE)
    elif MODEL_NAME == "FCN_ResNet50":
        model = fcn_resnet50(weights=None, weights_backbone=None)
        model.classifier[4] = nn.Conv2d(512, NUM_CLASSES, kernel_size=1)
        model = model.to(DEVICE)
    elif MODEL_NAME == "DeepLabV3":
        model = deeplabv3_resnet50(weights=None, weights_backbone=None)
        model.classifier[-1] = nn.Conv2d(256, NUM_CLASSES, kernel_size=1)
        model = model.to(DEVICE)
    elif MODEL_NAME == "DeepLabV3_ResNet50_GTAM":
        model = DeepLabV3_ResNet50_GTAM(num_classes=NUM_CLASSES).to(DEVICE)
    else:
        model = UNET(IN_CHANNELS, NUM_CLASSES).to(DEVICE)

    

    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    
    #  TRAINING LOOP 
    patience = 10
    patience_counter = 0
    best_val_dice = -1.0

    for epoch in range(1, NUM_EPOCHS + 1):

        if MODEL_NAME in ["FCN_ResNet50", "DeepLabV3", "FCN_ResNet50_GTAM","DeepLabV3_ResNet50_GTAM"]:
            train_loss, train_dice = train_fn_resnet(
                train_loader, model, optimizer, loss_fn, DEVICE, NUM_CLASSES
            )
        else:
            train_loss, train_dice = train_fn(
                train_loader, model, optimizer, loss_fn, DEVICE, NUM_CLASSES
            )

        if MODEL_NAME in ["FCN_ResNet50", "DeepLabV3", "FCN_ResNet50_GTAM","DeepLabV3_ResNet50_GTAM"]:
            val_loss, val_dice = validate_fn_resnet(
                val_loader, model, loss_fn, DEVICE, NUM_CLASSES
            )
        else:
            val_loss, val_dice = validate_fn(
                val_loader, model, loss_fn, DEVICE, NUM_CLASSES
            )

        print(f"Epoch {epoch:03d} | "
              f"train_loss: {train_loss:.4f}  train_dice: {train_dice:.4f} | "
              f"val_loss: {val_loss:.4f}  val_dice: {val_dice:.4f}")

        # Save best model
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            patience_counter = 0
            save_checkpoint(model, optimizer, epoch, path=best)
        else:
            patience_counter += 1

        # Early stopping
        if patience_counter >= patience:
            print(f"\nEarly stopping triggered at epoch {epoch}.")
            break

        # Occasional rolling checkpoint
        if epoch % 5 == 0:
            save_checkpoint(model, optimizer, epoch, path=checkpt)
    
    if MODEL_NAME in ["FCN_ResNet50", "DeepLabV3" ,"FCN_ResNet50_GTAM","DeepLabV3_ResNet50_GTAM"]:
        save_predictions_resnet(model, val_loader, DEVICE, folder=final_folder)
    else:
        save_predictions(model, val_loader, DEVICE, folder=final_folder)
    print("Saved predicted example masks")


def run_training():
    main() 

if __name__ == "__main__":
    models_to_train = [
        # "UNetPP",
        "DeepLabV3_ResNet50_GTAM",
        # "DeepLabV3",
        # "UNet",
        # "AttentionUNet",
        # "UNet_GTAM",
        # "UNetPP_GTAM",
        # "AttentionUNet_GTAM",
        
    ]

    for model in models_to_train:
        print("\n========================================")
        print(f"   TRAINING MODEL: {model} on BUSI")
        print("========================================\n")

        # turn all switches off
        set_all_flags_false()

        MODEL_NAME = model

       
        # turn on only the required one
        if model == "UNet":
            Train_UNet = True
        elif model == "UNetPP":
            Train_UNetPP = True
        elif model == "AttentionUNet":
            Train_AttentionUNet = True
        elif model == "UNet_GTAM":
            Train_UNET_GTAM = True
            IS_GTAM = True
        elif model == "UNetPP_GTAM":
            Train_UnetPP_GTAM = True
            IS_GTAM = True
        elif model == "AttentionUNet_GTAM":
            Train_AttentionUNet_GTAM = True
            IS_GTAM = True
        elif model == "DeepLabV3":
            Train_DeepLabV3 = True
        elif model == "DeepLabV3_ResNet50_GTAM":
            Train_DeepLabV3_ResNet50_GTAM = True
            IS_GTAM = True
            
        best, checkpt, final_folder = set_paths(MODEL_NAME, is_gtam=IS_GTAM)

        if MODEL_NAME in ["FCN_ResNet50", "DeepLabV3", "FCN_ResNet50_GTAM","DeepLabV3_ResNet50_GTAM"]:
            IN_CHANNELS = 3
        else:
            IN_CHANNELS = 1

        # run training
        run_training()

        print(f"\n FINISHED TRAINING {model} \n")

    print("\n ALL MODELS TRAINED SUCCESSFULLY! ")
    


In [None]:
# visualize_gtam_heatmaps

import os
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt   

import albumentations as A
from albumentations.pytorch import ToTensorV2

from dataset_BUSI import BUSIDataset    
from att_unet_gtam import AttentionUNet_GTAM
from unetpp_gtam import UNETPP_GTAM
from gtam import GaborTAM
from metrics import compute_metrics     

DEVICE = "mps" if torch.backends.mps.is_available() else (
    "cuda" if torch.cuda.is_available() else "cpu"
)

NUM_CLASSES = 2

TEST_IMG_DIR = "UDIAT_Data/test/images"        
TEST_MASK_DIR = "UDIAT_Data/test/masks"

CKPT_PATH = "Best_UDIAT_GTAM/UNetPP_GTAM_best.pth"   
OUT_FOLDER = "GTAM_Heatmaps_UDIAT"                 


# -----------------------------
# Albumentations transform
# -----------------------------
test_transform = A.Compose(
    [
        A.Resize(height=256, width=256),
        A.Normalize(mean=[0.0], std=[1.0], max_pixel_value=255.0),
        ToTensorV2(),
    ]
)

test_ds = UDIATDataset(TEST_IMG_DIR, TEST_MASK_DIR, transform=test_transform)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)


# -----------------------------
# Helper: find first GaborTAM
# -----------------------------
def find_gtam_module(model):
    for m in model.modules():
        if isinstance(m, GaborTAM):
            return m
    raise RuntimeError("No GaborTAM module found in model.")


# -----------------------------
# Load model + checkpoint
# -----------------------------
def load_gtam_model():
    model = UNETPP_GTAM(1, 2).to(DEVICE)
    ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
    model.load_state_dict(ckpt["state_dict"])
    model.eval()
    return model
 

        
# -----------------------------
# Main visualization function
# -----------------------------
@torch.no_grad()
def save_gtam_heatmaps():
    os.makedirs(OUT_FOLDER, exist_ok=True)

    model = load_gtam_model()
    gtab = find_gtam_module(model)

    print(f"Found GTAM module: {gtab.__class__.__name__}")
    print(f"Saving outputs to: {OUT_FOLDER}")

    for idx, (images, masks) in enumerate(test_loader):
        images = images.to(DEVICE)          # (1,1,H,W)
        masks = masks.long()                # (1,H,W) 

        # forward pass to generate logits + fill gtab.last_attn
        logits = model(images)              # (1,2,H,W)

        # prediction
        probs = torch.softmax(logits, dim=1)[:, 1]       # (1,H,W)
        preds = (probs > 0.5).float().cpu().numpy()[0]   # (H,W)

        # GTAM attention map from module
        attn = gtab.last_attn                    # (1,1,h,w)
        if attn is None:
            raise RuntimeError("GTAM last_attn is None. Did you forward through the module?")

        # upsample attention to image size
        attn_up = F.interpolate(
            attn, size=images.shape[-2:], mode="bilinear", align_corners=False
        )                                           # (1,1,H,W)
        attn_up = attn_up.cpu().numpy()[0, 0]      # (H,W)

        # normalize attention to [0,1]
        attn_min, attn_max = attn_up.min(), attn_up.max()
        if attn_max > attn_min:
            attn_norm = (attn_up - attn_min) / (attn_max - attn_min)
        else:
            attn_norm = np.zeros_like(attn_up)

        # ---------- COLOR HEATMAP ----------
        # apply colormap (e.g. 'jet')
        attn_color = plt.cm.jet(attn_norm)              # (H,W,4) RGBA
        attn_color = (attn_color[..., :3] * 255).astype(np.uint8)  # (H,W,3) RGB

        # ---------- ORIGINAL IMAGE (GRAY) ----------
        img_np = images.cpu().numpy()[0, 0]        # (H,W)
        img_min, img_max = img_np.min(), img_np.max()
        if img_max > img_min:
            img_norm = (img_np - img_min) / (img_max - img_min)
        else:
            img_norm = np.zeros_like(img_np)
        img_uint8 = (img_norm * 255).astype(np.uint8)

        # mask (0/1 → 0/255)
        mask_np = masks.numpy()[0] * 255
        mask_uint8 = mask_np.astype(np.uint8)

        # pred (0/1 → 0/255)
        pred_uint8 = (preds * 255).astype(np.uint8)

        # ---------- OVERLAY (GRAY IMAGE + COLOR HEATMAP) ----------
        img_rgb = np.stack([img_uint8, img_uint8, img_uint8], axis=-1)  # (H,W,3)
        alpha = 0.5
        overlay = (alpha * img_rgb + (1 - alpha) * attn_color).astype(np.uint8)

        # ---------- SAVE ALL ----------
        Image.fromarray(img_uint8, mode="L").save(
            os.path.join(OUT_FOLDER, f"sample_{idx:03d}_image.png")
        )
        Image.fromarray(mask_uint8, mode="L").save(
            os.path.join(OUT_FOLDER, f"sample_{idx:03d}_mask.png")
        )
        Image.fromarray(pred_uint8, mode="L").save(
            os.path.join(OUT_FOLDER, f"sample_{idx:03d}_pred.png")
        )
        # raw grayscale attention (optional)
        attn_gray_uint8 = (attn_norm * 255).astype(np.uint8)
        Image.fromarray(attn_gray_uint8, mode="L").save(
            os.path.join(OUT_FOLDER, f"sample_{idx:03d}_gtam_gray.png")
        )
        # color heatmap
        Image.fromarray(attn_color).save(
            os.path.join(OUT_FOLDER, f"sample_{idx:03d}_gtam_color.png")
        )
        # overlay
        Image.fromarray(overlay).save(
            os.path.join(OUT_FOLDER, f"sample_{idx:03d}_overlay.png")
        )

        if idx % 20 == 0:
            print(f"Saved heatmaps for sample {idx}")

    print("Done. Saved GTAM visualizations for all test images.")


if __name__ == "__main__":
    save_gtam_heatmaps()
