<a href="https://colab.research.google.com/github/Blaise-bf/thesis-files/blob/main/thesis_analysis_update_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install segmentation_models_pytorch --quiet
!pip install torchsummary --quiet
!pip install torchmetrics --quiet
!pip install -U albumentations --quiet
!pip install timm --quiet
# !pip install nnviz --quiet

In [None]:
from torchvision import transforms
import pandas as pd
import numpy as np
from PIL import Image
import cv2
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import os
import matplotlib.pyplot as plt
import pandas as pd
from torchmetrics.segmentation import DiceScore



In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import cv2
import numpy as np

def resize_to_square(
    img: np.ndarray,
    target_size: int,
    pad_color: int = 0,
    interpolation: int = cv2.INTER_CUBIC
) -> np.ndarray:
    """
    Resizes an image to a square while maintaining aspect ratio.
    Pads the remaining space with `pad_color` (default: 0/black).

    Args:
        img (np.ndarray): Input image (grayscale or color).
        target_size (int): Desired width & height of the square output.
        pad_color (int): Padding color (default 0 for black).

    Returns:
        np.ndarray: Square image of size (target_size, target_size).
    """
    h, w = img.shape[:2]

    # Scale the image to fit inside the target square
    scale = min(target_size / w, target_size / h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    resized_img = cv2.resize(img, (new_w, new_h), interpolation=interpolation )

    # Calculate padding to center the image
    delta_w = target_size - new_w
    delta_h = target_size - new_h
    top = delta_h // 2
    bottom = delta_h - top
    left = delta_w // 2
    right = delta_w - left

    # Apply padding
    square_img = cv2.copyMakeBorder(
        resized_img,
        top, bottom, left, right,
        cv2.BORDER_CONSTANT,
        value=pad_color
    )

    return square_img


def read_image_frompath(
    mask_paths: list[str],
    size: int | None = None,
    pad_color: int = 0,
    interpolation: int = cv2.INTER_CUBIC
) -> list[np.ndarray]:
    """
    Reads images from paths and optionally resizes them to squares.

    Args:
        mask_paths (list[str]): List of image file paths.
        size (int | None): If provided, resize images to (size x size).
        pad_color (int): Padding color (default 0 for black).

    Returns:
        list[np.ndarray]: List of images (resized if `size` is given).
    """
    images = []
    for path in mask_paths:
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            print(f"Warning: Could not read image at {path}")
            continue

        if size is not None:
            img = resize_to_square(img, size, pad_color, interpolation)
        images.append(img)

    return images


def load_base_unetmodel(encoder_name="efficientnet-b7"):
    model = smp.UnetPlusPlus(
    encoder_name=encoder_name,  # EfficientNet-b4 encoder
    encoder_weights="imagenet",      # Pretrained weights
    in_channels=1,                   # RGB channels
    classes=1                       # Single binary mask class
)
    return model


In [None]:
service = 'colab'
# Read in meta data
met_df = pd.read_excel('/content/drive/MyDrive/msc_uhasselt/cath_ai_rev_anon.xlsx')
met_df.rename(columns={'arch (1-ok, 0-no)': 'arch',
                                 'tip (1-ok, 0-no)' : 'tip' }, inplace=True)


if service == 'kaggle':
# Paths to images
    train_ctheter_mask_dir = '/kaggle/input/thesis-files/train_catheter_masks/train_catheter_masks'
    test_ctheter_mask_dir = '/kaggle/input/thesis-files/test_catheter_masks'
    train_atrium_mask_dir = '/kaggle/input/thesis-files/train_atrium_masks/train_atrium_masks'
    test_atrium_mask_dir = '/kaggle/input/thesis-files/test_atrium_masks'
    original_images_dir = '/kaggle/input/thesis-files/train_images/train_images'
    original_test_images_dir = '/kaggle/input/thesis-files/test_images'
else:
    train_ctheter_mask_dir = '/content/drive/MyDrive/msc_uhasselt/train_catheter_masks'
    test_ctheter_mask_dir = '/content/drive/MyDrive/msc_uhasselt/test_catheter_masks'
    train_atrium_mask_dir = '/content/drive/MyDrive/msc_uhasselt/train_atrium_masks'
    test_atrium_mask_dir = '/content/drive/MyDrive/msc_uhasselt/test_atrium_masks'
    original_images_dir = '/content/drive/MyDrive/msc_uhasselt/train_images'
    original_test_images_dir = '/content/drive/MyDrive/msc_uhasselt/test_images'




def get_filenames(dir_path):
    filenames = [f for f in os.listdir(dir_path) if f.endswith('.tif')]
    return filenames

def get_ids(filenames):
    ids = {f.split('.')[0] for f in filenames}
    return ids

def get_full_paths(dir_path, filenames):
    full_paths = [os.path.join(dir_path, f"{img_id}.tif") for img_id in filenames]
    return full_paths


# Get filenames (not full paths yet)
catheter_labels = get_filenames(train_ctheter_mask_dir)
atrial_labels = get_filenames(train_atrium_mask_dir)
test_catheter_labels = get_filenames(test_ctheter_mask_dir)
test_atrial_labels = get_filenames(test_atrium_mask_dir)

# Extract IDs
# The files typically have this form 'IMG-0025-00001.tif'
# A set data structure is used to ensure set concepts like intersection or union can be used
ids_cath = {f.split('.')[0] for f in catheter_labels} # this should return ids like IMG-0025-00001
ids_atria = {f.split('.')[0] for f in atrial_labels}
test_cath_ids = {f.split('.')[0] for f in test_catheter_labels}
test_atria_ids = {f.split('.')[0] for f in test_atrial_labels}

# Find common IDs
common_ids = sorted(list(ids_cath.intersection(ids_atria)))  # sorted is optional, but keeps things in order
common_ids_test = sorted(list(test_cath_ids.intersection(test_atria_ids)))

common_ids_test.remove('IMG-0273-00001') # This image returns blank predictions for both masks, very poor image quality
common_ids_test = sorted(common_ids_test)

# Now build matched paths
valid_cathetr_paths = get_full_paths(train_ctheter_mask_dir, common_ids)
valid_atrial_paths = get_full_paths(train_atrium_mask_dir, common_ids)
valid_test_cathetr_paths = get_full_paths(test_ctheter_mask_dir, common_ids_test)
valid_test_atrial_paths = get_full_paths(test_atrium_mask_dir, common_ids_test)
valid_images_paths = get_full_paths(original_images_dir, common_ids)
valid_test_paths = get_full_paths(original_test_images_dir, common_ids_test)

print(f"Number of images: {len(valid_images_paths)}")
print(f"Number of catheter masks: {len(valid_cathetr_paths)}")
print(f"Number of atrial masks: {len(valid_atrial_paths)}")
print(f"Number of test images: {len(valid_test_paths)}")
print(f"Number of test catheter masks: {len(valid_test_cathetr_paths)}")
print(f"Number of test atrial masks: {len(valid_test_atrial_paths)}")


In [None]:
import os
import requests

def download_file(url, local_filename):
    """Download a file from a URL to a local file."""
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(local_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    return local_filename





def save_file_locally(url, local_file_name):
    if not os.path.exists(local_file_name):
        print(f"File {local_file_name} not found. Downloading from GitHub...")
        try:
            download_file(url, local_file_name)
            print(f"Successfully downloaded {local_file_name}")
        except Exception as e:
            print(f"Failed to download file: {e}")
    else:
        print(f"File {local_file_name} already exists")



file_info = {
    'layers' : ("https://raw.githubusercontent.com/Blaise-bf/unet/refs/heads/master/models/layers.py", 'layers.py'),
    'model'  : ('https://raw.githubusercontent.com/Blaise-bf/unet/refs/heads/master/models/UNet_3Plus.py', 'unet_3plus.py'),
    'weights' : ('https://raw.githubusercontent.com/Blaise-bf/unet/refs/heads/master/models/init_weights.py', 'init_weights.py'),
    'atention_unet' : ('https://raw.githubusercontent.com/Blaise-bf/thesis-files/refs/heads/main/attention_unet.py', 'attention_unet.py'),
    'unet_2plus' : ('https://raw.githubusercontent.com/Blaise-bf/unet/refs/heads/master/models/UNet_2Plus.py', 'unet_2plus.py'),
    'models':      ('https://raw.githubusercontent.com/Blaise-bf/thesis-files/refs/heads/main/models.py', 'models.py'),
}

for key in file_info.keys():
    info = file_info[key]
    save_file_locally(info[0], info[1])


In [None]:
from unet_3plus import UNet_3Plus, UNet_3Plus_DeepSup
# from unet_2plus import UNet_2Plus
from attention_unet import Attention_UNet

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {device}")

In [None]:
# python version
!python --version

# pytorch version
print(torch.__version__)

In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import cv2
from typing import List, Optional, Tuple

# from segmentation_models_pytorch.losses import DiceLoss|
# from segmentation_models_pytorch.metrics import DiceScore


class SegmentationDataset(Dataset):
    def __init__(
        self,
        image_paths: List[str],
        mask_paths: List[str],
        image_size: int = 720,
        transform: Optional[transforms.Compose] = None,
        apply_clahe: bool = True,
        clahe_limit: float = 2.0,
        clahe_tile_grid_size: Tuple[int, int] = (8, 8)
    ):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.image_size = image_size
        self.transform = transform
        self.apply_clahe = apply_clahe
        self.clahe_limit = clahe_limit
        self.clahe_tile_grid_size = clahe_tile_grid_size

        # Validate inputs
        if len(image_paths) != len(mask_paths):
            raise ValueError("Number of images and masks must be equal")

        # Image normalization transform
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1] range
        ])

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # Load image and mask
        # image = Image.open(self.image_paths[idx]).convert('L')  # Convert to grayscale
        # mask = Image.open(self.mask_paths[idx]).convert('L')
        # Use Open cv to read in gray scale image and mask
        image = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)

        # Apply CLAHE for contrast enhancement if enabled
        if self.apply_clahe:
            # image_np = np.array(image)
            clahe = cv2.createCLAHE(clipLimit=self.clahe_limit, tileGridSize=self.clahe_tile_grid_size)
            image = Image.fromarray(clahe.apply(image))

        # Resize both image and mask
        image = self._resize_to_square(np.array(image), self.image_size)
        mask = self._resize_to_square(np.array(mask), self.image_size)

        # Convert mask to binary (0 or 1)
        mask = (mask > 0).astype(np.float32)  # float32 for PyTorch compatibility

        # Convert to PIL for potential transforms
        # image_pil = Image.fromarray(image)
        # mask_pil = Image.fromarray(mask)

        # Apply additional transforms if specified

        augmented = self.transform(image=image, mask=mask)
            # image_pil = self.transform(image_pil)
            # mask_pil = self.transform(mask_pil)

        # Apply normalization to image only
        image_tensor = augmented['image']
        mask_tensor = augmented['mask'] # Add channel dimension

        if mask_tensor.ndim == 2:
            mask_tensor = mask_tensor.unsqueeze(0)

        return image_tensor, mask_tensor

    @staticmethod
    def _resize_to_square(image: np.ndarray, size: int) -> np.ndarray:
        """Resize image to square while maintaining aspect ratio with padding"""
        h, w = image.shape
        scale = size / max(h, w)
        new_h, new_w = int(h * scale), int(w * scale)

        resized = cv2.resize(image, (new_w, new_h))

        # Pad to make square
        delta_w = size - new_w
        delta_h = size - new_h
        top = delta_h // 2
        bottom = delta_h - top
        left = delta_w // 2
        right = delta_w - left

        return cv2.copyMakeBorder(
            resized,
            top, bottom, left, right,
            cv2.BORDER_CONSTANT,
            value=0
        )

In [None]:
from math import pi
from torch.utils.data import DataLoader
import albumentations as A
# Example transforms for data augmentation

import albumentations as A
SZIE = 960
train_transform = A.Compose([
    # --- Spatial Augmentations ---
    A.HorizontalFlip(p=0.3),  # Safe for atrial anatomy
    A.Rotate(limit=15, p=0.2, border_mode=cv2.BORDER_CONSTANT),

    # Finer GridDistortion (controlled warping)


    # Optional: Combined with ElasticTransform
    A.OneOf([
        A.ElasticTransform(
                alpha=300,
                sigma=10,
                interpolation=cv2.INTER_LINEAR,
                approximate=False,
                same_dxdy=True,
                mask_interpolation=cv2.INTER_NEAREST,
                noise_distribution="gaussian",
                keypoint_remapping_method="mask",
                border_mode=cv2.BORDER_CONSTANT,
                fill=0,
                fill_mask=0,
                p=0.3
),
        A.GridDistortion(distort_limit=0.1, p=0.3)  # Even subtler variant
    ], p=0.4),

    # --- Intensity Augmentations ---
    # A.CLAHE(clip_limit=2.0, p=0.3),
    A.RandomBrightnessContrast(
        brightness_limit=(-0.1, 0.1),
        contrast_limit=(-0.1, 0.1),
        p=0.4
    ),
    A.GaussNoise(std_range=(0.01, 0.05), p=0.2),
    # --- Normalization ---
    A.Normalize(mean=(0.485), std=(0.229)),
    A.ToTensorV2(),
])

val_transform = A.Compose([
    A.Normalize(mean=(0.485), std=(0.229)),
    A.ToTensorV2(),
])



train_dataset = SegmentationDataset(
    image_paths=valid_images_paths,
    mask_paths=valid_cathetr_paths,
    image_size=SZIE,
    transform=train_transform,
    apply_clahe=True
)

val_dataset = SegmentationDataset(
    image_paths=valid_test_paths,
    mask_paths=valid_test_cathetr_paths,
    image_size=SZIE,
    transform=val_transform,
    apply_clahe=True
)

train_data_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, pin_memory=True)
val_data_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,  pin_memory=True)


In [None]:
import matplotlib.pyplot as plt

img, mask = val_dataset[0]
print(img.shape, mask.shape)

plt.subplot(1, 2, 1)
# Access the image tensor at index 0, remove channel dim, convert to numpy
plt.imshow(img.squeeze().numpy(), cmap='gray')
plt.title('Image')
plt.axis('off')

plt.subplot(1, 2, 2)
# Access the mask tensor at index 1, remove channel dim, convert to numpy
plt.imshow(mask.squeeze().numpy(), cmap='gray')
plt.title('Mask')
plt.axis('off')

plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

class HybridSegmentationLoss(nn.Module):
    def __init__(
        self,
        focal_alpha: float = 0.25,
        focal_gamma: float = 2.0,
        window_size: int = 11,
        msssim_weights: Tensor = None,
        eps: float = 1e-8
    ):
        super().__init__()
        self.focal_alpha = focal_alpha
        self.focal_gamma = focal_gamma
        self.window_size = window_size
        self.eps = eps

        # Default weights for MS-SSIM (5 scales)
        if msssim_weights is None:
            msssim_weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
        self.register_buffer('msssim_weights', msssim_weights)

    def forward(self, pred: Tensor, target: Tensor) -> Tensor:
        # Ensure float32 for mixed precision
        pred = pred.float()
        target = target.float()

        # Focal Loss (pixel-level)
        focal_loss = self.focal_loss_with_logits(pred, target)

        # Convert logits to probabilities
        pred_prob = torch.sigmoid(pred)

        # MS-SSIM Loss (patch-level)
        ms_ssim_loss = self.ms_ssim_loss(pred_prob, target)

        # IoU Loss (map-level)
        iou_loss = self.iou_loss(pred_prob, target)

        # Combine losses (Eq. 6 in paper)
        return focal_loss + ms_ssim_loss + iou_loss

    def focal_loss_with_logits(self, pred: Tensor, target: Tensor) -> Tensor:
        bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        pt = torch.exp(-bce)  # p if target=1, 1-p otherwise

        # Alpha-balanced focal loss (Eq. in paper references [10])
        alpha_t = torch.where(target == 1,
                             self.focal_alpha,
                             1 - self.focal_alpha)
        focal_loss = alpha_t * (1 - pt) ** self.focal_gamma * bce
        return focal_loss.mean()

    def ms_ssim_loss(self, pred: Tensor, target: Tensor) -> Tensor:
        # MS-SSIM returns a similarity score [0,1] → convert to loss
        ms_ssim_val = self.calc_ms_ssim(pred, target)
        return 1.0 - ms_ssim_val

    def iou_loss(self, pred: Tensor, target: Tensor) -> Tensor:
        intersection = (pred * target).sum(dim=(1, 2, 3))
        union = pred.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3)) - intersection
        iou = (intersection + self.eps) / (union + self.eps)
        return (1.0 - iou).mean()

    def calc_ms_ssim(self, pred: Tensor, target: Tensor) -> Tensor:
        device = pred.device
        weights = self.msssim_weights.to(device)
        levels = weights.size(0)

        # Initialize SSIM values
        ssim_per_level = []
        cs_per_level = []

        for i in range(levels):
            # Calculate SSIM at current scale
            ssim_val, cs_val = self.calc_ssim(pred, target, full=True)
            ssim_per_level.append(ssim_val)
            cs_per_level.append(cs_val)

            # Downsample for next level
            if i < levels - 1:
                pred = F.avg_pool2d(pred, kernel_size=2)
                target = F.avg_pool2d(target, kernel_size=2)

        # Combine results (Eq. 5 in paper)
        ms_ssim = torch.ones_like(ssim_per_level[0])
        for i in range(levels):
            if i < levels - 1:
                ms_ssim *= cs_per_level[i] ** weights[i]
            else:
                ms_ssim *= ssim_per_level[i] ** weights[i]

        return ms_ssim.mean()

    def calc_ssim(
        self,
        pred: Tensor,
        target: Tensor,
        window: Tensor = None,
        full: bool = False,
        val_range: float = None
    ) -> Tensor:
        # Determine value range
        if val_range is None:
            max_val = torch.max(pred).item()
            val_range = max_val if max_val > 1 else 1.0

        # Create window if needed
        c = pred.size(1)
        if window is None:
            real_size = min(self.window_size, pred.shape[2], pred.shape[3])
            window = self.create_window(real_size, c).to(pred.device)

        # Calculate means
        mu_pred = F.conv2d(pred, window, groups=c)
        mu_target = F.conv2d(target, window, groups=c)

        # Calculate variances and covariances
        mu_pred_sq = mu_pred.pow(2)
        mu_target_sq = mu_target.pow(2)
        mu_pred_target = mu_pred * mu_target

        sigma_pred_sq = F.conv2d(pred * pred, window, groups=c) - mu_pred_sq
        sigma_target_sq = F.conv2d(target * target, window, groups=c) - mu_target_sq
        sigma_pred_target = F.conv2d(pred * target, window, groups=c) - mu_pred_target

        # SSIM constants
        c1 = (0.01 * val_range) ** 2
        c2 = (0.03 * val_range) ** 2

        # Contrast sensitivity (CS)
        cs_map = (2 * sigma_pred_target + c2) / (sigma_pred_sq + sigma_target_sq + c2)

        # SSIM map
        ssim_map = ((2 * mu_pred_target + c1) / (mu_pred_sq + mu_target_sq + c1)) * cs_map

        if full:
            return ssim_map.mean(), cs_map.mean()
        return ssim_map.mean()

    def create_window(self, size: int, channel: int = 1) -> Tensor:
        # Create 1D Gaussian kernel
        sigma = 1.5
        coords = torch.arange(size).float()
        coords -= size // 2
        g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
        g /= g.sum()

        # Create 2D window
        g_2d = g.unsqueeze(1) @ g.unsqueeze(0)
        g_2d = g_2d.unsqueeze(0).unsqueeze(0)
        return g_2d.expand(channel, 1, size, size).contiguous()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.utils.data import DataLoader
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import os
from datetime import datetime
from torch.cuda.amp import GradScaler, autocast
import segmentation_models_pytorch as smp # Ensure smp is imported

from segmentation_models_pytorch.metrics.functional import iou_score

# from torchmetrics import DiceScore

# All previous imports remain the same

class SegmentationLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5, gamma=2.0, smooth=1e-6):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.smooth = smooth

    def forward(self, pred, target):
        # Ensure float32 for mixed precision
        pred = pred.float()
        target = target.float()

        # BCEWithLogitsLoss (handles logits)
        bce_loss = F.binary_cross_entropy_with_logits(pred, target)

        # Dice Loss with prob conversion
        pred_prob = torch.sigmoid(pred)
        intersection = (pred_prob * target).sum()
        dice_loss = 1 - (2. * intersection + self.smooth) / (pred_prob.sum() + target.sum() + self.smooth)

        # Stabilized Focal Loss
        focal_loss = self._focal_loss(pred, target)

        return (1 - self.alpha - self.beta) * bce_loss + self.alpha * dice_loss + self.beta * focal_loss

    def _focal_loss(self, pred, target):
        bce = F.binary_cross_entropy_with_logits(pred, target, reduction='none')
        pt = torch.exp(-bce.clamp(min=-100, max=50))  # Critical stabilization
        return ((1 - pt) ** self.gamma * bce).mean()


def free_gpu_memory():
    torch.cuda.empty_cache()
    import gc
    gc.collect()

class Trainer:
    def __init__(self, model, train_loader, val_loader,
                 device, num_epochs=100, experiment_name="",
                 lr=0.0001, weight_decay=0.001, loss_type='hybrid'): # Renamed 'loss' to 'loss_type' to avoid conflict
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.num_epochs = num_epochs

        # Setup experiment directory
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.experiment_dir = f"experiments/{experiment_name}_{timestamp}" if experiment_name else f"experiments/exp_{timestamp}"
        os.makedirs(f"{self.experiment_dir}/plots", exist_ok=True)
        os.makedirs(f"{self.experiment_dir}/models", exist_ok=True)

        # Optimizer with filtered parameters
        self.optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            weight_decay=weight_decay
        )

        # Mixed precision training
        self.scaler = GradScaler()

        # LR scheduling with warmup
        warmup_epochs = 5
        self.scheduler = SequentialLR(
            self.optimizer,
            schedulers=[
                LinearLR(self.optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_epochs),
                CosineAnnealingLR(self.optimizer, T_max=num_epochs-warmup_epochs, eta_min=1e-6)
            ],
            milestones=[warmup_epochs]
        )

        # Define loss functions based on loss_type
        self.loss_type = loss_type
        if self.loss_type == 'hybrid':
            # Define individual loss components for hybrid loss
            self.dice_loss = smp.losses.DiceLoss(mode='binary')
            self.focal_loss = smp.losses.FocalLoss(mode='binary')
            # Note: The actual combination happens in the forward pass methods
            self.criterion_name = "Dice + Focal"
        elif self.loss_type == 'focal':
            self.criterion = smp.losses.FocalLoss(mode='binary')
            self.criterion_name = "Focal Loss"
        elif self.loss_type == 'custom':
            self.criterion = SegmentationLoss()
            self.criterion_name = "Custom Loss"

        elif self.loss_type == 'dice':
            self.criterion = smp.losses.DiceLoss(mode='binary')
            self.criterion_name = "Dice Loss"

        elif self.loss_type == 'bced':
            self.dice_loss = smp.losses.DiceLoss(mode='binary')
            self.bce_loss = smp.losses.SoftBCEWithLogitsLoss()
            self.criterion_name = "Dice + BCE"


        else: # Default or soft_bce
            self.criterion = smp.losses.SoftBCEWithLogitsLoss()
            self.criterion_name = "SoftBCEWithLogitsLoss"


        # Track metrics
        self.train_loss = []
        self.val_loss = []
        self.val_dice = []
        self.val_iou = []
        self.lr_history = []
        self.best_dice = 0.0
        self.best_iou = 0.0

    def train_epoch(self, epoch):
        self.model.train()
        epoch_loss = 0.0

        with tqdm(self.train_loader, unit="batch", desc=f"Epoch {epoch+1}/{self.num_epochs} [Train]") as pbar:
            for images, masks in pbar:
                images = images.to(self.device, non_blocking=True)
                masks = masks.to(self.device, non_blocking=True)

                # Mixed precision forward
                with autocast():
                    outputs = self.model(images)
                    # Calculate loss based on loss_type
                    if self.loss_type == 'hybrid':
                        loss = self.dice_loss(outputs, masks) + self.focal_loss(outputs, masks)
                    elif self.loss_type == 'bced':
                        loss = self.dice_loss(outputs, masks) + self.bce_loss(outputs, masks)
                    else:
                        loss = self.criterion(outputs, masks)


                # Backward pass
                self.optimizer.zero_grad()
                self.scaler.scale(loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()

                epoch_loss += loss.item()
                pbar.set_postfix({"loss": f"{loss.item():.4f}", "lr": f"{self.optimizer.param_groups[0]['lr']:.2e}"})

        return epoch_loss / len(self.train_loader)


    def validate_epoch(self, epoch):
        self.model.eval()
        val_loss = 0.0
        dice_scores = []
        iou_scores = []

        with tqdm(self.val_loader, unit="batch", desc=f"Epoch {epoch+1}/{self.num_epochs} [Val]") as pbar:
            with torch.no_grad():
                for images, masks in pbar:
                    images = images.to(self.device, non_blocking=True)
                    masks = masks.to(self.device, non_blocking=True)

                    with autocast():
                        outputs = self.model(images)
                        # Calculate loss based on loss_type
                        if self.loss_type == 'hybrid':
                             loss = self.dice_loss(outputs, masks) + self.focal_loss(outputs, masks)

                        elif self.loss_type == 'bced':
                             loss = self.dice_loss(outputs, masks) + self.bce_loss(outputs, masks)
                        else:
                             loss = self.criterion(outputs, masks)

                    val_loss += loss.item()

                    # Calculate Dice score
                    preds = torch.sigmoid(outputs)
                    # Move tensors to CPU before calculating dice
                    dice = self._calculate_dice(preds.cpu(), masks.cpu())
                    iou = self._calculate_iou(preds.cpu(), masks.cpu())
                    dice_scores.append(dice)
                    iou_scores.append(iou)

                    pbar.set_postfix({"val_loss": f"{loss.item():.4f}", "dice": f"{dice:.4f}", "iou": f"{iou:.4f}"})

        mean_dice = np.mean(dice_scores)
        mean_iou = np.mean(iou_scores)
        return val_loss / len(self.val_loader), mean_dice, mean_iou

    def _calculate_dice(self, pred, target, smooth=1e-6):
        # Ensure tensors are on CPU if they weren't already
        pred = pred.cpu()
        target = target.cpu()

        pred = (pred > 0.5).float()
        intersection = (pred * target).sum()
        return (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

    def _calculate_iou(self, pred, target):
        # Ensure tensors are on CPU if they weren't already
        pred = pred.cpu()
        target = target.cpu()

        pred = (pred > 0.5).float()
        intersection = (pred * target).sum()
        union = pred.sum() + target.sum() - intersection
        return intersection / union

    def _update_plots(self, epoch):
        """Update training metrics visualization"""
        plt.figure(figsize=(18, 6))

        # Loss plot
        plt.subplot(1, 3, 1)
        plt.plot(self.train_loss, label='Train', color='blue')
        plt.plot(self.val_loss, label='Val', color='red')
        plt.title('Training & Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        # Dice plot
        plt.subplot(1, 3, 2)
        plt.plot(self.val_dice, label='Val Dice', color='green')
        plt.title('Validation Dice Score')
        plt.xlabel('Epochs')
        plt.ylabel('Dice')
        plt.legend()
        plt.grid(True)

        # LR plot
        plt.subplot(1, 3, 3)
        plt.plot(self.lr_history, label='Learning Rate', color='purple')
        plt.title('Learning Rate Schedule')
        plt.xlabel('Epochs')
        plt.ylabel('LR')
        plt.legend()
        plt.grid(True)
        plt.yscale('log')

        plt.tight_layout()
        plt.savefig(f"{self.experiment_dir}/plots/metrics_epoch_{epoch+1}.png", dpi=120)
        plt.close()

    def train(self):
        print(f"\n🚀 Starting training for {self.num_epochs} epochs...")
        print(f"📂 Experiment directory: {self.experiment_dir}")
        print(f"⚡ Using device: {self.device}")
        print(f"🔧 Loss function: {self.criterion_name}\n") # Use criterion_name for printing

        for epoch in tqdm(range(self.num_epochs), desc="Training Progress"):
            train_loss = self.train_epoch(epoch)
            val_loss, val_dice, val_iou = self.validate_epoch(epoch)

            self.scheduler.step()
            self.train_loss.append(train_loss)
            self.val_loss.append(val_loss)
            self.val_dice.append(val_dice)
            self.lr_history.append(self.optimizer.param_groups[0]['lr'])

            self._update_plots(epoch)

            if val_dice > self.best_dice:
                self.best_dice = val_dice
                self.best_iou = val_iou
                torch.save({
                    'epoch': epoch+1,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'loss': val_loss,
                    'dice': val_dice,
                }, f"{self.experiment_dir}/models/best_model.pth")
                tqdm.write(f"🎉 New best model: Dice {val_dice:.4f} IOU: {val_iou:.4f} at epoch {epoch+1}")

        self._save_final_artifacts()
        print(f"\n🏁 Training completed! Best Dice: {self.best_dice:.4f}")
        return self.model.state_dict() # Return state dict of the best model

    def _save_final_artifacts(self):
        """Save all training artifacts"""
        torch.save({
            'epoch': self.num_epochs,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_loss': self.train_loss,
            'val_loss': self.val_loss,
            'val_dice': self.val_dice,
        }, f"{self.experiment_dir}/models/final_model.pth")

        np.savez(
            f"{self.experiment_dir}/plots/training_metrics.npz",
            train_loss=np.array(self.train_loss),
            val_loss=np.array(self.val_loss),
            val_dice=np.array(self.val_dice),
            lr_history=np.array(self.lr_history)
        )

        self._generate_final_plot()

    def _generate_final_plot(self):
        """Generate high-quality final plot"""
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))

        ax1.plot(self.train_loss, label='Train', color='blue')
        ax1.plot(self.val_loss, label='Val', color='red')
        ax1.set_title('Training & Validation Loss')
        ax1.set_xlabel('Epochs')
        ax1.legend()
        ax1.grid(True)

        ax2.plot(self.val_dice, label='Val Dice', color='green')
        ax2.set_title('Validation Dice Score')
        ax2.set_xlabel('Epochs')
        ax2.legend()
        ax2.grid(True)

        ax3.plot(self.lr_history, label='Learning Rate', color='purple')
        ax3.set_title('Learning Rate Schedule')
        ax3.set_xlabel('Epochs')
        ax3.set_yscale('log')
        ax3.legend()
        ax3.grid(True)

        plt.tight_layout()
        plt.savefig(f"{self.experiment_dir}/plots/final_metrics.png", dpi=300)
        plt.savefig(f"{self.experiment_dir}/plots/final_metrics.pdf")
        plt.close()

In [None]:
import warnings
warnings.filterwarnings("ignore")

if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_catheter = load_base_unetmodel()
    # model_catheter = load_base_unetmodel(
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_catheter,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=100,
        experiment_name="catheter_unet_dice",
        loss_type='dice'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_cath = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
import warnings
warnings.filterwarnings("ignore")

if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_catheter = smp.Unet(in_channels=1, encoder_name='efficientnet-b7', encoder_weights='imagenet', classes=1)
    # model_catheter = load_base_unetmodel(
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_catheter,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=100,
        experiment_name="catheter_unet_simple",
        loss_type='dice'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_cath = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
import warnings
warnings.filterwarnings("ignore")

if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_catheter = load_base_unetmodel()
    # model_catheter = load_base_unetmodel(
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_catheter,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=60,
        experiment_name="catheter_unet",
        loss_type='bced'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_cath = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
!cp -r /content/experiments/catheter_unet_20250616_041515 /content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet

In [None]:
free_gpu_memory()

In [None]:
import warnings
warnings.filterwarnings("ignore")
train_data_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, pin_memory=True)
val_data_loader = DataLoader(val_dataset, batch_size=2, shuffle=False,  pin_memory=True)


if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_catheter_3p = UNet_3Plus(in_channels=1)
    # model_catheter = load_base_unetmodel(
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_catheter_3p,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=60,
        experiment_name="catheter_unet_3p",
        loss_type='bced'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_cath = trainer.train()


In [None]:
!cp -r /content/experiments/catheter_unet_20250614_100702 /content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet

In [None]:
import warnings
warnings.filterwarnings("ignore")
train_data_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, pin_memory=True)
val_data_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,  pin_memory=True)

if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_catheter_unet = load_base_unetmodel()
    # model_catheter = load_base_unetmodel(
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_catheter_unet,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=100,
        experiment_name="catheter_unet_bce",
        # lr=0.00001
        loss_type='bce'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_cath = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
free_gpu_memory()

In [None]:
!cp -r /content/experiments/catheter_unet_bce_20250609_180115 /content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet

In [None]:
import warnings
warnings.filterwarnings("ignore")
train_data_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, pin_memory=True)
val_data_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,  pin_memory=True)

if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_catheter_unet = load_base_unetmodel()
    # model_catheter = load_base_unetmodel(
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_catheter_unet,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=100,
        experiment_name="catheter_unet_bce",
        # lr=0.00001
        loss_type='custom'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_cath = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
import warnings
warnings.filterwarnings("ignore")
train_data_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, pin_memory=True)
val_data_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,  pin_memory=True)

if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_catheter_unet_focal = load_base_unetmodel()
    # model_catheter = load_base_unetmodel(
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_catheter_unet_focal,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=100,
        experiment_name="catheter_unet_2plus_no_clahe",
        # lr=0.00001
        loss='focal'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_cath = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
!cp -r /content/experiments/catheter_unet_20250609_053611 /content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet

In [None]:
import warnings
warnings.filterwarnings("ignore")
train_data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
val_data_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,  pin_memory=True)

if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_catheter_unet = load_base_unetmodel()
    # model_catheter = load_base_unetmodel(
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_catheter_unet,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=100,
        experiment_name="catheter_unet_2plus_focal_loss",
        loss_type='focal'
        # lr=0.00001
        # hybrid_loss=False
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_cath = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
!cp -r /content/experiments/catheter_unet_2plus_focal_loss_20250609_220416 /content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet

In [None]:
# Load best catheter model
model_catheter_unet.load_state_dict(best_weights_state_dict_cath)

In [None]:
SIZE = 960

transform_image = transforms.Compose([  # Remember to change this to 960x960
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485), std=(0.229))
])



from torchmetrics.functional.segmentation import dice_score

def calculate_dice_score(y_true, y_pred):
    score = dice_score(y_true, y_pred, num_classes=2,  average='micro')
    # get average score
    return score.cpu().numpy().mean()

def get_binary_masks(masks):


    return (masks > 0).astype(np.float32)





def predict_catheter_segmentation(image_path, unet_model,
                                  device,
                                  size=SIZE,
                                  transform=transform_image,
                                  clip=None,
                                  model_name='unet++',
                                  prob=False):
    # Open the image using PIL and convert to grayscale
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    # image_pil = Image.fromarray(image)
    unet_model = unet_model.to(device)

    # Convert to NumPy array
    # image_np = np.array(image_pil)

    # Apply CLAHE
    if clip:

      clahe = cv2.createCLAHE(clipLimit=clip, tileGridSize=(8, 8))
      clahe_image_np = clahe.apply(image)
    else:
      clahe_image_np = image

    # Resize image while maintaining aspect ratio
    resized_aspx = resize_to_square(img=clahe_image_np, target_size=size)

    # Convert back to PIL Image
    clahe_image_pil = Image.fromarray(resized_aspx)

    # Apply transform (e.g., ToTensor, normalization)
    if transform:
        image_tensor = transform(clahe_image_pil).unsqueeze(0).to(device) # unsqueeze adds a sample dimenstion (N, C, W, H)
    else:
        image_tensor = transforms.ToTensor()(clahe_image_pil).unsqueeze(0).to(device)

    # Predict mask using U-Net++
    unet_model.eval()
    with torch.inference_mode():


        if model_name == 'unet++':
            output = torch.sigmoid(unet_model(image_tensor))

        elif model_name == 'unet_deep_sup':
            output = unet_model(image_tensor)
            output = torch.sigmoid(output[0])

        else:
          output = unet_model(image_tensor)
        # output = torch.sigmoid(output)
        mask = output if prob else (output > 0.5).float()

    return mask.cpu().squeeze().numpy(), resized_aspx


def convert_to_tensor(np_array):
    return torch.from_numpy(np_array)

def get_id_from_path(path: str) -> str:
    return path.split('/')[-1].replace('.tif', '')


import random
import cv2
import matplotlib.pyplot as plt

def visualize_predictions(model, test_image_paths, test_mask_paths, device, num_samples=5, clip_limit=2, figsize=(20, 10), model_name='unet++', size=SIZE):
    """
    Visualize model predictions by showing original images, predicted masks, and ground truth masks.

    Args:
        model: The trained model for making predictions
        test_image_paths: List of paths to test images
        test_mask_paths: List of paths to corresponding ground truth masks
        device: Device to run the model on ('cuda' or 'cpu')
        num_samples: Number of random samples to visualize (default: 5)
        clip_limit: CLAHE clip limit (default: 2)
        figsize: Size of the matplotlib figure (default: (20, 10))
    """
    # Randomly select samples
    random_indices = random.sample(range(len(test_image_paths)), min(num_samples, len(test_image_paths)))
    random_images = [test_image_paths[i] for i in random_indices]
    random_masks = [test_mask_paths[i] for i in random_indices]

    # Create subplots
    fig, axes = plt.subplots(3, len(random_images), figsize=figsize)

    # If only one sample, axes will be 1D - convert to 2D for consistency
    if len(random_images) == 1:
        axes = axes.reshape(3, 1)

    for i, img_path in enumerate(random_images):
        # Read ground truth mask
        actual_mask = resize_to_square(cv2.imread(random_masks[i], cv2.IMREAD_GRAYSCALE), target_size=size)

        binary_mask = get_binary_masks(actual_mask)
        # actual

        # Get prediction
        predicted_mask, clahe_image = predict_catheter_segmentation(
            img_path, model, device, clip=clip_limit, model_name=model_name
        )

        dice = calculate_dice_score(convert_to_tensor(binary_mask).unsqueeze(0), convert_to_tensor(predicted_mask).unsqueeze(0))

        # Plot original image
        axes[0, i].imshow(clahe_image, cmap='gray')
        # axes[0, i].set_title("")
        axes[0, i].axis('off')

         # Plot ground truth mask
        axes[1, i].imshow(binary_mask, cmap='gray')
        axes[1, i].set_title("Ground truth", fontsize=12)
        axes[1, i].axis('off')


        # Plot predicted mask
        axes[2, i].imshow(predicted_mask, cmap='gray')
        axes[2, i].set_title(f'Predicted mask--DSC:{dice:.3f}', fontsize=12)
        axes[2, i].axis('off')


    plt.tight_layout()
    plt.show()



In [None]:
visualize_predictions(model_catheter, valid_test_paths, valid_test_cathetr_paths, device, num_samples=5, clip_limit=2, figsize=(20, 10))

In [None]:
!cp -r /content/experiments/catheter_unet_20250605_075915 /content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet

In [None]:
from math import pi
from torch.utils.data import DataLoader

import albumentations as A

train_transform = A.Compose([
    # --- Spatial Augmentations ---
    A.HorizontalFlip(p=0.3),  # Safe for atrial anatomy
    A.Rotate(limit=15, p=0.2, border_mode=cv2.BORDER_CONSTANT),

    # Finer GridDistortion (controlled warping)


    # Optional: Combined with ElasticTransform
    A.OneOf([
        A.ElasticTransform(
                alpha=300,
                sigma=10,
                interpolation=cv2.INTER_LINEAR,
                approximate=False,
                same_dxdy=True,
                mask_interpolation=cv2.INTER_NEAREST,
                noise_distribution="gaussian",
                keypoint_remapping_method="mask",
                border_mode=cv2.BORDER_CONSTANT,
                fill=0,
                fill_mask=0,
                p=0.3
),
        A.GridDistortion(distort_limit=0.1, p=0.3)  # Even subtler variant
    ], p=0.4),

    # --- Intensity Augmentations ---
    # A.CLAHE(clip_limit=2.0, p=0.3),
    A.RandomBrightnessContrast(
        brightness_limit=(-0.1, 0.1),
        contrast_limit=(-0.1, 0.1),
        p=0.4
    ),
    A.GaussNoise(std_range=(0.01, 0.05), p=0.2),
    # --- Normalization ---
    A.Normalize(mean=(0.485), std=(0.229)),
    A.ToTensorV2(),
])

val_transform = A.Compose([
    A.Normalize(mean=(0.485), std=(0.229)),
    A.ToTensorV2(),
])



train_dataset_atria_seg = SegmentationDataset(
    image_paths=valid_images_paths,
    mask_paths=valid_atrial_paths,
    image_size=960,
    transform=val_transform,
    apply_clahe=True,
    clahe_limit=2 # Atrium needs higher clahe
)

val_dataset_atria_seg = SegmentationDataset(
    image_paths=valid_test_paths,
    mask_paths=valid_test_atrial_paths,
    image_size=960,
    transform=val_transform,
    apply_clahe=True,
    clahe_limit=2
)

train_data_loader = DataLoader(train_dataset_atria_seg, batch_size=5, shuffle=True, pin_memory=True)
val_data_loader = DataLoader(val_dataset_atria_seg, batch_size=4, shuffle=False, pin_memory=True)



In [None]:

img, mask = train_dataset_atria_seg[0]
print(img.shape, mask.shape)

plt.subplot(1, 2, 1)
# Access the image tensor at index 0, remove channel dim, convert to numpy
plt.imshow(img.squeeze().numpy(), cmap='gray')
plt.title('Image')
plt.axis('off')

plt.subplot(1, 2, 2)
# Access the mask tensor at index 1, remove channel dim, convert to numpy
plt.imshow(mask.squeeze().numpy(), cmap='gray')
plt.title('Mask')
plt.axis('off')

plt.show()

In [None]:
if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_atrium_unet_2plus = load_base_unetmodel()
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_atrium_unet_2plus,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=100,
        experiment_name="atrium_unet_focal",
        loss_type='hybrid'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_atrium = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_atrium_unet_2plus = load_base_unetmodel()
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_atrium_unet_2plus,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=100,
        experiment_name="atrium_unet-dice",
        loss_type='dice'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_atrium = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
!cp -r /content/experiments/catheter_unet_dice_20250617_084412 /content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet/

In [None]:
free_gpu_memory()

In [None]:
if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_atrium_unet_2plus = load_base_unetmodel()
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_atrium_unet_2plus,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        num_epochs=100,
        experiment_name="atrium_unet_bce_dice",
        loss_type='bced'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_atrium = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
!cp -r /content/experiments/atrium_unet_2plus_no_augmentation_20250614_115800 /content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet

In [None]:
model_atrium_unet_2plus.load_state_dict(best_weights_state_dict_atrium)

In [None]:



class UNet3PlusDeepSupTrainer(Trainer):
    def __init__(self, model, train_loader, val_loader, device, num_epochs=100, experiment_name=""):
        super().__init__(model, train_loader, val_loader, device, num_epochs, experiment_name)
        # Use same optimizer, scheduler, etc.
        # Expect self.criterion to accept prediction & mask pairs

        # Weights for each deep supervision output, adjust as needed
        self.ds_weights = [0.5, 0.25, 0.125, 0.0625, 0.0625]

    def train_epoch(self, epoch):
        self.model.train()
        epoch_loss = 0.0

        with tqdm(self.train_loader, unit="batch", desc=f"Epoch {epoch+1}/{self.num_epochs} [Train]") as pbar:
            for images, masks in pbar:
                images = images.to(self.device, non_blocking=True)
                masks = masks.to(self.device, non_blocking=True)

                # Mixed precision forward
                with autocast():
                    outs = self.model(images)  # d1, d2, d3, d4, d5
                    # All outputs are same shape as mask
                    total_loss = 0
                    for idx, out in enumerate(outs):
                        total_loss += self.ds_weights[idx] * self.criterion(out, masks)

                self.optimizer.zero_grad()
                self.scaler.scale(total_loss).backward()
                self.scaler.step(self.optimizer)
                self.scaler.update()

                epoch_loss += total_loss.item()
                pbar.set_postfix({"loss": f"{total_loss.item():.4f}", "lr": f"{self.optimizer.param_groups[0]['lr']:.2e}"})

        return epoch_loss / len(self.train_loader)

    def validate_epoch(self, epoch):
        self.model.eval()
        val_loss = 0.0
        dice_scores = []

        with tqdm(self.val_loader, unit="batch", desc=f"Epoch {epoch+1}/{self.num_epochs} [Val]") as pbar:
            with torch.no_grad():
                for images, masks in pbar:
                    images = images.to(self.device, non_blocking=True)
                    masks = masks.to(self.device, non_blocking=True)

                    with autocast():
                        outs = self.model(images)
                        total_loss = 0
                        for idx, out in enumerate(outs):
                            total_loss += self.ds_weights[idx] * self.criterion(out, masks)

                    val_loss += total_loss.item()

                    # Use the first output (highest resolution) for Dice calculation
                    preds = torch.sigmoid(outs[0])
                    dice = self._calculate_dice(preds.cpu(), masks.cpu())
                    dice_scores.append(dice)

                    pbar.set_postfix({"val_loss": f"{total_loss.item():.4f}", "dice": f"{dice:.4f}"})

        mean_dice = np.mean(dice_scores)
        return val_loss / len(self.val_loader), mean_dice


In [None]:
free_gpu_memory()

In [None]:
visualize_predictions(model_atrium_unet_2plus, valid_test_paths, valid_test_atrial_paths, device, num_samples=5, figsize=(20, 10), clip_limit=2)

In [None]:
if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_atrium = load_base_unetmodel()
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_atrium,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        # num_epochs=150,
        experiment_name="atrium_unet_bce_loss", loss_type='bce'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_atrium = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_atrium = load_base_unetmodel()
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_atrium,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        # num_epochs=150,
        experiment_name="atrium_unet_focal_loss", loss_type='focal'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_atrium = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
!cp -r /content/experiments/atrium_unet_focal_loss_20250610_101136 /content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet

In [None]:
model_atrium.load_state_dict(best_weights_state_dict_atrium)

In [None]:
visualize_predictions(model_atrium, valid_test_paths, valid_test_atrial_paths, device, num_samples=5, figsize=(20, 10), clip_limit=2)

In [None]:

# train_transform = A.Compose([
#     # --- Spatial Augmentations ---
#     A.HorizontalFlip(p=0.3),  # Safe for atrial anatomy

#     # Finer GridDistortion (controlled warping)

#     A.Rotate(limit=15, p=0.4, border_mode=cv2.BORDER_CONSTANT),
#     A.GridDistortion(
#         num_steps=5,           # Smoother transitions (default=5)
#         distort_limit=0.15,    # Reduced from 0.3 (gentler distortion)
#         interpolation=cv2.INTER_LINEAR,
#         border_mode=cv2.BORDER_CONSTANT,
#         p=0.3
#     ),

#     # Optional: Combined with ElasticTransform
#     A.OneOf([
#         A.ElasticTransform(
#             alpha=120,         # Magnitude of distortion (reduced from default)
#             sigma=120 * 0.05,  # Smoothness (sigma=alpha*factor)
#             # alpha_affine=0,    # No longer a parameter - removed
#             p=0.5
#         ),
#         A.GridDistortion(distort_limit=0.1, p=0.5)  # Even subtler variant
#     ], p=0.4),

#     # --- Intensity Augmentations ---
#     # A.CLAHE(clip_limit=2.0, p=0.3),
#     A.RandomBrightnessContrast(
#         brightness_limit=(-0.1, 0.1),
#         contrast_limit=(-0.1, 0.1),
#         p=0.4
#     ),
#     A.GaussNoise(std_range=(0.01, 0.02), p=0.2),
#     # --- Normalization ---
#     A.Normalize(mean=(0.485), std=(0.229)),
#     A.ToTensorV2(),
# ])

# val_transform = A.Compose([
#     A.Normalize(mean=(0.485), std=(0.229)),
#     A.ToTensorV2(),
# ])




# train_dataset = SegmentationDataset(
#     image_paths=valid_images_paths,
#     mask_paths=valid_cathetr_paths,
#     image_size=960,
#     transform=train_transform,
#     apply_clahe=True
# )

# val_dataset = SegmentationDataset(
#     image_paths=valid_test_paths,
#     mask_paths=valid_test_cathetr_paths,
#     image_size=960,
#     transform=val_transform,
#     apply_clahe=True
# )

# train_data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
# val_data_loader = DataLoader(val_dataset, batch_size=4, shuffle=False,  pin_memory=True)


In [None]:
free_gpu_memory()

In [None]:
if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_atrium = load_base_unetmodel()
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_atrium,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        # num_epochs=150,
        experiment_name="atrium_unet_bce_loss", loss_type='bce'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_atrium = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_atrium = load_base_unetmodel()
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_atrium,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        # num_epochs=150,
        experiment_name="atrium_unet_focal_update_loss", loss_type='bce'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_atrium = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
!cp -r /content/experiments/atrium_unet_focal_update_loss_20250610_151901  /content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet

In [None]:
if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_atrium = load_base_unetmodel()
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = Trainer(
        model=model_atrium,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        # num_epochs=150,
        experiment_name="atrium_unet_bce_loss", loss_type='hybrid'
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_atrium = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
if __name__ == "__main__":
    # Initialize components outside the Trainer class's __main__ block
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_cath_deep_sup = UNet_3Plus_DeepSup(in_channels=1)
    # Data loaders are defined above this __main__ block
    # train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    # val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

    # Train
    trainer = UNet3PlusDeepSupTrainer(
        model=model_cath_deep_sup,
        train_loader=train_data_loader, # Use the data loaders defined outside __main__
        val_loader=val_data_loader,   # Use the data loaders defined outside __main__
        device=device,
        # num_epochs=150,
        experiment_name="catheter_unet_deep_sup"
    )
    # The train method now returns the best weights state dictionary
    best_weights_state_dict_cath_deep_sup = trainer.train()

    # Load best weights for inference
    # model1.load_state_dict(best_weights_state_dict)

In [None]:
visualize_predictions(model_cath_deep_sup, valid_test_paths, valid_test_cathetr_paths, device, num_samples=5, clip_limit=3.5, figsize=(20, 10), deep_sup=True)

### Get Masks from trained

In [None]:
# !cp -r /content/experiments/catheter_unet_deep_sup_20250527_043857 /content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet/

In [None]:
# Ensure models are saved correctly
unet_atrium_model_load = UNet_3Plus_DeepSup(in_channels=1)
# Load the entire checkpoint dictionary first
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet/atrium_unet_deep_sup_20250526_235713/models/best_model.pth', weights_only=False)

unet_atrium_model_load.load_state_dict(checkpoint['model_state_dict'])

In [None]:
unet_cath_model_load = UNet_3Plus_DeepSup(in_channels=1)
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet/catheter_unet_deep_sup_20250527_043857/models/best_model.pth', weights_only=False)

unet_cath_model_load.load_state_dict(checkpoint['model_state_dict'])

In [None]:
unet_atrium_base_model = UNet_3Plus(in_channels=1)
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet/atrium_unet_20250526_180651/models/best_model.pth', weights_only=False)

unet_atrium_base_model.load_state_dict(checkpoint['model_state_dict'])

unet_cath_base_model = UNet_3Plus(in_channels=1)
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet/catheter_unet_20250526_131827/models/best_model.pth', weights_only=False)

unet_cath_base_model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
unet_2plus_cath = load_base_unetmodel()
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet/catheter_unet_2plus_20250606_100927/models/best_model.pth', weights_only=False)
unet_2plus_cath.load_state_dict(checkpoint['model_state_dict'])

unet_2plus_cath_noaug = load_base_unetmodel()
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet/catheter_unet_2plus_20250605_233907/models/best_model.pth', weights_only=False)
unet_2plus_cath_noaug.load_state_dict(checkpoint['model_state_dict'])


unet_2plus_atrium = load_base_unetmodel()
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet/atrium_unet_2plus_20250606_073341/models/best_model.pth', weights_only=False)
unet_2plus_atrium.load_state_dict(checkpoint['model_state_dict'])


unet_cath_no_clahe = load_base_unetmodel()
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet/catheter_unet_2plus_no_clahe_20250608_155846/models/best_model.pth', weights_only=False)
unet_cath_no_clahe.load_state_dict(checkpoint['model_state_dict'])

unet_atrium_no_clahe = load_base_unetmodel()
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet/atrium_unet_no_clahe_20250608_174402/models/best_model.pth', weights_only=False)
unet_atrium_no_clahe.load_state_dict(checkpoint['model_state_dict'])


In [None]:
unet_atrium_bce = load_base_unetmodel()
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet/atrium_unet_bce_loss_20250610_115239/models/best_model.pth', weights_only=False)
unet_atrium_bce.load_state_dict(checkpoint['model_state_dict'])

unet_cath_bce = load_base_unetmodel()
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet/catheter_unet_bce_20250609_180115/models/best_model.pth', weights_only=False)
unet_cath_bce.load_state_dict(checkpoint['model_state_dict'])


unet_atrium_focal = load_base_unetmodel()
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet/atrium_unet_focal_update_loss_20250610_151901/models/best_model.pth', weights_only=False)
unet_atrium_focal.load_state_dict(checkpoint['model_state_dict'])

unet_cath_focal = load_base_unetmodel()
checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/catheter_unet/catheter_unet_2plus_focal_loss_20250609_220416/models/best_model.pth', weights_only=False)
unet_cath_focal.load_state_dict(checkpoint['model_state_dict'])

# unet_atrium_hybrid = load_base_unetmodel()
# checkpoint = torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet/atrium_unet_2plus_hybrid_loss_20250609_234508/models/best_model.pth', weights_only=False)
# une_atrium_focal.load_state_dict(checkpoint['model_state_dict'])

In [None]:
# !cp -r /content/experiments/atrium_unet_focal_loss_20250610_101136 /content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet

In [None]:
visualize_predictions(unet_2plus_cath,
                      valid_test_paths,
                      valid_test_cathetr_paths,
                      device, num_samples=5,
                      clip_limit=2,
                      figsize=(20, 12),
                      model_name='unet++')

In [None]:
visualize_predictions(unet_cath_no_clahe,
                      valid_test_paths,
                      valid_test_cathetr_paths,
                      device,
                      num_samples=5,
                      # clip_limit=4.1,
                      figsize=(20, 10),
                      model_name='unet++')

In [None]:
visualize_predictions(unet_2plus_atrium,
                      valid_test_paths,
                      valid_test_atrial_paths,
                      device,
                      num_samples=5,
                      clip_limit=2,
                      figsize=(20, 12),
                      model_name='unet++')

In [None]:
def get_individual_masks(image_paths, unet_model, device, clahe=None, clip=2, model_name='unet_deep_sup'):
    all_predictions = []
    clahe_images = []

    for image_path in image_paths:
        mask, clahe_image = predict_catheter_segmentation(image_path, unet_model, device, clip=clip, model_name=model_name)
        all_predictions.append(mask)

        if clahe:
            clahe_images.append(clahe_image)

    return np.array(all_predictions), np.array(clahe_images) if clahe else None


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Get mask from pretained models
* For both mask, their respective base UNet+++ and UNet+++ with deep supervision will be obtained


In [None]:
def process_masks(image_paths,
                  unet_cath_model,
                  unet_atrium_model,
                  device,
                  clip=None,
                  clip_atrium=None,
                  model_name='unet_deep_sup',
                  prob=False):
    """
    Processes images to get predicted catheter and atrial masks using trained models.

    Args:
        image_paths (list[str]): List of paths to the images.
        unet_cath_model: The trained model for catheter segmentation.
        unet_atrium_model: The trained model for atrial segmentation.
        device: The device to run the models on ('cuda' or 'cpu').
        clip_atrium (float): CLAHE clip limit for atrium images.
        deep_sup (bool): Whether the models use deep supervision.

    Returns:
        tuple: Contains four tensors:
               - catheter_masks_tensor (torch.Tensor)
               - atrial_masks_tensor (torch.Tensor)
               - clahe_images (list[np.ndarray])
    """
    catheter_masks = []
    atrial_masks = []
    clahe_images = []

    for img_path in tqdm(image_paths, desc="Processing Images"):
        # Process for catheter mask
        cath_mask, clahe_image = predict_catheter_segmentation(
            img_path, unet_cath_model, device, model_name=model_name, prob=prob, clip=clip
        )
        catheter_masks.append(cath_mask)
        clahe_images.append(clahe_image)

        # Process for atrial mask
        atrial_mask, _ = predict_catheter_segmentation(
            img_path, unet_atrium_model, device, clip=clip_atrium, model_name=model_name, prob=prob
        )
        atrial_masks.append(atrial_mask)

    # Convert lists of numpy arrays to single numpy arrays
    catheter_masks_np = np.stack(catheter_masks)
    atrial_masks_np = np.stack(atrial_masks)

    # Convert numpy arrays to tensors
    catheter_masks_tensor = convert_np_tensor(catheter_masks_np)
    atrial_masks_tensor = convert_np_tensor(atrial_masks_np)


    return catheter_masks_tensor, atrial_masks_tensor, clahe_images

In [None]:
def get_orginal_maks(train_paths, test_paths, size = 960):

  return read_image_frompath(train_paths, size=size), read_image_frompath(test_paths, size=size)

def get_binary_masks(masks):


  return (masks > 0).float()

def convert_np_tensor(np_array):

  if isinstance(np_array, torch.Tensor):
    return np_array
  elif isinstance(np_array, np.ndarray):
    return torch.from_numpy(np_array)
  elif isinstance(np_array, list):
    return torch.stack([torch.from_numpy(arr) for arr in np_array])
  else:
    raise ValueError("Unsupported data type")



In [None]:
from tqdm.notebook import tqdm
train_catheter_tensor_base, train_atrial_tensor_base, train_clahe_images = process_masks(valid_images_paths,
                                                                                 unet_cath_base_model,
                                                                                 unet_atrium_base_model,


                                                                                 device, model_name='unet3+',

                                                                                         clip=2,
                                                                                         clip_atrium=3.5)

test_catheter_tensor_base, test_atrial_tensor_base, test_clahe_images = process_masks(valid_test_paths,
                                                                                 unet_cath_base_model,
                                                                                 unet_atrium_base_model,
                                                                                 device, model_name='unet3+',
                                                                                      clip=2,
                                                                                         clip_atrium=3.5

                                                                                      )

In [None]:
train_catheter_tensor_ds, train_atrial_tensor_ds, _ = process_masks(valid_images_paths,
                                                                                 unet_cath_model_load,
                                                                                 unet_atrium_model_load,
                                                                                 device, model_name='unet_deep_sup',
                                                                    clip=2,
                                                                                         clip_atrium=3.5)
test_catheter_tensor_ds, test_atrial_tensor_ds, _ = process_masks(valid_test_paths,
                                                                                 unet_cath_model_load,
                                                                                 unet_atrium_model_load,
                                                                                 device, model_name='unet_deep_sup',
                                                                  clip=2,
                                                                                         clip_atrium=3.5)


In [None]:
train_catheter_tensor_u2plus, train_atrial_tensor_u2plus, _ = process_masks(valid_images_paths,
                                                                                 unet_2plus_cath,
                                                                                 unet_2plus_atrium,
                                                                                 device, model_name='unet++')
test_catheter_tensor_u2plus, test_atrial_tensor_u2plus, _ = process_masks(valid_test_paths,
                                                                                 unet_2plus_cath,
                                                                                 unet_2plus_atrium,
                                                                                 device, model_name='unet++')

In [None]:
train_catheter_tensor_u2plus_no_clahe, train_atrial_tensor_u2plus_no_clahe, _ = process_masks(valid_images_paths,
                                                                                 unet_cath_no_clahe,
                                                                                 unet_atrium_no_clahe,
                                                                                 device, model_name='unet++')
test_catheter_tensor_u2plus_no_clahe, test_atrial_tensor_u2plus_no_clahe, _ = process_masks(valid_test_paths,
                                                                                            unet_cath_no_clahe,
                                                                                 unet_atrium_no_clahe,
                                                                                            device, model_name='unet++')


In [None]:
test_catheter_tensor_bce, test_atrial_tensor_bce, _ = process_masks(valid_test_paths,
                                                                                 unet_cath_bce,
                                                                                 unet_atrium_bce,
                                                                                 device, model_name='unet++',
                                                                                clip=2, clip_atrium=2)

test_catheter_tensor_focal, test_atrial_tensor_focal , _ = process_masks(valid_test_paths,
                                                                                 unet_cath_focal,
                                                                                 unet_atrium_focal,
                                                                                 device, model_name='unet++',
                                                                                clip=2, clip_atrium=2)

In [None]:
# !cp -r /content/experiments/atrium_unet_focal_loss_20250610_101136 /content/drive/MyDrive/msc_uhasselt/experiments/atrium_unet/

Load the ground truth masks for both the catheter and the right atria border, this will enable us to obtain validation ("test") metrics for the different segmentation models.

In [None]:
original_catheter_mask_train, original_catheter_mask_test = get_orginal_maks(valid_cathetr_paths, valid_test_cathetr_paths)

original_atrial_mask_train, original_atrial_mask_test = get_orginal_maks(valid_atrial_paths, valid_test_atrial_paths)

In [None]:
original_catheter_mask_train[0].shape

In [None]:
ground_truth_catheter_tensor_train = get_binary_masks(convert_np_tensor(original_catheter_mask_train))
ground_truth_atrial_tensor_train = get_binary_masks(convert_np_tensor(original_atrial_mask_train))

ground_truth_catheter_tensor_test = get_binary_masks(convert_np_tensor(original_catheter_mask_test))
ground_truth_atrial_tensor_test = get_binary_masks(convert_np_tensor(original_atrial_mask_test))

In [None]:
clahe_images = np.stack(train_clahe_images + test_clahe_images)
clahe_images.shape

plt.imshow(clahe_images[0], cmap='gray')
plt.axis('off')
plt.show()

## Visually inspect some difficult images with UNet+++ with/wo deep suervision

- Does deep supervision aid in identifying difficult regions??

In [None]:
difficult_images = ['IMG-0456-00001', 'IMG-0468-00001', 'IMG-0462-00001', 'IMG-0464-00001', 'IMG-0459-00001', 'IMG-0273-00001']
# keep pathi if id in not in difficult_images
difficult_path = [path for path in valid_test_paths if path.split('/')[-1].split('.')[0] in difficult_images]
difficult_catheter_path = [path for path in valid_test_cathetr_paths if path.split('/')[-1].split('.')[0] in difficult_images]
difficult_atrial_path = [path for path in valid_test_atrial_paths if path.split('/')[-1].split('.')[0] in difficult_images]

# print(f"Number of images: {len(difficult_path)}, path: {' '.join(difficult_path)}")
# print(f"Number of catheter masks: {len(difficult_catheter_path)}, paths: {' '.join(difficult_catheter_path)}")
print(f"Number of atrial masks: {len(difficult_atrial_path)}")


In [None]:
diff_cath_tensor_base, diff_atrial_tensor_base, diff_images = process_masks(difficult_path,
                                                                                 unet_cath_base_model,
                                                                                 unet_atrium_base_model,
                                                                                 device, model_name='unet3+')
diff_cath_tensor_ds, diff_atrial_tensor_ds, _ = process_masks(difficult_path,
                                                              unet_cath_model_load,
                                                              unet_atrium_model_load,
                                                              device, model_name='unet_deep_sup')

diff_cath_tensor_u2plus, diff_atrial_tensor_u2plus, _ = process_masks(difficult_path,
                                                              unet_2plus_cath,
                                                              unet_2plus_atrium,
                                                              device, model_name='unet++')

diff_cath_gt, diff_atrial_gt = get_orginal_maks(difficult_catheter_path, difficult_atrial_path)

In [None]:
diff_cath_gt = get_binary_masks(convert_np_tensor(diff_cath_gt))
diff_atrial_gt = get_binary_masks(convert_np_tensor(diff_atrial_gt))

In [None]:
type(diff_images[0])

In [None]:
from torchmetrics.functional.segmentation import dice_score, mean_iou

def calculate_dice_score(y_true, y_pred):
    score = dice_score(y_true, y_pred, num_classes=2,  average='micro')
    # get average score
    return score.cpu().numpy().mean()

def calculate_iou_score(y_true, y_pred):
    # Convert floating point tensors (0 or 1) to long integer type (0 or 1)
    # mean_iou expects class indices or boolean
    y_true = y_true.long()
    y_pred = y_pred.long()
    score = mean_iou(y_true, y_pred, include_background=False, num_classes=2)
    # get average score
    return score.cpu().numpy().mean()

# Now call the function with the converted tensors
calculate_iou_score(ground_truth_catheter_tensor_train, train_catheter_tensor_u2plus) # sanity chec should be == 1


In [None]:


def visualize_comparison(original, gt, pred_no_ds, pred_ds):
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))


    dice_base = calculate_dice_score(torch.from_numpy(gt).unsqueeze(0), torch.from_numpy(pred_no_ds).unsqueeze(0))
    dice_ds = calculate_dice_score(torch.from_numpy(gt).unsqueeze(0), torch.from_numpy(pred_ds).unsqueeze(0))

    # Original Image
    axes[0].imshow(original, cmap='gray')
    axes[0].set_title("Input Image", fontsize=9)
    axes[0].axis('off')

    # Ground Truth
    axes[1].imshow(original, cmap='gray')
    axes[1].imshow(gt, alpha=0.3, cmap='Blues')
    axes[1].set_title("Ground Truth", fontsize=9)
    axes[1].axis('off')

    # Without Deep Supervision
    axes[2].imshow(original, cmap='gray')
    axes[2].imshow(pred_no_ds, alpha=0.3, cmap='Reds')
    axes[2].set_title(f"UNet+++ (No DS) dice: {dice_base:.3f}", fontsize=9)
    axes[2].axis('off')

    # With Deep Supervision
    axes[3].imshow(original, cmap='gray')
    axes[3].imshow(pred_ds, alpha=0.3, cmap='Greens')
    axes[3].set_title(f"UNet+++ (With DS) dice:{dice_ds:.3f}", fontsize=9)
    axes[3].axis('off')

    plt.show()

idx = 3

# Example usage:
original = diff_images[idx]
gt = diff_atrial_gt[idx].cpu().numpy()
pred_no_ds = diff_atrial_tensor_base[idx].cpu().numpy()
pred_ds = diff_atrial_tensor_u2plus[idx].cpu().numpy()
visualize_comparison(original, gt, pred_no_ds, pred_ds)

In [None]:
def do_clahe(image, clip_limit=4.0):
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=(8, 8))
    return clahe.apply(image)

# create a function which creates randomly selects 5 images in the top tow row and their correponsing clahe in the bottom row



In [None]:


def visualize_random_images_with_clahe(image_paths: list[str],
                                       target_size: int = 960,
                                       num_samples: int = 5):
    """
    Randomly selects and visualizes images with and without CLAHE.

    Args:
        image_paths (list[str]): List of paths to image files.
        target_size (int): The size to resize the images to.
        num_samples (int): The number of random images to display.
    """
    if not image_paths:
        print("No image paths provided.")
        return

    # Ensure num_samples is not more than the number of available images
    num_samples = min(num_samples, len(image_paths))

    # Randomly select image paths
    selected_paths = random.sample(image_paths, num_samples)

    fig, axes = plt.subplots(2, num_samples, figsize=(num_samples * 4, 8))

    # If only one sample, axes will be 1D - convert to 2D for consistency
    if num_samples == 1:
        axes = axes.reshape(2, 1)

    for i, img_path in enumerate(selected_paths):
        # Read original image
        original_image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)

        if original_image is None:
            print(f"Warning: Could not read image at {img_path}")
            continue

        # Resize original image to square
        resized_original = resize_to_square(original_image, target_size)

        # Apply CLAHE and resize to square
        clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
        clahe_image = clahe.apply(original_image)
        resized_clahe = resize_to_square(clahe_image, target_size)

        # Plot original (resized) image
        axes[0, i].imshow(resized_original, cmap='gray')
        axes[0, i].set_title(f"Original", fontsize=10)
        axes[0, i].axis('off')

        # Plot CLAHE (resized) image
        axes[1, i].imshow(resized_clahe, cmap='gray')
        axes[1, i].set_title("CLAHE Applied", fontsize=10)
        axes[1, i].axis('off')

    plt.tight_layout()
    plt.show()

# Example usage (assuming valid_images_paths is defined from your code)
visualize_random_images_with_clahe(valid_images_paths, target_size=960, num_samples=5)

In [None]:
dice_atrium_base_train = calculate_dice_score(ground_truth_atrial_tensor_train, train_atrial_tensor_base)
dice_atrium_base_test = calculate_dice_score(ground_truth_atrial_tensor_test, test_atrial_tensor_base)
dice_atrium_ds_train = calculate_dice_score(ground_truth_atrial_tensor_train, train_atrial_tensor_ds)
dice_atrium_ds_test = calculate_dice_score(ground_truth_atrial_tensor_test, test_atrial_tensor_ds)
dice_atrium_u2plus_train = calculate_dice_score(ground_truth_atrial_tensor_train, train_atrial_tensor_u2plus)
dice_atrium_u2plus_test = calculate_dice_score(ground_truth_atrial_tensor_test, test_atrial_tensor_u2plus)
dice_atrium_u2plus_no_clahe_train = calculate_dice_score(ground_truth_atrial_tensor_train, train_atrial_tensor_u2plus_no_clahe)
dice_atrium_u2plus_no_clahe_test = calculate_dice_score(ground_truth_atrial_tensor_test, test_atrial_tensor_u2plus_no_clahe)
dice_atrium_bce_test = calculate_dice_score(ground_truth_atrial_tensor_test, test_atrial_tensor_bce)
dice_atrium_focal_test = calculate_dice_score(ground_truth_atrial_tensor_test, test_atrial_tensor_focal)

# IOU score for atrial models
iou_atrium_base_train = calculate_iou_score(ground_truth_atrial_tensor_train, train_atrial_tensor_base)
iou_atrium_base_test = calculate_iou_score(ground_truth_atrial_tensor_test, test_atrial_tensor_base)
iou_atrium_ds_train = calculate_iou_score(ground_truth_atrial_tensor_train, train_atrial_tensor_ds)
iou_atrium_ds_test = calculate_iou_score(ground_truth_atrial_tensor_test, test_atrial_tensor_ds)
iou_atrium_u2plus_train = calculate_iou_score(ground_truth_atrial_tensor_train, train_atrial_tensor_u2plus)
iou_atrium_u2plus_test = calculate_iou_score(ground_truth_atrial_tensor_test, test_atrial_tensor_u2plus)
iou_atrium_u2plus_no_clahe_train = calculate_iou_score(ground_truth_atrial_tensor_train, train_atrial_tensor_u2plus_no_clahe)
iou_atrium_u2plus_no_clahe_test = calculate_iou_score(ground_truth_atrial_tensor_test, test_atrial_tensor_u2plus_no_clahe)
iou_atrium_bce_test = calculate_iou_score(ground_truth_atrial_tensor_test, test_atrial_tensor_bce)
iou_atrium_focal_test = calculate_iou_score(ground_truth_atrial_tensor_test, test_atrial_tensor_focal)

# Dice score for catheter models
dice_cath_base_train = calculate_dice_score(ground_truth_catheter_tensor_train, train_catheter_tensor_base)
dice_cath_base_test = calculate_dice_score(ground_truth_catheter_tensor_test, test_catheter_tensor_base)
dice_cath_ds_train = calculate_dice_score(ground_truth_catheter_tensor_train, train_catheter_tensor_ds)
dice_score_cath_ds_test = calculate_dice_score(ground_truth_catheter_tensor_test, test_catheter_tensor_ds)
dice_cath_u2plus_train = calculate_dice_score(ground_truth_catheter_tensor_train, train_catheter_tensor_u2plus)
dice_cath_u2plus_test = calculate_dice_score(ground_truth_catheter_tensor_test, test_catheter_tensor_u2plus)
dice_cath_u2plus_no_clahe_train = calculate_dice_score(ground_truth_catheter_tensor_train, train_catheter_tensor_u2plus_no_clahe)
dice_cath_u2plus_no_clahe_test = calculate_dice_score(ground_truth_catheter_tensor_test, test_catheter_tensor_u2plus_no_clahe)
dice_cathe_bce_test = calculate_dice_score(ground_truth_catheter_tensor_test, test_catheter_tensor_bce)
dice_cathe_focal_test = calculate_dice_score(ground_truth_catheter_tensor_test, test_catheter_tensor_focal)

# IOU score for catheter models
iou_cath_base_train = calculate_iou_score(ground_truth_catheter_tensor_train, train_catheter_tensor_base)
iou_cath_base_test = calculate_iou_score(ground_truth_catheter_tensor_test, test_catheter_tensor_base)
iou_cath_ds_train = calculate_iou_score(ground_truth_catheter_tensor_train, train_catheter_tensor_ds)
iou_cath_ds_test = calculate_iou_score(ground_truth_catheter_tensor_test, test_catheter_tensor_ds)
iou_cath_u2plus_train = calculate_iou_score(ground_truth_catheter_tensor_train, train_catheter_tensor_u2plus)
iou_cath_u2plus_test = calculate_iou_score(ground_truth_catheter_tensor_test, test_catheter_tensor_u2plus)
iou_cath_u2plus_no_clahe_train = calculate_iou_score(ground_truth_catheter_tensor_train, train_catheter_tensor_u2plus_no_clahe)
iou_cath_u2plus_no_clahe_test = calculate_iou_score(ground_truth_catheter_tensor_test, test_catheter_tensor_u2plus_no_clahe)
iou_cath_bce_test = calculate_iou_score(ground_truth_catheter_tensor_test, test_catheter_tensor_bce)
iou_cath_focal_test = calculate_iou_score(ground_truth_catheter_tensor_test, test_catheter_tensor_focal)

print(f'Model: Base UNet+++; Atrial Segmentation-- Dice: {dice_atrium_base_test:.4f}; mIoU: {iou_atrium_base_test:.4f} , CVC segmentation--Dice {dice_cath_base_test:.4f};   mIoU: {iou_cath_base_test:.4f}')
print(f'==\n')
print(f'Model: UNet+++ with Deep Supervision; Atrial Segmentation--Dice: {dice_atrium_ds_test:.4f}; mIoU: {iou_atrium_ds_test:.4f} , CVC segmentation--Dice {dice_score_cath_ds_test:.4f};   mIoU: {iou_cath_ds_test:.4f}')



In [None]:
print(f'UNet++ model: Atrial Segmentation -- Dice: {dice_atrium_u2plus_test:.4f}; mIoU: {iou_atrium_u2plus_test:.4f} , CVC segmentation--Dice {dice_cath_u2plus_test:.4f};   mIoU: {iou_cath_u2plus_test:.4f}')

In [None]:
print(f'UNet++ model: Atrial Segmentation -- Dice: {dice_atrium_u2plus_no_clahe_test:.4f}; mIoU: {iou_atrium_u2plus_no_clahe_test:.4f} , CVC segmentation--Dice {dice_cath_u2plus_no_clahe_test:.4f};   mIoU: {iou_cath_u2plus_no_clahe_test:.4f}')

In [None]:
print(f'UNet++ model: Atrial Segmentation -- Dice: {dice_atrium_bce_test:.4f}; mIoU: {iou_atrium_bce_test:.4f} , CVC segmentation--Dice {dice_cathe_bce_test:.4f};   mIoU: {iou_cath_bce_test:.4f}')

In [None]:
print(f'UNet++ model: Atrial Segmentation -- Dice: {dice_atrium_focal_test:.4f}; mIoU: {iou_atrium_focal_test:.4f} , CVC segmentation--Dice {dice_cathe_focal_test:.4f};   mIoU: {iou_cath_focal_test:.4f}')

In [None]:
# Download and save utils file from git hub
url = 'https://raw.githubusercontent.com/Blaise-bf/thesis-files/refs/heads/main/utills.py'

save_file = 'utills.py'

save_file_locally(url, save_file)

save_file_locally('https://raw.githubusercontent.com/Blaise-bf/thesis-files/refs/heads/main/model_setup.py', 'model_setup.py')

### Calibration of thresh hold for segmentation

In [None]:
train_catheter_tensor_ds_prob, train_atrial_tensor_ds_prob, _ = process_masks(valid_images_paths,
                                                                                 unet_cath_model_load,
                                                                                 unet_atrium_model_load,
                                                                                 device,
                                                                              model_name='unet_deep_sup',
                                                                              prob=True)

In [None]:
test_catheter_tensor_ds_prob, test_atrial_tensor_ds_prob, _ = process_masks(valid_test_paths,
                                                                                 unet_cath_model_load,
                                                                                 unet_atrium_model_load,
                                                                                 device,
                                                                                 model_name='unet_deep_sup',
                                                                                   prob=True)

In [None]:
ground_truth_catheter_tensor_test.cpu().numpy().flatten()

In [None]:
import numpy as np


# Flatten to 1D arrays
y_true_flat = ground_truth_catheter_tensor_test[0].cpu().numpy().flatten()  # shape [H*W]
y_prob_flat = test_catheter_tensor_ds_prob[0].cpu().numpy().flatten()  # shape [H*W]

In [None]:
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt

prob_true, prob_pred = calibration_curve(y_true_flat, y_prob_flat, n_bins=10)

plt.figure(figsize=(8, 6))
plt.plot(prob_pred, prob_true, marker='o', label='Your Model')
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfectly Calibrated')
plt.xlabel('Predicted Probability')
plt.ylabel('True Probability')
plt.title('Calibration Curve')
plt.legend()
plt.show()

In [None]:
from sklearn.metrics import precision_recall_curve, f1_score

precision, recall, thresholds = precision_recall_curve(y_true_flat, y_prob_flat)

# Calculate F1 scores
f1_scores = 2 * (precision * recall) / (precision + recall + 1e-9)

# Optimal threshold (max F1)
optimal_idx = np.argmax(f1_scores[:-1])  # last threshold is added by sklearn
optimal_threshold = thresholds[optimal_idx]

print(f"Optimal threshold: {optimal_threshold:.3f}")
print(f"Max F1-score: {f1_scores[optimal_idx]:.3f}")

In [None]:
from sklearn.calibration import CalibratedClassifierCV
from sklearn.linear_model import LogisticRegression

# Reshape for sklearn (needs 2D input)
X_calib = y_prob_flat.reshape(-1, 1)
y_calib = y_true_flat

# Platt scaling calibration
calibrator = LogisticRegression(C=1e5, solver='lbfgs')
calibrator.fit(X_calib, y_calib)

# Calibrated probabilities
y_prob_calibrated = calibrator.predict_proba(X_calib)[:, 1]

In [None]:
# Plot a historgam of predicted probabilities and color them by the minary labels from the ground truth
plt.figure(figsize=(10, 6))
plt.hist(y_prob_calibrated, bins=20, alpha=0.5, label='Predicted Probabilities')
plt.xlabel('Predicted Probability')
plt.ylabel('Frequency')
plt.title('Histogram of Predicted Probabilities')
plt.legend()
plt.show()

### Non Deep Learning approach:

This sections aims to setup a baseline model which will later on be compared to the deep learning approach. The general Idea is to mimic the logic used by radiologist to ascertain whether a catheter is adequately positioned or not. These is based on the following features.
* Distance for catheter's tip to top of the upper border of the atrial mask
* Dsitance from the catheter's tip to the lower border of the atrium (too short and two high is might indequate inadequate positioning)
* Fraction of catheter in the upper third for the right atrail border. This tries to mimic the recommended location of CVC's
*Inter section of over nnion

In [None]:
from utills import extract_features, plot_catheter_feature_correlation

In [None]:
def combine_channels(chanel1, chanel2):

    return torch.stack([chanel1, chanel2], dim = 1)

def combine_tensor_batch(batch1, batch2):
    return torch.cat([batch1, batch2], dim=0)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

def extract_features_from_tensors(atrium_tensor, catheter_tensor, extract_features_func):
    """
    Extracts features from batches of atrium and catheter masks.

    Args:
        atrium_tensor (torch.Tensor): Batch of atrium masks (N, H, W).
        catheter_tensor (torch.Tensor): Batch of catheter masks (N, H, W).
        extract_features_func (function): The function to extract features from individual masks.

    Returns:
        tuple: A tuple containing:
            - features_df (pd.DataFrame): DataFrame containing the extracted features.
            - error_indices (list): List of indices where feature extraction failed.
    """
    all_features = []
    error_indices = []

    for i in range(atrium_tensor.shape[0]):
        features = extract_features_func(atrium_tensor[i], catheter_tensor[i], index=i)
        if features is not None:
            all_features.append(features)
        else:
            error_indices.append(i)

    if all_features:
        features_df = pd.DataFrame(all_features)
    else:
        print("No features extracted successfully. DataFrame could not be created.")
        features_df = pd.DataFrame() # Return an empty DataFrame

    return features_df, error_indices

# Example usage (assuming atrium_tensor, catheter_tensor, and extract_features exist)

In [None]:
ground_truth_atrial_tensor_test.shape

In [None]:
catheter_tensor_original = combine_tensor_batch(ground_truth_catheter_tensor_train, ground_truth_catheter_tensor_test)
atria_tensor_original = combine_tensor_batch(ground_truth_atrial_tensor_train, ground_truth_atrial_tensor_test)

original_features_df, error_indices = extract_features_from_tensors(catheter_tensor_original,
                                                                    atria_tensor_original,
                                                                  extract_features)

print(f"Error indices: {error_indices}")

In [None]:
catheter_tensor_base = combine_tensor_batch(train_catheter_tensor_base, test_catheter_tensor_base)
atria_tensor_base = combine_tensor_batch(train_atrial_tensor_base, test_atrial_tensor_base)

predicted_features_base_df, error_indices = extract_features_from_tensors(catheter_tensor_base,
                                                        atria_tensor_base,
                                                        extract_features)
print(f"Error indices: {error_indices}")

In [None]:
catheter_tensor_ds = combine_tensor_batch(train_catheter_tensor_ds, test_catheter_tensor_ds)
atria_tensor_ds = combine_tensor_batch(train_atrial_tensor_ds, test_atrial_tensor_ds)

predicted_features_ds_df , error_indices = extract_features_from_tensors(catheter_tensor_ds,
                                                                         atria_tensor_ds,
                                                                         extract_features)
print(f"Error indices: {error_indices}")

In [None]:
catheter_tensor_u2plus = combine_tensor_batch(train_catheter_tensor_u2plus, test_catheter_tensor_u2plus)
atria_tensor_u2plus = combine_tensor_batch(train_atrial_tensor_u2plus, test_atrial_tensor_u2plus)

predicted_features_u2plus_df, error_indices = extract_features_from_tensors(catheter_tensor_u2plus,
                                                                             atria_tensor_u2plus,
                                                                             extract_features)

In [None]:
catheter_tensor_u2plus_no_clahe = combine_tensor_batch(train_catheter_tensor_u2plus_no_clahe, test_catheter_tensor_u2plus_no_clahe)
atria_tensor_u2plus_no_clahe = combine_tensor_batch(train_atrial_tensor_u2plus_no_clahe, test_atrial_tensor_u2plus_no_clahe)

predicted_features_u2plus_no_clahe_df, error_indices = extract_features_from_tensors(catheter_tensor_u2plus_no_clahe,
                                                                             atria_tensor_u2plus_no_clahe,
                                                                              extract_features)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle, Patch
from skimage.measure import regionprops




def viz_spatial_features(atrium_mask, catheter_mask, features):
    """
    Visualizes spatial features extracted from atrium and catheter masks.

    Args:
        atrium_mask: Atrium segmentation mask (NumPy array).
        catheter_mask: Catheter segmentation mask (NumPy array).
        features: Dictionary containing extracted features.
    """

    # Assuming atrium_mask and catheter_mask are NumPy arrays
    # Get region properties for the atrium to calculate minr, minc, maxr, maxc
    atrium_props = regionprops(atrium_mask.astype(int))[0]
    minr, minc, maxr, maxc = atrium_props.bbox
    H = maxr - minr
    width = maxc - minc  # Calculate width for bounding box

    # Get region properties for the catheter to calculate centroid
    catheter_props = regionprops(catheter_mask.astype(int))[0]
    r_cent, c_cent = catheter_props.centroid

    # Calculate third for upper-third region
    third = H // 3

    fig, ax = plt.subplots()

    # Show atrium mask as background with transparency
    atrium_mask_rgba = np.zeros((*atrium_mask.shape, 4), dtype=np.float32)
    atrium_mask_rgba[atrium_mask > 0, 3] = 0.3  # Set alpha to 0.3 where mask is True
    atrium_mask_rgba[atrium_mask > 0, 1] = 1  # Set green channel to 1
    ax.imshow(atrium_mask_rgba)

    # Show catheter mask on top with transparency
    catheter_mask_rgba = np.zeros((*catheter_mask.shape, 4), dtype=np.float32)
    catheter_mask_rgba[catheter_mask > 0, 3] = 0.8  # Set alpha to 0.8 where mask is True
    catheter_mask_rgba[catheter_mask > 0, 0] = 1  # Set red channel to 1
    ax.imshow(catheter_mask_rgba)

    # Bounding box (blue)
    bbox_patch = Rectangle((minc, minr), width, H, fill=False, linewidth=2, edgecolor='blue')
    ax.add_patch(bbox_patch)

    # Centroid (yellow)
    centroid_patch = Circle((c_cent, r_cent), radius=5, fill=True, color='yellow')
    ax.add_patch(centroid_patch)

    # Upper-third region (cyan)
    upper_third_patch = Rectangle((minc, minr), width, 2 * third, fill=True, alpha=0.2, color='cyan')
    ax.add_patch(upper_third_patch)

    # Create legend elements
    legend_elements = [
        Patch(facecolor='blue', edgecolor='blue', label='Atrium bbox'),
        Patch(facecolor='red', edgecolor='red', label='Catheter', alpha=0.8),  # Adjust alpha for legend
        Patch(facecolor='yellow', edgecolor='yellow', label='Centroid'),
        Patch(facecolor='cyan', edgecolor='cyan', label='Upper-third', alpha=0.2),  # Adjust alpha for legend
    ]

    # Add legend
    ax.legend(handles=legend_elements, loc='upper right')

    plt.axis("off")
    plt.show()


# Convert the tensors to NumPy arrays
atrium_mask_np = atria_tensor_u2plus[10].cpu().numpy()
catheter_mask_np = catheter_tensor_u2plus[10].cpu().numpy()
# Call viz_spatial_features
viz_spatial_features(atrium_mask_np, catheter_mask_np, predicted_features_u2plus_df.iloc[10])

viz_spatial_features(atrium_mask_np, catheter_mask_np, original_features_df.iloc[10])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle, Circle, Patch
from skimage.measure import regionprops

# 1) Modify viz_spatial_features to accept an Axes:
def viz_spatial_features(atrium_mask, catheter_mask, features, ax):
    # compute bbox & centroid exactly as before…
    atrium_props = regionprops(atrium_mask.astype(int))[0]
    minr, minc, maxr, maxc = atrium_props.bbox
    H = maxr - minr
    width = maxc - minc

    catheter_props = regionprops(catheter_mask.astype(int))[0]
    r_cent, c_cent = catheter_props.centroid
    third = H // 3

    # draw atrium background (green, α=0.3)
    atrium_rgba = np.zeros((*atrium_mask.shape, 4), float)
    atrium_rgba[atrium_mask>0,1] = 1
    atrium_rgba[atrium_mask>0,3] = 0.3
    ax.imshow(atrium_rgba)

    # draw catheter (red, α=0.8)
    cath_rgba = np.zeros((*catheter_mask.shape,4), float)
    cath_rgba[catheter_mask>0,0] = 1
    cath_rgba[catheter_mask>0,3] = 0.8
    ax.imshow(cath_rgba)

    # bbox, centroid, upper‐third
    ax.add_patch(Rectangle((minc, minr), width, H,
                           fill=False, lw=2, edgecolor='blue'))
    ax.add_patch(Circle((c_cent, r_cent), 5, color='yellow'))
    ax.add_patch(Rectangle((minc, minr), width, 2*third,
                           fill=True, alpha=0.2, color='cyan'))

    ax.set_xticks([]); ax.set_yticks([])

# 2) Prepare your 9 samples:
#    Suppose you have lists: atria = [atrium1, …, atrium9],
#                            caths = [cat1, …, cat9],
#                            feats = [feat1, …, feat9]
#    Replace these with your actual data.
atria = atria_tensor_u2plus[:9]
caths = catheter_tensor_u2plus[:9]
feats = predicted_features_u2plus_df.head(9)  # your real features dicts go here

# 3) Create 3×3 grid and plot
fig, axes = plt.subplots(3,3, figsize=(9,9))
for ax, atr_mask, cath_mask, fdict in zip(axes.flat, atria, caths, feats):
    viz_spatial_features(atr_mask.numpy(), cath_mask.numpy(), fdict, ax)

plt.tight_layout()
plt.show()


In [None]:
met_df.head()

In [None]:


# Reverse coding of the 'tip' column
met_df['tip'] = met_df['tip'].apply(lambda x: 0 if x == 1 else 1)
met_df['tip_lab'] = met_df['tip'].apply(lambda x: 'Malpositioned' if x == 1 else 'Not malpositioned')

In [None]:
met_df.head()

In [None]:
id_label_pair = dict(zip(met_df['ap_id'], met_df['tip']))

In [None]:
all_ids = common_ids + common_ids_test
labels = [int(id_label_pair[id]) for id in all_ids]
original_features_df['tip'] = labels
predicted_features_base_df['tip'] = labels
predicted_features_ds_df['tip'] = labels
predicted_features_u2plus_df['tip'] = labels
predicted_features_u2plus_no_clahe_df['tip'] = labels

predicted_features_ds_df['tip_lab'] = predicted_features_ds_df['tip'].apply(lambda x: 'Malpositioned' if x == 1 else 'Not malpositioned')
predicted_features_base_df['tip_lab'] = predicted_features_base_df['tip'].apply(lambda x: 'Malpositioned' if x == 1 else 'Not malpositioned')
original_features_df['tip_lab'] = original_features_df['tip'].apply(lambda x: 'Malpositioned' if x == 1 else 'Not malpositioned')
predicted_features_u2plus_df['tip_lab'] = predicted_features_u2plus_df['tip'].apply(lambda x: 'Malpositioned' if x == 1 else 'Not malpositioned')


In [None]:
met_df['tip'].value_counts(normalize=True)

In [None]:
plot_catheter_feature_correlation(original_features_df)

In [None]:
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

# Define selected features
selected_columns = [
    "loc_norm", "frac_atrium_covered", "dist_to_atria_top", "dist_to_atria_bottom", "length",
    "orientation", "eccentricity", "curvature"
] + [f"hu_catheter_{i}" for i in range(7)]

# Filter and drop NA
filtered_df = original_features_df[selected_columns + ['tip']].dropna() # Keep 'tip' in the filtering process
X = filtered_df[selected_columns]
y = filtered_df['tip']  # Target variable is now aligned with X

# ... (rest of your code)
# Standardize and reduce to 10 PCs in pipeline
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('pca', PCA(n_components=10)),
    ('clf', RandomForestClassifier(random_state=42))
])

# Perform 5-fold cross-validation
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(pipeline, X, y, cv=cv, scoring='accuracy')

# Display results
mean_score = scores.mean()
std_score = scores.std()

mean_score, std_score




In [None]:
from sklearn.metrics import make_scorer, f1_score, roc_auc_score
from sklearn.model_selection import cross_validate

# Define custom scoring
scoring = {
    'accuracy': 'accuracy',
    'f1': make_scorer(f1_score),
    'roc_auc': 'roc_auc'
}

'tip_lab' in predicted_features_ds_df.columns

In [None]:
import seaborn as sns

sns.pairplot(predicted_features_ds_df[[
       "loc_norm", "iou", "frac_upper", "dist_to_atria_top", "dist_to_atria_bottom", "tip_lab"]], hue='tip_lab')

# for the selected columns, remove _ from them in the axis lables



In [None]:
sns.pairplot(predicted_features_u2plus_df[[
       "loc_norm", "iou", "frac_upper", "dist_to_atria_top", "dist_to_atria_bottom", "tip_lab"]], hue='tip_lab')

In [None]:
# "loc_norm", "frac_atrium_covered", "dist_to_atria_top", "dist_to_atria_bottom"

# create a box plot for each of the columns above with respect to the target variable (tip_lab)
sns.boxplot(x='tip_lab', y='iou', hue='tip_lab', data=predicted_features_u2plus_df)
plt.xlabel('Tip Position', fontsize=12);
plt.ylabel('IOU', fontsize=12);




In [None]:
sns.boxplot(x='tip_lab', y='frac_upper', hue='tip_lab', data=predicted_features_u2plus_df)
plt.ylabel('Upper third fraction covered', fontsize=12)
plt.xlabel('Tip Position', fontsize=12);


In [None]:

sns.boxplot(x='tip_lab', y='dist_to_atria_top', hue='tip_lab', data=predicted_features_u2plus_df)
plt.ylabel('Distance to top of atrial border', fontsize=12)
plt.xlabel('Tip Position', fontsize=12);


In [None]:

sns.boxplot(x='tip_lab', y='dist_to_atria_bottom', hue='tip_lab', data=predicted_features_u2plus_df)
plt.ylabel('Distance to bottom of atrial border', fontsize=12)
plt.xlabel('Tip Position', fontsize=12);

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, cross_validate
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import make_scorer, f1_score, roc_auc_score, matthews_corrcoef, recall_score, precision_score
import numpy as np
import pandas as pd

def evaluate_classifiers_with_pca(dataframe, target_column='tip', n_components=10):
    """
    Performs PCA and evaluates multiple classifiers using 5-fold cross-validation.

    Args:
        dataframe (pd.DataFrame): The input feature dataframe (must include `target_column`).
        target_column (str): The name of the target column (default is 'tip').
        n_components (int): Number of PCA components to retain (default is 10).

    Returns:
        pd.DataFrame: A summary of mean and std of accuracy, F1 score, and AUC for each model.
    """
    # Define classifiers
    classifiers = {
        "Logistic Regression": LogisticRegression(max_iter=1000, random_state=42),
        "SVM (RBF Kernel)": SVC(kernel='rbf', probability=True, random_state=42),
        "Random Forest": RandomForestClassifier(random_state=42)
    }

    # Scoring metrics
    scoring = {
        'accuracy': 'accuracy',
        'f1': make_scorer(f1_score),
        'roc_auc': 'roc_auc',
        # 'matthews_corrcoef': make_scorer(matthews_corrcoef),
        'recall': make_scorer(recall_score),
        'precision': make_scorer(precision_score)
    }

    # Define selected features
    selected_columns = [
        "loc_norm", "frac_upper", "dist_to_atria_top", "dist_to_atria_bottom", "iou"

    ]
    # + [f"hu_catheter_{i}" for i in range(7)


    # Drop missing values and align X/y
    filtered_df = dataframe[selected_columns + [target_column]].dropna()
    X = filtered_df[selected_columns]
    y = filtered_df[target_column]

    # 5-fold stratified CV
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    # Evaluate each classifier
    results = {}
    for name, clf in classifiers.items():
        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            # ('pca', PCA(n_components=n_components)),
            ('classifier', clf)
        ])
        scores = cross_validate(pipeline, X, y, cv=cv, scoring=scoring)
        results[name] = {
            "mean_accuracy": scores['test_accuracy'].mean(),
            "std_accuracy": scores['test_accuracy'].std(),
            "mean_f1": scores['test_f1'].mean(),
            "std_f1": scores['test_f1'].std(),
            "mean_auc": scores['test_roc_auc'].mean(),
            "std_auc": scores['test_roc_auc'].std(),
            "mean_recall": scores['test_recall'].mean(),
            "std_recall": scores['test_recall'].std(),
            "mean_precision": scores['test_precision'].mean(),
            "std_precision": scores['test_precision'].std()
        }


    return pd.DataFrame(results).T
predicted_mask_metrics_ds = evaluate_classifiers_with_pca(predicted_features_ds_df)
predicted_mask_metrics_base = evaluate_classifiers_with_pca(predicted_features_base_df)
predicted_mask_metrics_u2plus = evaluate_classifiers_with_pca(predicted_features_u2plus_df)

original_mask_metrics = evaluate_classifiers_with_pca(original_features_df)
predicted_mask_metrics_u2plus_no_clahe = evaluate_classifiers_with_pca(predicted_features_u2plus_no_clahe_df)

In [None]:
predicted_mask_metrics_ds

In [None]:
predicted_mask_metrics_base

In [None]:
predicted_mask_metrics_u2plus

In [None]:
original_mask_metrics

In [None]:
predicted_mask_metrics_u2plus_no_clahe

In [None]:
original_features_df.columns

In [None]:
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold
from scipy.stats import loguniform
import numpy as np
from xgboost import XGBClassifier

def nested_cv_classifiers(dataframe, target_column='tip', n_components=10):
    """
    Performs nested cross-validation with random search hyperparameter tuning.

    Args:
        dataframe (pd.DataFrame): Input feature dataframe
        target_column (str): Name of target column
        n_components (int): PCA components (currently unused)

    Returns:
        pd.DataFrame: Performance metrics with standard deviations
        dict: Best hyperparameters for each model
    """
    # Define classifiers with hyperparameter grids
    classifiers = {
        "Logistic Regression": {
            "estimator": LogisticRegression(max_iter=1000, random_state=42),
            "params": {
                'classifier__C': loguniform(1e-4, 100),
                'classifier__penalty': ['l1', 'l2'],
                'classifier__solver': ['saga']
            }
        },
        "SVM (RBF Kernel)": {
            "estimator": SVC(probability=True, random_state=42),
            "params": {
                'classifier__C': loguniform(1e-4, 100),
                'classifier__gamma': loguniform(1e-4, 100),
                'classifier__kernel': ['rbf']
            }
        },
        "Random Forest": {
            "estimator": RandomForestClassifier(random_state=42),
            "params": {
                'classifier__n_estimators': [50, 100, 200, 400],
                'classifier__max_depth': [None, 10, 20, 30, 50],
                'classifier__min_samples_split': [2, 5, 10],
                'classifier__min_samples_leaf': [1, 2, 4],
                'classifier__max_features': ['sqrt', 'log2']
            }
        },
        'XGBoost': {
            "estimator": XGBClassifier(random_state=42),
            "params": {
                'classifier__n_estimators': [50, 100, 200, 400],
                'classifier__max_depth': [3, 5, 7, 10],
                'classifier__learning_rate': [0.01, 0.1, 0.2, 0.3],
                'classifier__subsample': [0.8, 0.9, 1.0],
                'classifier__colsample_bytree': [0.8, 0.9, 1.0],
                'classifier__gamma': [0, 0.1, 0.2, 0.3, 0.4]
            }
        }
    }

    # Define metrics
    scoring = {
        'accuracy': 'accuracy',
        'f1': make_scorer(f1_score),
        'roc_auc': 'roc_auc',
        'recall': make_scorer(recall_score),
        'precision': make_scorer(precision_score)

        # 'matthews_corrcoef': make_scorer(matthews_corrcoef)
    }

    # Feature selection
    selected_columns = [
        "frac_upper", "loc_norm",
        "dist_to_atria_top", "dist_to_atria_bottom", "iou"
    ]

    # Prepare data
    filtered_df = dataframe[selected_columns + [target_column]].dropna()
    X = filtered_df[selected_columns]
    y = filtered_df[target_column]

    # Cross-validation setup
    inner_cv = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)
    outer_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    # Results storage
    results = {}
    best_params = {}

    for name, config in classifiers.items():
        # Create pipeline
        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('classifier', config["estimator"])
        ])

        # Randomized search setup
        search = RandomizedSearchCV(
            estimator=pipeline,
            param_distributions=config["params"],
            n_iter=20,
            scoring='f1',
            cv=inner_cv,
            refit=True,
            random_state=42,
            n_jobs=-1
        )

        # Nested CV with multiple metrics
        cv_results = cross_validate(
            search,
            X, y,
            cv=outer_cv,
            scoring=scoring,
            return_estimator=True
        )

        # Store best parameters
        best_params[name] = [est.best_params_ for est in cv_results['estimator']]

        # Aggregate results
        results[name] = {
            'mean_accuracy': np.mean(cv_results['test_accuracy']),
            'std_accuracy': np.std(cv_results['test_accuracy']),
            'mean_f1': np.mean(cv_results['test_f1']),
            'std_f1': np.std(cv_results['test_f1']),
            'mean_auc': np.mean(cv_results['test_roc_auc']),
            'std_auc': np.std(cv_results['test_roc_auc']),
            'mean_recall': np.mean(cv_results['test_recall']),
            'std_recall': np.std(cv_results['test_recall']),
            'mean_precision': np.mean(cv_results['test_precision']),
            'std_precision': np.std(cv_results['test_precision'])
        }

    return pd.DataFrame(results).T, best_params



In [None]:
original_features_df.columns

In [None]:
from sklearn.model_selection import RandomizedSearchCV, StratifiedKFold, cross_validate
from sklearn.linear_model import LogisticRegression # Added missing import
from sklearn.svm import SVC                       # Added missing import
from sklearn.ensemble import RandomForestClassifier # Added missing import
from sklearn.preprocessing import StandardScaler   # Added missing import
from sklearn.pipeline import Pipeline             # Added missing import
from sklearn.metrics import make_scorer, f1_score, recall_score, precision_score # Added missing imports
from scipy.stats import loguniform
import numpy as np
from xgboost import XGBClassifier

def nested_cv_classifiers(dataframe, target_column='tip', n_components=10):
    """
    Performs nested cross-validation with random search hyperparameter tuning.

    Args:
        dataframe (pd.DataFrame): Input feature dataframe
        target_column (str): Name of target column
        n_components (int): PCA components (currently unused)

    Returns:
        pd.DataFrame: Performance metrics with standard deviations
        dict: Best hyperparameters for each model
    """
    # Define classifiers with hyperparameter grids
    classifiers = {
        "Logistic Regression": {
            "estimator": LogisticRegression(max_iter=1000, random_state=42),
            "params": {
                'classifier__C': loguniform(1e-4, 100),
                'classifier__penalty': ['l1', 'l2'],
                'classifier__solver': ['saga']
            }
        },
        "SVM (RBF Kernel)": {
            "estimator": SVC(probability=True, random_state=42),
            "params": {
                'classifier__C': loguniform(1e-4, 100),
                'classifier__gamma': loguniform(1e-4, 100),
                'classifier__kernel': ['rbf']
            }
        },
        "Random Forest": {
            "estimator": RandomForestClassifier(random_state=42),
            "params": {
                'classifier__n_estimators': [50, 100, 200, 400],
                'classifier__max_depth': [None, 10, 20, 30, 50],
                'classifier__min_samples_split': [2, 5, 10],
                'classifier__min_samples_leaf': [1, 2, 4],
                'classifier__max_features': ['sqrt', 'log2']
            }
        },
        'XGBoost': {
            "estimator": XGBClassifier(random_state=42),
            "params": {
                'classifier__n_estimators': [50, 100, 200, 400],
                'classifier__max_depth': [3, 5, 7, 10],
                'classifier__learning_rate': [0.01, 0.1, 0.2, 0.3],
                'classifier__subsample': [0.8, 0.9, 1.0],
                'classifier__colsample_bytree': [0.8, 0.9, 1.0],
                'classifier__gamma': [0, 0.1, 0.2, 0.3, 0.4]
            }
        }
    }

    # Define metrics
    scoring = {
        'accuracy': 'accuracy',
        'f1': make_scorer(f1_score),
        'roc_auc': 'roc_auc',
        'recall': make_scorer(recall_score),
        'precision': make_scorer(precision_score)

        # 'matthews_corrcoef': make_scorer(matthews_corrcoef)
    }

    # Feature selection
    selected_columns = [
        "frac_upper", "loc_norm",
        "dist_to_atria_top", "dist_to_atria_bottom", "iou"
    ]

    # Prepare data
    filtered_df = dataframe[selected_columns + [target_column]].dropna()
    X = filtered_df[selected_columns]
    y = filtered_df[target_column]

    # Cross-validation setup
    inner_cv = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)
    outer_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    # Results storage
    results = {}
    best_params = {}

    for name, config in classifiers.items():
        # Create pipeline
        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('classifier', config["estimator"])
        ])

        # Randomized search setup
        # Set n_jobs=1 to avoid multiprocessing issues
        search = RandomizedSearchCV(
            estimator=pipeline,
            param_distributions=config["params"],
            n_iter=20,
            scoring='f1',
            cv=inner_cv,
            refit=True,
            random_state=42,
            n_jobs=1  # Changed from -1 to 1
        )

        # Nested CV with multiple metrics
        # Set n_jobs=1 to avoid multiprocessing issues
        cv_results = cross_validate(
            search,
            X, y,
            cv=outer_cv,
            scoring=scoring,
            return_estimator=True,
            n_jobs=1 # Changed from -1 to 1
        )

        # Store best parameters
        # The estimator returned by cross_validate when n_jobs=1 is the fitted search object
        best_params[name] = [est.best_params_ for est in cv_results['estimator']]

        # Aggregate results
        results[name] = {
            'mean_accuracy': np.mean(cv_results['test_accuracy']),
            'std_accuracy': np.std(cv_results['test_accuracy']),
            'mean_f1': np.mean(cv_results['test_f1']),
            'std_f1': np.std(cv_results['test_f1']),
            'mean_auc': np.mean(cv_results['test_roc_auc']),
            'std_auc': np.std(cv_results['test_roc_auc']),
            'mean_recall': np.mean(cv_results['test_recall']),
            'std_recall': np.std(cv_results['test_recall']),
            'mean_precision': np.mean(cv_results['test_precision']),
            'std_precision': np.std(cv_results['test_precision'])
        }

    return pd.DataFrame(results).T, best_params

In [None]:
results_original, _ = nested_cv_classifiers(original_features_df)
results_predicted_ds, best_parms = nested_cv_classifiers(predicted_features_ds_df)
results_predicted_base, _ = nested_cv_classifiers(predicted_features_base_df)
results_predicted_u2plus, _ = nested_cv_classifiers(predicted_features_u2plus_df)
results_predicted_u2plus_no_clahe, _ = nested_cv_classifiers(predicted_features_u2plus_no_clahe_df)

In [None]:
results_original

In [None]:
results_predicted_ds

In [None]:
results_predicted_base

In [None]:
results_predicted_u2plus

In [None]:
results_predicted_u2plus_no_clahe

In [None]:
from sklearn.model_selection import train_test_split
cleaned_params = {
    key.replace('classifier__', ''): value
    for key, value in best_parms['Random Forest'][0].items()
}
#


In [None]:
selected_columns = [
        "frac_upper", "loc_norm",
        "dist_to_atria_top", "dist_to_atria_bottom", "iou"
    ]

X = original_features_df[selected_columns]
y = original_features_df['tip']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
# Do a correlation heat map for selected columns
import seaborn as sns

corr_matrix = original_features_df[selected_columns].corr()
sns.heatmap(corr_matrix, annot=True)
plt.show()


In [None]:
pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('classifier', RandomForestClassifier(random_state=42, **cleaned_params))
        ])

In [None]:
model_rf = pipeline.fit(X_train, y_train)


In [None]:
model_rf.score(X_test, y_test)

In [None]:
# Get shapley values for RF
import shap

rf_model = model_rf.named_steps['classifier']


explainer = shap.TreeExplainer(rf_model, data=X_train)


shap_values = explainer.shap_values(X_test)


shap.summary_plot(shap_values[:,:, 1], X_test)


In [None]:
# plot feature importance from the random forest model
import matplotlib.pyplot as plt
import numpy as np

feature_importances = model_rf.named_steps['classifier'].feature_importances_
sorted_idx = np.argsort(feature_importances)[::-1]

# Get the sorted feature names
sorted_feature_names = np.array(selected_columns)[sorted_idx]

# Process the feature names: replace _ with space and capitalize the first word
processed_feature_names = [
    name.replace('_', ' ').capitalize() for name in sorted_feature_names
]


fig, ax = plt.subplots()
# Use barh for horizontal bars
ax.barh(range(feature_importances.shape[0]), feature_importances[sorted_idx],
        color='dodgerblue')

# Set y-axis ticks and labels (since bars are horizontal) using processed names
ax.set_yticks(range(feature_importances.shape[0]))
ax.set_yticklabels(processed_feature_names) # Use the processed names here

# Set axis labels and title (swapped for horizontal bars)
ax.set_xlabel('Feature Importance')
ax.set_ylabel('') # Y-axis now represents the features
ax.set_title('Random Forest Feature Importance')

# Invert the y-axis so the most important feature is at the top
ax.invert_yaxis()

plt.show()

### Deep Learning approach

* This section builds models based on the predicted masks obtained from the segmentation models, with use of transfer learning.
* The key Idea here is to make use of well informed data transformation to ensure there is less overfitting.
Some transformations to consider inscludes:
  - Mild Grid and Elastic distortion
  - Horizontal flip (THe prediction should be robust to this)
  - Mild Rotations

`Albumentation` is a great packge for such transformtions

In [None]:
!pip install -U albumentations --quiet

from sklearn.model_selection import StratifiedKFold

from model_setup import train_step, test_step, call_model, train_kfold_model


In [None]:
import albumentations as A


import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
from sklearn.model_selection import KFold
import numpy as np
from tqdm.auto import tqdm
from torchmetrics import F1Score
import gc

def free_gpu_memory():
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import albumentations as A
import cv2

class ClassificationDataset(Dataset):
    def __init__(self, catheter_predictions, atria_predictions,
                 labels, ids=None, original_images=None,
                 transform=None, normalize=False):
        """
        Modified dataset class with Albumentations support and proper normalization

        Args:
            catheter_predictions: Tensor/Numpy array of shape (N, H, W)
            atria_predictions: Tensor/Numpy array of shape (N, H, W)
            labels: List/Tensor of labels
            original_images: Optional original images (N, H, W) or (N, 3, H, W)
            transform: Albumentations transform pipeline
            normalize: Whether to normalize masks from 0-255 to 0-1
        """
        self.predictions = catheter_predictions
        self.labels = torch.tensor(labels, dtype=torch.float32)
        self.transform = transform
        self.atrial_mask = atria_predictions
        self.original_images = original_images
        self.ids = ids
        self.normalize = normalize

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Get raw data elements
        mask = self.predictions[idx]  # Shape (H, W)
        label = self.labels[idx]
        atria = self.atrial_mask[idx]

        # Convert to tensors if needed
        if isinstance(mask, np.ndarray):
            mask = torch.from_numpy(mask)
        if isinstance(atria, np.ndarray):
            atria = torch.from_numpy(atria)

        # Normalization for binary masks
        if self.normalize:
            mask = mask.float().div(255)
            atria = atria.float().div(255)
        else:
            mask = mask.float()
            atria = atria.float()

        # Handle third channel
        if self.original_images is not None:
            third_channel = self.original_images[idx]
            if isinstance(third_channel, np.ndarray):
                third_channel = torch.from_numpy(third_channel)
            third_channel = third_channel.float().div(255)
        else:
            third_channel = torch.zeros_like(mask)

        # Create 3-channel input (C, H, W)
        feature = torch.stack([mask, atria, third_channel], dim=0)

        # Convert to numpy array for Albumentations (H, W, C)
        feature_np = feature.permute(1, 2, 0).numpy()

        # Apply transformations
        if self.transform:
            transformed = self.transform(image=feature_np)
            feature = transformed['image']  # Albumentations handles HWC -> CHW

        if self.ids is not None:
            return feature, label, self.ids[idx]
        return feature, label

In [None]:
from sklearn.model_selection import StratifiedKFold
from torchmetrics import Accuracy, Precision, Recall, AUROC, F1Score
from tqdm.auto import tqdm
from typing import Dict, List, Tuple

In [None]:
def test_step(model: nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: nn.Module,
              device: torch.device,
              return_predictions: float = False) -> Tuple[float, float, float, float, float, float]:
    """Test step with proper type handling and metric reset"""
    model.eval()

    # Initialize metrics
    metrics = {
        'acc': Accuracy(task='binary').to(device),
        'precision': Precision(task='binary').to(device),
        'recall': Recall(task='binary').to(device),
        'auc': AUROC(task='binary').to(device),
        'f1': F1Score(task='binary').to(device)
    }

    test_loss = 0.0
    all_probs = []
    all_labels = []

    try:
        with torch.no_grad():
            for images, labels, _ in dataloader:
                images, labels = images.to(device), labels.to(device)

                # Forward pass
                outputs = model(images)
                if isinstance(outputs, tuple):
                    outputs = outputs[0]

                # Calculate loss (ensure proper types)
                loss = loss_fn(outputs, labels.float().unsqueeze(1))
                test_loss += loss.item() * images.size(0)

                # Get probabilities and ensure proper shapes
                probs = torch.sigmoid(outputs)
                preds = (probs > 0.5).float()
                all_probs.append(probs.cpu())
                all_labels.append(labels.cpu())

                # Reshape if needed
                if len(labels.shape) == 1:
                    labels = labels.unsqueeze(1)
                if len(preds.shape) == 1:
                    preds = preds.unsqueeze(1)

                # Update metrics
                for metric in metrics.values():
                    metric.update(preds, labels)

        # Compute final metrics (ensure float conversion)
        avg_loss = float(test_loss / len(dataloader.dataset))
        results = {
            'loss': avg_loss,
            'acc': float(metrics['acc'].compute().item()) ,
            'precision': float(metrics['precision'].compute().item()),
            'recall': float(metrics['recall'].compute().item()),
            'auc': float(metrics['auc'].compute().item()),
            'f1': float(metrics['f1'].compute().item())
        }

        # Concatenate all predictions
        preds_retrun = torch.cat(all_probs).numpy(), torch.cat(all_labels).numpy()


        return (
            results['loss'],
            results['acc'],
            results['precision'],
            results['recall'],
            results['auc'],
            results['f1'],
            preds_retrun
        )

    finally:
        # Reset metrics
        for metric in metrics.values():
            metric.reset()



In [None]:
def train_kfold_model(catheter_predictions,
                     atria_predictions,
                     labels,
                     ids=None,
                     original_images=None,
                     size=600,
                     model_name='efficientnet_b7',
                     num_epochs=60,
                     batch_size=8,
                     k=5,
                     fine_tune='last_two',
                      data_aug=True):
    """Stratified K-fold training with tqdm progress bars"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_gpus = torch.cuda.device_count()
    IMG_SIZE = size



    if data_aug:

      train_transform = A.Compose([
          A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
          A.HorizontalFlip(p=0.3),
          A.Rotate(limit=15, p=0.2, border_mode=cv2.BORDER_CONSTANT),
          A.GridElasticDeform(
        num_grid_xy=(5, 5),
            magnitude=3,
          interpolation=cv2.INTER_NEAREST,
                  mask_interpolation=cv2.INTER_NEAREST,
            p=0.2
        )  ,
          A.ToTensorV2()
      ])
    else:
      train_transform = A.Compose([
          A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
          A.ToTensorV2()
      ])


    val_transform = A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.ToTensorV2()
    ])

    # Use StratifiedKFold instead of KFold
    splits = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)
    foldperf = {}
    fold_predictions = {}  # Store predictions for each fold
    # labels_int = [int(label) for label in labels]

    # Convert labels to numpy array for stratified splitting
    labels_np = np.array(labels)

    # StratifiedKFold.split needs both features (X) and labels (y)
    # We'll use indices as dummy features since we have separate components
    dummy_X = np.arange(len(labels_np))

    for fold, (train_idx, val_idx) in enumerate(splits.split(dummy_X, labels_np)):
        print(f'\n=== Fold {fold + 1}/{k} ===')
        print(f"Train class distribution: {np.bincount(labels_np[train_idx])}")
        print(f"Val class distribution: {np.bincount(labels_np[val_idx])}")

        # Split raw data components using stratified fold indices
        train_catheter = catheter_predictions[train_idx]
        train_atria = atria_predictions[train_idx]
        train_labels = [labels[i] for i in train_idx]
        train_ids = [ids[i] for i in train_idx] if ids else None
        train_originals = original_images[train_idx] if original_images is not None else None

        val_catheter = catheter_predictions[val_idx]
        val_atria = atria_predictions[val_idx]
        val_labels = [labels[i] for i in val_idx]
        val_ids = [ids[i] for i in val_idx] if ids else None
        val_originals = original_images[val_idx] if original_images is not None else None


        # Create datasets with appropriate transforms
        train_dataset = ClassificationDataset(
            catheter_predictions=train_catheter,
            atria_predictions=train_atria,
            labels=train_labels,
            ids=train_ids,
            original_images=train_originals,
            transform=train_transform
        )

        val_dataset = ClassificationDataset(
            catheter_predictions=val_catheter,
            atria_predictions=val_atria,
            labels=val_labels,
            ids=val_ids,
            original_images=val_originals,
            transform=val_transform
        )

        # Create dataloaders with stratified batches
        effective_batch = batch_size * max(1, num_gpus)

        # For training, we can use StratifiedSampler if needed (optional)
        train_loader = DataLoader(train_dataset, batch_size=effective_batch,
                                shuffle=True, pin_memory=True, num_workers=2)
        test_loader = DataLoader(val_dataset, batch_size=effective_batch,
                               shuffle=False, pin_memory=True, num_workers=2)

        # Rest of your training code remains the same...
        model = call_model(model_name=model_name, device='cpu', fine_tune=fine_tune)
        if num_gpus > 1:
            model = nn.DataParallel(model)
        model = model.to(device)

        criterion = nn.BCEWithLogitsLoss()
        optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                        lr=0.0001, weight_decay=0.001)

        # LR scheduling
        warmup_scheduler = LinearLR(optimizer, start_factor=0.01,
                                  end_factor=1.0, total_iters=5)
        cosine_scheduler = CosineAnnealingLR(optimizer,
                                           T_max=num_epochs-5, eta_min=1e-6)
        scheduler = SequentialLR(optimizer,
                               schedulers=[warmup_scheduler, cosine_scheduler],
                               milestones=[5])

        history = {
            'train_loss': [], 'test_loss': [],
            'train_acc': [], 'test_acc': [],
            'test_precision': [], 'test_recall': [],
            'test_auc': [], 'test_f1': []
        }
        # For storing final fold predictions
        fold_probs = None
        fold_labels = None

        for epoch in tqdm(range(num_epochs), desc=f'Epochs'):
            free_gpu_memory()

            # Training
            train_loss, train_acc = train_step(
                model, train_loader, criterion, optimizer, scheduler, device)

            # Validation - get predictions only on last epoch
            return_preds = (epoch == num_epochs - 1)
            test_metrics = test_step(
                model, test_loader, criterion, device, return_predictions=return_preds)

            (test_loss, test_acc, test_precision,
             test_recall, test_auc, test_f1, preds) = test_metrics

            # Store predictions if available
            if return_preds and preds is not None:
                fold_probs, fold_labels = preds

            # # Training
            # train_loss, train_acc = train_step(
            #     model, train_loader, criterion, optimizer, scheduler, device)

            # # Validation with all metrics
            # (test_loss, test_acc, test_precision,
            #  test_recall, test_auc, test_f1) = test_step(
            #     model, test_loader, criterion, device)

            # Store results
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['test_loss'].append(test_loss)
            history['test_acc'].append(test_acc)
            history['test_precision'].append(test_precision)
            history['test_recall'].append(test_recall)
            history['test_auc'].append(test_auc)
            history['test_f1'].append(test_f1)

            # Enhanced progress reporting
            if (epoch + 1) % 10 == 0 or epoch == 0 or (epoch + 1) == num_epochs:
                print(f"Epoch {epoch+1:03d} | "
                      f"Train Loss: {train_loss:.4f} | "
                      f"Train Acc: {train_acc:.2f}% | "
                      f"Test Loss: {test_loss:.4f}\n"
                      f"Test Metrics: "
                      f"Acc: {test_acc*100:.2f}% | "
                      f"Precision: {test_precision:.4f} | "
                      f"Recall: {test_recall:.4f} | "
                      f"AUC: {test_auc:.4f} | "
                      f"F1: {test_f1:.4f}")

        foldperf[f'fold{fold+1}'] = history
        fold_predictions[f'fold{fold+1}'] = (fold_probs, fold_labels)

    return foldperf, fold_predictions



In [None]:
def summarize_kfold_metrics(foldperf):
    """Comprehensive metric summary"""
    metrics = {
        'acc': [], 'precision': [], 'recall': [],
        'auc': [], 'f1': []
    }

    print("\n=== Comprehensive Fold Performance ===")
    for fold in sorted(foldperf.keys()):
        fold_metrics = {
            'acc': foldperf[fold]['test_acc'][-1],
            'precision': foldperf[fold]['test_precision'][-1],
            'recall': foldperf[fold]['test_recall'][-1],
            'auc': foldperf[fold]['test_auc'][-1],
            'f1': foldperf[fold]['test_f1'][-1]
        }

        print(f"\n{fold}:")
        print(f"Accuracy: {fold_metrics['acc']*100:.2f}%")
        print(f"Precision: {fold_metrics['precision']:.4f}")
        print(f"Recall: {fold_metrics['recall']:.4f}")
        print(f"AUC: {fold_metrics['auc']:.4f}")
        print(f"F1: {fold_metrics['f1']:.4f}")

        for k in metrics.keys():
            metrics[k].append(fold_metrics[k])

    print("\n=== Aggregate Metrics ===")
    for metric, values in metrics.items():
        unit = '%' if metric == 'acc' else ''
        print(f"Mean {metric.capitalize()}: {np.mean(values):.4f}{unit} ± {np.std(values):.4f}")

    return metrics

In [None]:
import numpy as np
from sklearn.metrics import roc_auc_score, auc, roc_curve

def calculate_auc_metrics(fold_predictions):
    """
    Calculate average AUC (mean of fold AUCs) and macro AUC (combined predictions)

    Args:
        fold_predictions: Dictionary from train_kfold_model containing
                         (probs, labels) for each fold

    Returns:
        dict: Dictionary containing AUC metrics and combined predictions
    """
    # Initialize storage
    fold_aucs = []
    all_probs = []
    all_labels = []

    # Process each fold
    for fold, (probs, labels) in fold_predictions.items():
        # Calculate fold AUC
        fold_auc = roc_auc_score(labels, probs)
        fold_aucs.append(fold_auc)

        # Collect for macro calculation
        all_probs.append(probs)
        all_labels.append(labels)

    # Convert to arrays
    all_probs = np.concatenate(all_probs)
    all_labels = np.concatenate(all_labels)

    # Calculate metrics
    metrics = {
        'average_auc': np.mean(fold_aucs),
        'std_auc': np.std(fold_aucs),
        'macro_auc': roc_auc_score(all_labels, all_probs),
        'fold_aucs': fold_aucs,
        'all_probs': all_probs,
        'all_labels': all_labels
    }

    # Print summary
    print(f"Average AUC (mean of folds): {metrics['average_auc']:.4f} ± {metrics['std_auc']:.4f}")
    print(f"Macro AUC (combined data):   {metrics['macro_auc']:.4f}")
    print("\nIndividual Fold AUCs:")
    for fold, auc_value in zip(fold_predictions.keys(), metrics['fold_aucs']):
        print(f"{fold}: {auc_value:.4f}")

    return metrics



In [None]:
fold_perf_no_aug, fold_predictions_no_aug = train_kfold_model(catheter_tensor_u2plus,
                                                               atria_tensor_u2plus,
                              labels, all_ids, num_epochs=120,
                              batch_size=8, data_aug=False)

In [None]:
metrics_no_aug = summarize_kfold_metrics(fold_perf_no_aug)

In [None]:
fold_perf_ds_ori_img, fold_predictions_ds_ori_img = train_kfold_model(catheter_tensor_ds,
                              atria_tensor_ds,

                              labels, all_ids, original_images=clahe_images,
                                num_epochs=120,
                              batch_size=8)

In [None]:
metrics_ori_img_ds = summarize_kfold_metrics(fold_perf_ds_ori_img)

In [None]:
fold_perf_base_ori_img, fold_predictions_base_ori_img = train_kfold_model(catheter_tensor_base,
                              atria_tensor_base,

                              labels, all_ids, original_images=clahe_images,
                                num_epochs=120,
                              batch_size=8)

In [None]:
metrics_ori_img_base = summarize_kfold_metrics(fold_perf_base_ori_img)

In [None]:
free_gpu_memory()

In [None]:
fold_perf, fold_predictions = train_kfold_model(catheter_tensor_original,
                              atria_tensor_original,
                              labels, all_ids, num_epochs=120,
                              batch_size=8)

In [None]:
original_metrics = summarize_kfold_metrics(fold_perf)



In [None]:
metrics_real = summarize_kfold_metrics(fold_perf)

In [None]:
auc_metrics_real_mask = calculate_auc_metrics(fold_predictions)

In [None]:
fold_perf, fold_predictions = train_kfold_model(catheter_tensor_original,
                              atria_tensor_original,
                              labels, all_ids, num_epochs=120,
                              batch_size=8)

In [None]:

from sklearn.metrics import roc_curve, precision_recall_curve, auc
def plot_combined_roc_curves(fold_predictions):
    """Plot all ROC curves in one figure with legend"""
    plt.figure(figsize=(8, 6))

    # Plot each fold's ROC curve
    for fold, (probs, labels) in fold_predictions.items():
        fpr, tpr, _ = roc_curve(labels, probs)
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{fold} (AUC = {roc_auc:.2f})')

    # Plot diagonal line
    plt.plot([0, 1], [0, 1], 'k--', label='Random (AUC = 0.50)')

    # Formatting
    plt.xlabel('False Positive Rate', fontsize=10)
    plt.ylabel('True Positive Rate', fontsize=10)
    # plt.title('Combined ROC Curves Across Folds', fontsize=14)
    plt.legend(loc='lower right', fontsize=10)
    plt.grid(True, alpha=0.1)
    plt.tight_layout()
    plt.show()

def plot_combined_pr_curves(fold_predictions):
    """Plot all Precision-Recall curves in one figure with legend"""
    plt.figure(figsize=(8, 6))

    # Calculate baseline (percentage of positive cases)
    _, first_labels = next(iter(fold_predictions.values()))
    baseline = np.mean(first_labels)

    # Plot each fold's PR curve
    for fold, (probs, labels) in fold_predictions.items():
        precision, recall, _ = precision_recall_curve(labels, probs)
        pr_auc = auc(recall, precision)
        plt.plot(recall, precision, label=f'{fold} (AUC = {pr_auc:.2f})')

    # Plot baseline
    plt.axhline(y=baseline, color='k', linestyle='--',
                label=f'Baseline (AUC = {baseline:.2f})')

    # Formatting
    plt.xlabel('Recall', fontsize=10)
    plt.ylabel('Precision', fontsize=10)
    # plt.title('Combined Precision-Recall Curves Across Folds', fontsize=14)
    plt.legend(loc='best', fontsize=10)
    plt.grid(True, alpha=0.1)
    plt.tight_layout()
    plt.show()

# Usage after training:
# foldperf, fold_predictions = train_kfold_model(...)
# plot_combined_roc_curves(fold_predictions)
# plot_combined_pr_curves(fold_predictions)

In [None]:
fold_perf_u2plus_last_conv, fold_predictions_u2plus_last_conv = train_kfold_model(catheter_tensor_u2plus,
                              atria_tensor_u2plus,
                              labels, all_ids, num_epochs=120,
                              fine_tune='head_only',
                              batch_size=8)

In [None]:
metrics_u2plus_last_conv = summarize_kfold_metrics(fold_perf_u2plus_last_conv)

In [None]:
auc_metrics_u2plus_last_conv = calculate_auc_metrics(fold_predictions_u2plus_last_conv)

In [None]:
fold_perf_u2plus_feat_extra, fold_predictions_u2plus_feat_extra = train_kfold_model(catheter_tensor_u2plus,
                              atria_tensor_u2plus,
                              labels, all_ids, num_epochs=120,
                              fine_tune=None,
                              batch_size=16)

In [None]:
metrics_u2plus_feat_extra = summarize_kfold_metrics(fold_perf_u2plus_feat_extra)

In [None]:
auc_metrics_feat_extra = calculate_auc_metrics(fold_predictions_u2plus_feat_extra)

In [None]:
fold_perf_uplus2, fold_predictions_uplus2 = train_kfold_model(catheter_tensor_u2plus,
                              atria_tensor_u2plus,
                              labels, all_ids, num_epochs=120,
                              batch_size=8)

In [None]:
metrics_uplus2 = summarize_kfold_metrics(fold_perf_uplus2)

In [None]:
auc_metrics_uplus2 = calculate_auc_metrics(fold_predictions_uplus2)

In [None]:
plot_combined_roc_curves(fold_predictions_uplus2)
plot_combined_pr_curves(fold_predictions_uplus2)

In [None]:
fold_predictions_uplus2['fold1']

In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold


all_preds = []
all_true = []

for train_index, test_index in kf.split(probs.reshape(-1, 1), labels):
    # Usually, you'd train your model here using train_index
    # For this static example, just collect predictions and true labels
    prob_fold = probs[test_index]
    true_fold = labels[test_index]

    # Convert probabilities to binary predictions (threshold = 0.5)
    preds_fold = (prob_fold >= 0.5).astype(int)

    all_preds.extend(preds_fold)
    all_true.extend(true_fold)

# Compute confusion matrix
cm = confusion_matrix(all_true, all_preds)
print("Confusion Matrix:")
print(cm)


In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix

def get_kfold_confusion_matrix(fold_predictions, threshold=0.5):

  all_preds = []
  all_true = []

  for fold, (probs, labels) in fold_predictions.items():
    # Convert probabilities to binary predictions (threshold = 0.5)
    preds_fold = (probs >= threshold).astype(int)
    all_preds.extend(preds_fold)
    all_true.extend(labels)

    # get the recall

    # get the specificity



  # Compute confusion matrix
  cm = confusion_matrix(all_true, all_preds)
  print("Confusion Matrix:")
  print(cm)
  return cm

  from sklearn.metrics import confusion_matrix

def evaluate_folds(fold_predictions, threshold=0.5):
    all_preds = []
    all_true = []

    for fold, (probs, labels) in fold_predictions.items():
        # Convert probabilities to binary predictions
        preds_fold = (probs >= threshold).astype(int)
        all_preds.extend(preds_fold)
        all_true.extend(labels)

    # Compute confusion matrix
    cm = confusion_matrix(all_true, all_preds)
    print("Confusion Matrix:")
    print(cm)

    # Extract components from 2x2 matrix
    TN, FP, FN, TP = cm.ravel()

    # Calculate metrics
    recall = TP / (TP + FN)  # Sensitivity/Recall
    specificity = TN / (TN + FP)  # Specificity

    print("\nPerformance Metrics:")
    print(f"Recall (Sensitivity): {recall:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print(f"Accuracy:    {(TP + TN)/(TP+TN+FP+FN):.4f}")

    return cm, recall, specificity

In [None]:
cm_05, _, _= evaluate_folds(fold_predictions_uplus2, threshold=0.5)
plot_confusion_matrix(conf_mat=cm_05, figsize=(5, 5));

In [None]:
cm_04, _, _= evaluate_folds(fold_predictions_uplus2, threshold=0.4)
plot_confusion_matrix(conf_mat=cm_04, figsize=(5, 5));


In [None]:
cm_03, _, _= evaluate_folds(fold_predictions_uplus2, threshold=0.3)
plot_confusion_matrix(conf_mat=cm_03, figsize=(5, 5));

In [None]:
cm_02, _, _= evaluate_folds(fold_predictions_uplus2, threshold=0.2)
plot_confusion_matrix(conf_mat=cm_02, figsize=(5, 5));

In [None]:
cm_01, _, _= evaluate_folds(fold_predictions_uplus2, threshold=0.1)
plot_confusion_matrix(conf_mat=cm_01, figsize=(5, 5));

In [None]:
cm_02 = get_kfold_confusion_matrix(fold_predictions_uplus2, threshold=0.2)
plot_confusion_matrix(conf_mat=cm_02, figsize=(5, 5));

In [None]:
cm_01 = get_kfold_confusion_matrix(fold_predictions_uplus2, threshold=0.1)
plot_confusion_matrix(conf_mat=cm_01, figsize=(5, 5));

In [None]:
fold_perf_uplus, fold_predictions_uplus = train_kfold_model(catheter_tensor_u2plus,
                              atria_tensor_u2plus,
                              labels, all_ids, num_epochs=120,
                              batch_size=8)

In [None]:
fold_perf_unet2, fold_predictions_unet2 = train_kfold_model(catheter_tensor_u2plus,
                              atria_tensor_u2plus,
                              labels, all_ids, num_epochs=120,
                              batch_size=8)

In [None]:
metrics_unetp2 = summarize_kfold_metrics(fold_perf_unet2)

In [None]:
auc_metrics_unetp2 = calculate_auc_metrics(fold_predictions_unet2)

In [None]:
plot_combined_roc_curves(fold_predictions_unet2)

In [None]:
plot_combined_pr_curves(fold_predictions_unet2)

In [None]:
metris_uplus = summarize_kfold_metrics(fold_perf_uplus)

In [None]:
metrics_uplus2 =  summarize_kfold_metrics(fold_perf_uplus2)

In [None]:
fold_perf_uplus2_ds, fold_predictions_uplus2_ds = train_kfold_model(catheter_tensor_original,
                              atria_tensor_u2plus,
                              labels, all_ids, num_epochs=120,
                              batch_size=8)

In [None]:
metrics_uplus2_ds = summarize_kfold_metrics(fold_perf_uplus2_ds)

In [None]:
# summarize_kfold_metrics(fold_perf_uplus2_ds)
metrics_dict = {
    'Original masks': original_metrics,
    'UNet++': metrics_unetp2,
    'UNet++ OA': metrics_uplus2_ori_atr,
    'UNet++ OC': metrics_uplus_oricath
    # 'Uplus': metris_uplus
}


In [None]:
fold_perf_uplus2_ori_cath, fold_predictions_uplus2_ori_img = train_kfold_model(catheter_tensor_u2plus,
                              atria_tensor_original,
                              labels, all_ids,
                              num_epochs=120,
                              batch_size=8)


In [None]:
metrics_uplus2_ori_atr = summarize_kfold_metrics(fold_perf_uplus2_ori_cath)

In [None]:

import json
# with open('/content/drive/MyDrive/msc_uhasselt/experiments/classification/metrics_uplus2.json', 'w') as f:
#     json.dump(metrics_dict, f)

with open('/content/drive/MyDrive/msc_uhasselt/experiments/classification/metrics_uplus2.json', 'r') as f:
    metrics_dict_prev = json.load(f)

In [None]:
metrics_uplus_oricath = summarize_kfold_metrics(fold_perf_uplus2_ori_cath)

In [None]:
fold_perf_uplus2_ori_atr, fold_predictions_uplus2_ori_atr = train_kfold_model(catheter_tensor_u2plus,
                              atria_tensor_original,
                              labels, all_ids,
                              num_epochs=120,
                              batch_size=8)

In [None]:
metrics_uplus2_ori_atr = summarize_kfold_metrics(fold_perf_uplus2_ori_atr)

In [None]:
mwtrics_uplus2_ori_atr = summarize_kfold_metrics(fold_perf_uplus2_ori_atr)

In [None]:
metrics_uplus2_ori_img = summarize_kfold_metrics(fold_perf_uplus2_ori_img)

In [None]:
def create_metrics_dataframe(metrics_dict, dataset_name):
    """Create formatted metrics DataFrame for a single dataset"""
    df = pd.DataFrame(metrics_dict)
    stats = pd.DataFrame({
        'Mean': df.mean(),
        'Std': df.std()
    }).round(4)
    # stats['Mean ± Std'] = stats['Mean'].astype(str) + ' ± ' + stats['Std'].astype(str)
    stats['Dataset'] = dataset_name
    stats['metric'] = ['acc', 'precision', 'recall', 'auc', 'f1-score']
    return stats

# Create comparison table for all datasets
comparison_dfs = []
for dataset_name, metrics in metrics_dict.items():
    comparison_dfs.append(create_metrics_dataframe(metrics, dataset_name))

final_comparison = pd.concat(comparison_dfs)
# Convert to wide format
wide_df = final_comparison.pivot(index="Dataset", columns="metric", values=["Mean", "Std"])

# Flatten multi-index columns and rename
wide_df.columns = [f"{stat}_{metric}" for stat, metric in wide_df.columns]
wide_df = wide_df.reset_index()

# Reorder columns logically
metric_order = ["acc", "precision", "recall", "auc", "f1-score"]
column_order = ["Dataset"] + [f"{stat}_{met}" for met in metric_order for stat in ["Mean", "Std"]]
wide_df = wide_df[column_order]
wide_df


In [None]:
print(f'AUC fold summary for original maks\n')
auc_metrics_real_mask = calculate_auc_metrics(fold_predictions)

print(f'AUC fold summary for UNet++\n')
auc_metrics_unetp2 = calculate_auc_metrics(fold_predictions_unet2)

print(f'AUC fold summary for UNet++  OC\n')
auc_metrics_uplus_oricath = calculate_auc_metrics(fold_predictions_uplus2_ori_img)

print(f'AUC fold summary for UNet++ OA\n')
auc_metrics_uplus2_ori_atr = calculate_auc_metrics(fold_predictions_uplus2_ori_atr)
print(f'AUC fold summary for UNet++ no clahe\n')
auc_metrics_uplus2_ds = calculate_auc_metrics(fold_predictions_no_clahe)


In [None]:
plot_combined_roc_curves(fold_predictions_uplus2_ori_atr)
plot_combined_pr_curves(fold_predictions_uplus2_ori_atr)

In [None]:
plot_combined_pr_curves(fold_predictions_uplus2_ori_img) # UNet++ predicted atrium + original catheter
plot_combined_roc_curves(fold_predictions_uplus2_ori_img) #

In [None]:
fold_perf_uplus2_ori_atrium, fold_predictions_uplus2_ori_atrium = train_kfold_model(catheter_tensor_original,
                              atria_tensor_u2plus,
                              labels, all_ids, original_images=clahe_images,
                              num_epochs=120,
                              batch_size=8)

In [None]:
free_gpu_memory()

IMG_SIZE = 600

train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.3),
        A.Rotate(limit=15, p=0.2, border_mode=cv2.BORDER_CONSTANT),
        A.GridElasticDeform(
       num_grid_xy=(5, 5),
           magnitude=4,
         interpolation=cv2.INTER_NEAREST,
                mask_interpolation=cv2.INTER_NEAREST,
           p=0.3
       )  ,
        A.ToTensorV2()
    ])

val_transform = A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.ToTensorV2()
    ])

In [None]:
fold_perf_ds, fold_predictions_ds = train_kfold_model(catheter_tensor_ds,
                              atria_tensor_ds,
                              labels, all_ids, num_epochs=120,
                              batch_size=8)



In [None]:
metrics_ds = summarize_kfold_metrics(fold_perf_ds)

In [None]:
plot_combined_roc_curves(fold_predictions_ds)
plot_combined_pr_curves(fold_predictions_ds)

In [None]:
fold_perf_base, fold_predictions_base = train_kfold_model(catheter_tensor_base,
                              atria_tensor_base,
                              labels, all_ids, num_epochs=120,
                              batch_size=8)

In [None]:
plot_combined_roc_curves(fold_predictions_base)

In [None]:
metrics_base = summarize_kfold_metrics(fold_perf_base)

In [None]:
fold_ori_atrium_pred_cath_perf = train_kfold_model(catheter_tensor_ds,
                              atria_tensor_original,
                              labels, all_ids, num_epochs=120,
                              batch_size=8)

In [None]:
metrics_ori_atrium = summarize_kfold_metrics(fold_ori_atrium_pred_cath_perf[0])

In [None]:
plot_combined_roc_curves(fold_ori_atrium_pred_cath_perf[1])

In [None]:
fold_ori_cath_pred_atrium_perf, fold_predictins_ori_cath_pred = train_kfold_model(catheter_tensor_original,
                              atria_tensor_ds,
                              labels, all_ids, num_epochs=120,
                              batch_size=8)

In [None]:
metrics_ori_cath = summarize_kfold_metrics(fold_ori_cath_pred_atrium_perf)

In [None]:
plot_combined_roc_curves(fold_predictions_ds)

In [None]:
kfol_experiments = {
    'Original masks': fold_perf,
    'UNet +++ DS': fold_perf_ds,
    'UNet +++': fold_perf_base,
    'Cath DS + Original Atrium': fold_ori_atrium_pred_cath_perf[0],
    'Atrium DS + Original Cath': fold_ori_cath_pred_atrium_perf

}
# Save as json file
import json
with open('/content/drive/MyDrive/msc_uhasselt/experiments/classification/kfold_experiments.json', 'w') as f:
    json.dump(kfol_experiments, f)

In [None]:
import json
experiments = {
    'Original masks': original_metrics,
    'UNet +++ DS': metrics_ds,
    'UNet +++': metrics_base,
    'Cath DS + Original Atrium': metrics_ori_atrium,
    'Atrium DS + Original Cath': metrics_ori_cath
}

with open('/content/drive/MyDrive/msc_uhasselt/experiments/classification/metrics.json', 'w') as f:


    json.dump(experiments, f)

# read back saved json
with open('/content/drive/MyDrive/msc_uhasselt/experiments/classification/metrics.json', 'r') as f:
    experiments = json.load(f)


wide_df


In [None]:
IMG_SIZE = 600
train_transform = A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.3),
        A.Rotate(limit=15, p=0.2, border_mode=cv2.BORDER_CONSTANT),
        A.GridElasticDeform(
       num_grid_xy=(7, 7),
           magnitude=3,
         interpolation=cv2.INTER_NEAREST,
                mask_interpolation=cv2.INTER_NEAREST,
           p=0.2
       )  ,
        A.ToTensorV2()
    ])


val_transform = A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.ToTensorV2()
    ])

In [None]:
dataset_uplus = ClassificationDataset(catheter_tensor_u2plus,
                                         atria_tensor_u2plus,
                                         labels, ids=all_ids,
                                         transform=train_transform)

In [None]:
idx = 100

img, label, id = dataset_uplus[idx]

plt.imshow(img.permute(1, 2, 0))
plt.title(f"Label: {label}, ID: {id}")
plt.axis('off')
plt.show()


Fine tune only the last conv block of ConvNeXt

In [None]:
auc_metrics_uplus2 = calculate_auc_metrics(fold_predictions_uplus2)

auc_metrics_up2_oricath = calculate_auc_metrics(fold_predictions_uplus2_ori_img)

auc_metrics_up2_ori_atrium = calculate_auc_metrics(fold_predictions_uplus2_ori_atr)

# np.savez('/content/drive/MyDrive/msc_uhasselt/experiments/classification/fold_pred_uplus2.npz', **fold_predictions_uplus2)

In [None]:
fold_predictions_to_save = {}
for fold, (probs, labels) in fold_predictions_uplus2_ori_atr.items():
    # Explicitly convert probs and labels to numpy arrays
    fold_predictions_to_save[fold + '_probs'] = np.array(probs)
    fold_predictions_to_save[fold + '_labels'] = np.array(labels)

# Save the prepared dictionary
np.savez('/content/drive/MyDrive/msc_uhasselt/experiments/classification/fold_pred_uplus2_oriatr.npz', **fold_predictions_to_save)
# --- FIX ENDS HERE ---

In [None]:
np.savez('/content/drive/MyDrive/msc_uhasselt/experiments/classification/auc_metrics_uplus2.npz', **auc_metrics_uplus2)
np.savez('/content/drive/MyDrive/msc_uhasselt/experiments/classification/auc_metrics_uplus2_oricath.npz', **auc_metrics_up2_oricath)
np.savez('/content/drive/MyDrive/msc_uhasselt/experiments/classification/auc_metrics_uplus2_oriatr.npz', **auc_metrics_up2_ori_atrium)

In [None]:
np.savez('/content/drive/MyDrive/msc_uhasselt/experiments/classification/auc_metrics_original.npz', **auc_metrics)
# loaded_data = np.load('/content/drive/MyDrive/msc_uhasselt/experiments/classification/auc_metrics_original.npz', allow_pickle=True)
# data_loaded = {key: loaded_data[key] for key in loaded_data.files}
# data_loaded
auc_metrics_ds = calculate_auc_metrics(fold_predictions_ds)
np.savez('/content/drive/MyDrive/msc_uhasselt/experiments/classification/auc_metrics_ds.npz', **auc_metrics_ds)
auc_metrics_base = calculate_auc_metrics(fold_predictions_base)
np.savez('/content/drive/MyDrive/msc_uhasselt/experiments/classification/auc_metrics_base.npz', **auc_metrics_base)
auc_metrics_ori_atrium = calculate_auc_metrics(fold_ori_atrium_pred_cath_perf[1])
np.savez('/content/drive/MyDrive/msc_uhasselt/experiments/classification/auc_metrics_ori_atrium.npz', **auc_metrics_ori_atrium)
auc_metrics_ori_cath = calculate_auc_metrics(fold_predictins_ori_cath_pred)
np.savez('/content/drive/MyDrive/msc_uhasselt/experiments/classification/auc_metrics_ori_cath.npz', **auc_metrics_ori_cath)


In [None]:
free_gpu_memory()

In [None]:
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    f1_score,
    matthews_corrcoef,
    confusion_matrix,
    recall_score
)

def calculate_classification_metrics(fold_predictions, threshold=0.5):
    """
    Calculate classification metrics (Accuracy, Precision, MCC, F1) from fold predictions.

    Args:
        fold_predictions (dict): Dictionary of fold predictions (probs, labels).
        threshold (float): Decision threshold for binary classification (default=0.5).

    Returns:
        dict: Contains fold-wise and macro-averaged metrics.
    """
    fold_metrics = {
        'accuracy': [],
        'precision': [],
        'f1': [],
        'mcc': []
    }
    all_probs = []
    all_labels = []

    for fold, (probs, labels) in fold_predictions.items():
        # Binarize predictions
        preds = (probs >= threshold).astype(int)

        # Compute metrics for this fold
        fold_metrics['accuracy'].append(accuracy_score(labels, preds))
        fold_metrics['precision'].append(precision_score(labels, preds, zero_division=0))
        fold_metrics['f1'].append(f1_score(labels, preds))
        fold_metrics['mcc'].append(matthews_corrcoef(labels, preds))

        # Store for macro-averaging
        all_probs.extend(probs)
        all_labels.extend(labels)

    # Convert to arrays
    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)
    all_preds = (all_probs >= threshold).astype(int)

    # Macro-averaged metrics (computed on all data combined)
    macro_metrics = {
        'accuracy': accuracy_score(all_labels, all_preds),
        'precision': precision_score(all_labels, all_preds, zero_division=0),
        'f1': f1_score(all_labels, all_preds),
        'mcc': matthews_corrcoef(all_labels, all_preds),
        # 'recall': recall_score(all_labels, all_preds)
    }

    # Fold-wise averages (mean ± std)
    avg_metrics = {
        'avg_accuracy': np.mean(fold_metrics['accuracy']),
        'std_accuracy': np.std(fold_metrics['accuracy']),
        'avg_precision': np.mean(fold_metrics['precision']),
        'std_precision': np.std(fold_metrics['precision']),
        'avg_f1': np.mean(fold_metrics['f1']),
        'std_f1': np.std(fold_metrics['f1']),
        'avg_mcc': np.mean(fold_metrics['mcc']),
        'std_mcc': np.std(fold_metrics['mcc'])
        # 'avg_recall': np.mean(fold_metrics['recall']),
        # 'std_recall': np.std(fold_metrics['recall'])
    }

    return {
        'fold_metrics': fold_metrics,  # Metrics per fold
        'avg_metrics': avg_metrics,   # Mean ± std across folds
        'macro_metrics': macro_metrics  # Computed on all data
    }

In [None]:
class_metrics_original = calculate_classification_metrics(fold_predictions)
class_metrics_ds = calculate_classification_metrics(fold_predictions_ds)
class_metrics_base = calculate_classification_metrics(fold_predictions_base)
class_metrics_ori_atrium = calculate_classification_metrics(fold_ori_atrium_pred_cath_perf[1])
class_metrics_ori_cath = calculate_classification_metrics(fold_predictins_ori_cath_pred)


metrics = {
    'Original masks': class_metrics_original,
    'UNet +++ DS': class_metrics_ds,
    'UNet +++': class_metrics_base,
    'Cath DS + Original Atrium': class_metrics_ori_atrium,
    'Atrium DS + Original Cath': class_metrics_ori_cath
}

with open('/content/drive/MyDrive/msc_uhasselt/experiments/classification/class_metrics.json', 'w') as f:
    json.dump(metrics, f)

In [None]:
calculate_classification_metrics(fold_predictions, threshold=0.2)

In [None]:
calculate_classification_metrics(fold_predictions_ds)

### Try Convenext

In [None]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import (
    EfficientNet_B0_Weights,
    EfficientNet_B7_Weights,
    ConvNeXt_Tiny_Weights,
    ConvNeXt_Base_Weights
)


def call_model(model_name='convnext_tiny', device=None, fine_tune=None, drop_out_prob=0.2):
    """
    Initialize ConvNeXt or EfficientNet using torchvision models.

    Args:
        model_name: One of ['convnext_tiny', 'convnext_base', 'efficientnet_b0', 'efficientnet_b7']
        device: torch.device
        fine_tune: None (frozen), 'head_only' (final conv + classifier),
                   'last_two' (last two blocks), or 'all' (entire model)
    """
    # Device setup
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Model initialization with pretrained weights
    if model_name.startswith('convnext'):
        if model_name == 'convnext_tiny':
            weights = ConvNeXt_Tiny_Weights.IMAGENET1K_V1
            model = models.convnext_tiny(weights=weights)
        elif model_name == 'convnext_base':
            weights = ConvNeXt_Base_Weights.IMAGENET1K_V1
            model = models.convnext_base(weights=weights)
        else:
            raise ValueError(f"Unsupported ConvNeXt variant: {model_name}")

        # For ConvNeXt, the final "block" is actually a Sequential of layers
        final_conv = model.features[-1][-1]  # Get last layer of last block

    elif model_name.startswith('efficientnet'):
        if model_name == 'efficientnet_b0':
            weights = EfficientNet_B0_Weights.IMAGENET1K_V1
            model = models.efficientnet_b0(weights=weights)
        elif model_name == 'efficientnet_b7':
            weights = EfficientNet_B7_Weights.IMAGENET1K_V1
            model = models.efficientnet_b7(weights=weights)
        else:
            raise ValueError(f"Unsupported EfficientNet variant: {model_name}")

        # For EfficientNet, identify the final convolutional layer
        final_conv = model.features[-1]
    else:
        raise ValueError(f"Unknown model: {model_name}")

    # Freezing parameters based on fine_tune option
    if fine_tune is None:
        # Freeze all parameters
        for param in model.parameters():
            param.requires_grad = False

    elif fine_tune == 'head_only':
        # Freeze all parameters first
        for param in model.parameters():
            param.requires_grad = False

        # Unfreeze the final convolutional layer
        for param in final_conv.parameters():
            param.requires_grad = True

    elif fine_tune == 'last_two':
        # Unfreeze last two blocks
        if model_name.startswith('convnext'):
            for block in model.features[-2:]:
                for param in block.parameters():
                    param.requires_grad = True
        else:  # EfficientNet
            for layer in model.features[-2:]:
                for param in layer.parameters():
                    param.requires_grad = True

    elif fine_tune == 'all':
        # All parameters remain trainable
        pass
    else:
        raise ValueError("fine_tune must be None, 'head_only', 'last_two', or 'all'")

    # Replace classifier head for binary classification
    if model_name.startswith('convnext'):
        in_features = model.classifier[-1].in_features
        model.classifier = nn.Sequential(
            model.classifier[0],  # Keep LayerNorm2d
            model.classifier[1],  # Keep AdaptiveAvgPool2d
            nn.Flatten(),
            nn.Dropout(p=drop_out_prob, inplace=True),
            nn.Linear(in_features, 1)
        )
    else:  # EfficientNet
        in_features = model.classifier[-1].in_features
        model.classifier = nn.Sequential(
            nn.Dropout(p=drop_out_prob, inplace=True),
            nn.Linear(in_features, 1)
        )

    return model.to(device)

In [None]:
original_convnext_perf, original_convnext_predictions = train_kfold_model(catheter_tensor_original,
                              atria_tensor_original,
                              labels, all_ids,
                              num_epochs=120, size=224,
                              model_name='convnext_tiny',
                              batch_size=16)


In [None]:
merics_convnext_original = summarize_kfold_metrics(original_convnext_perf)

In [None]:
plot_combined_roc_curves(original_convnext_predictions)

In [None]:
original_convnext_large_perf, original_convnext_large_predictions = train_kfold_model(catheter_tensor_original,
                              atria_tensor_original,
                              labels, all_ids,
                              num_epochs=120, size=224,
                              model_name='convnext_base',
                              batch_size=16)


In [None]:
metrics_convnext_large = summarize_kfold_metrics(original_convnext_large_perf)

In [None]:
plot_combined_roc_curves(original_convnext_large_predictions)

In [None]:
convnext_ds_perf, convnext_ds_predictions = train_kfold_model(catheter_tensor_ds,
                              atria_tensor_original,
                              labels, all_ids,
                              num_epochs=120, size=224, model_name='convnext_base',
                              batch_size=16)


In [None]:
metrics_convnext_ds = summarize_kfold_metrics(convnext_ds_perf)

In [None]:
plot_combined_roc_curves(convnext_ds_predictions)
#

In [None]:
convnext_base_perf, convnext_base_predictions = train_kfold_model(catheter_tensor_base,
                              atria_tensor_base,
                              labels, all_ids,
                              num_epochs=120, size=224, model_name='convnext_base',
                              batch_size=16)


In [None]:
metrics_convnext_base = summarize_kfold_metrics(convnext_base_perf)

In [None]:
convnext_ds_perf, convnext_ds_predictions = train_kfold_model(catheter_tensor_ds,
                              atria_tensor_ds,
                              labels, all_ids,
                              num_epochs=120, size=224, model_name='convnext_base',
                              batch_size=16)


In [None]:
# create an 80 20 % training and test data set
metrics_convnext_ds = summarize_kfold_metrics(convnext_ds_perf)

In [None]:
plot_combined_roc_curves(convnext_ds_predictions)

In [None]:
from sklearn.model_selection import train_test_split
import numpy as np
IMG_SIZE = 600
train_transform = A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.3),
        A.Rotate(limit=15, p=0.2, border_mode=cv2.BORDER_CONSTANT),
        A.GridElasticDeform(
       num_grid_xy=(7, 7),
           magnitude=3,
         interpolation=cv2.INTER_NEAREST,
                mask_interpolation=cv2.INTER_NEAREST,
           p=0.2
       )  ,
        A.ToTensorV2()
    ])


val_transform = A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.ToTensorV2()
    ])
#

# 2. Create train-test split (80-20) based on IDs
train_ids, test_ids = train_test_split(
    all_ids,
    test_size=0.2,
    stratify=labels,
    random_state=42  # For reproducibility
)

# 3. Create boolean masks for indexing
train_mask = np.isin(all_ids, train_ids)
test_mask = np.isin(all_ids, test_ids)

# 4. Create datasets with appropriate transforms
train_dataset = ClassificationDataset(
    catheter_predictions=catheter_tensor_u2plus[train_mask],
    atria_predictions=atria_tensor_u2plus[train_mask],
    labels=[labels[i] for i in np.where(train_mask)[0]],
    # original_images=original_images[train_mask] if original_images is not None else None,
    ids=train_ids,
    transform=train_transform,
    # normalize=True
)

test_dataset = ClassificationDataset(
    catheter_predictions=catheter_tensor_u2plus[test_mask],
    atria_predictions=atria_tensor_u2plus[test_mask],
    labels=[labels[i] for i in np.where(test_mask)[0]],
    # original_images=original_images[test_mask] if original_images is not None else None,
    ids=test_ids,
    transform=val_transform,
    # normalize=True
)




In [None]:

idx = 5

img, lab, id = train_dataset[idx]

plt.imshow(img.permute(1, 2, 0))
plt.title(f"Label: {lab}, ID: {id}")
plt.axis('off')
plt.show()

In [None]:
num_epochs = 120
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)



model = call_model(model_name='efficientnet_b7', device=device, fine_tune='last_two', drop_out_prob=0.5)
model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()



# Define loss function and optimizer
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, weight_decay=0.001)


# LR scheduling
warmup_epochs = 5  # Number of epochs for warmup
warmup_scheduler = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_epochs)

  # Step 2: Cosine decay scheduler (after warmup)
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs - warmup_epochs, eta_min=1e-6)

  # # Combined scheduler
from torch.optim.lr_scheduler import SequentialLR
scheduler = SequentialLR(
      optimizer,
      schedulers=[warmup_scheduler, cosine_scheduler],
      milestones=[warmup_epochs]
      )

# Training loop
epoch_losses = []
epoch_accuracies = []
val_accuracies = []
val_loss = []
for epoch in tqdm(range(num_epochs), desc=f'Epochs'):
  free_gpu_memory()
  train_loss, train_acc = train_step(model, train_dataloader, criterion, optimizer, scheduler, device)
  val_metrics = test_step(model, val_dataloader, criterion, device)
  epoch_losses.append(train_loss)
  epoch_accuracies.append(train_acc)
  val_loss.append(val_metrics[0])
  val_accuracies.append(val_metrics[1])
  if (epoch + 1) % 10 == 0 or epoch == 0 or (epoch + 1) == num_epochs:
    print(f"Epoch {epoch+1:03d} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Train Acc: {train_acc * 100:.2f}% | "
          f"Val Loss: {val_metrics[0]:.4f} | "
          f"Val Acc: {val_metrics[1] * 100:.2f}%")





In [None]:
# plot loss and accuracies
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(epoch_losses, label='Train Loss')
plt.plot(val_loss, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epoch_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

In [None]:
# save model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# torch.save(model.state_dict(), '/content/drive/MyDrive/msc_uhasselt/experiments/classification/effnet_unetp2.pth')

model = call_model(model_name='efficientnet_b7', device=device, fine_tune='last_two')
# model.to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/msc_uhasselt/experiments/classification/effnet_unetp2.pth'))
model = model.to(device)

In [None]:
!pip install optuna
import optuna
from optuna.trial import TrialState

IMG_SIZE = 224
train_transform = A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.3),
        A.Rotate(limit=15, p=0.2, border_mode=cv2.BORDER_CONSTANT),
        A.GridElasticDeform(
       num_grid_xy=(7, 7),
           magnitude=3,
         interpolation=cv2.INTER_NEAREST,
                mask_interpolation=cv2.INTER_NEAREST,
           p=0.2
       )  ,
        A.ToTensorV2()
    ])


val_transform = A.Compose([
        A.Resize(IMG_SIZE, IMG_SIZE, interpolation=cv2.INTER_NEAREST),
        A.ToTensorV2()
    ])

train_dataset = ClassificationDataset(
    catheter_predictions=catheter_tensor_original[train_mask],
    atria_predictions=atria_tensor_original[train_mask],
    labels=[labels[i] for i in np.where(train_mask)[0]],
    # original_images=original_images[train_mask] if original_images is not None else None,
    ids=train_ids,
    transform=train_transform,

    # normalize=True
)

test_dataset = ClassificationDataset(
    catheter_predictions=catheter_tensor_original[test_mask],
    atria_predictions=atria_tensor_original[test_mask],
    labels=[labels[i] for i in np.where(test_mask)[0]],
    # original_images=original_images[test_mask] if original_images is not None else None,
    ids=test_ids,
    transform=val_transform,
    # normalize=True
)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

def objective(trial):
    # 1. Suggest hyperparameters
    params = {
        'dropout_prob': trial.suggest_float('dropout_prob', 0.1, 0.5),
        'weight_decay': trial.suggest_float('weight_decay', 1e-5, 1e-3, log=True),
        'lr': trial.suggest_float('lr', 1e-5, 1e-3, log=True),
    }

    # 2. Initialize model with suggested dropout
    model = call_model(model_name='convnext_base', device=device, fine_tune='last_two', drop_out_prob=params['dropout_prob'])
    model.to(device)
     # 3. Create optimizer with weight decay
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=params['lr'],
        weight_decay=params['weight_decay']
    )

    # LR scheduling
    warmup_epochs = 5  # Number of epochs for warmup
    warmup_scheduler = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_epochs)

   # Step 2: Cosine decay scheduler (after warmup)
    cosine_scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs - warmup_epochs, eta_min=1e-6)

    # # Combined scheduler
    from torch.optim.lr_scheduler import SequentialLR
    scheduler = SequentialLR(
          optimizer,
          schedulers=[warmup_scheduler, cosine_scheduler],
          milestones=[warmup_epochs]
          )


    # 4. Use your existing train/val steps
    best_acc = 0
    for epoch in range(50):
        # Your existing training step
        train_loss, train_acc = train_step(model, train_dataloader, criterion, optimizer, scheduler, device)
        val_loss, val_acc, *rest  = test_step(model, val_dataloader, criterion, device)

        # Report back to Optuna
        trial.report(val_acc, epoch)

        # Early stopping/pruning
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

        # Track best validation loss
        if best_acc < val_acc:
            best_acc = val_acc

    return best_acc

In [None]:
# Create study with pruning
study = optuna.create_study(
    direction='minimize',
    pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=10)
)

# Run optimization
study.optimize(objective, n_trials=20)  # 30 trials or 1 hour

# Best found parameters
print("Best trial:")
trial = study.best_trial
print(f"  Dropout: {trial.params['dropout_prob']:.4f}")
print(f"  Weight decay: {trial.params['weight_decay']:.2e}")
print(f"  Learning rate: {trial.params['lr']:.2e}")

In [None]:
model.eval()

In [None]:
ind = 10
img, lab, id = test_dataset[ind]

input = img.unsqueeze(0)

# Ensure the input tensor is on the same device as the model
input = input.to(device) # <-- Added this line

model.to(device)

with torch.no_grad():

  logit = model(input)
  logit = torch.sigmoid(logit)
  logit = logit.item()
print(logit)

predicted_lab = 1 if logit > 0.5 else 0

title_color = "green" if predicted_lab == lab else "red"

prob = logit if predicted_lab == 1 else 1 - logit


plt.imshow(img.permute(1, 2, 0))
plt.title(f"actual: {int(lab)}, predicted: {int(logit > 0.5)}, prop: {prob:.4f}", color=title_color)
plt.axis('off')
plt.show()

input.requires_grad = True

In [None]:
!pip install captum --quiet

In [None]:
from captum.attr import IntegratedGradients
from captum.attr import Saliency
from captum.attr import DeepLift
from captum.attr import NoiseTunnel
from captum.attr import visualization as viz
import torch
import numpy as np


# move model to cpu, to use local ram
device = torch.device("cpu")
model.to(device)

# Ensure the input tensor is on the same device as the model
input = input.to(device)


saliency = Saliency(model)
# Remove the target argument for single-output binary classification models
grads = saliency.attribute(input) # Removed target=lab
grads = np.transpose(grads.squeeze().cpu().detach().numpy(), (1, 2, 0))

In [None]:
# move model and input to cpu
# device = torch.device("cpu")
# model.to(device)
# input = input.to(device)

In [None]:
# Convert the img tensor to a NumPy array before visualizing
img_np = img.squeeze().cpu().numpy()



_ = viz.visualize_image_attr(grads, img_np, method="heat_map", sign="absolute_value",
                          show_colorbar=True, title="Overlayed Gradient Magnitudes")

In [None]:
# file ipython-input-139-8aa18f246882
def attribute_image_features(algorithm, input, **kwargs):
    model.zero_grad()
    # For binary segmentation with a single output channel, do not specify target
    tensor_attributions = algorithm.attribute(input,
                                              # Removed target=labels[ind]
                                              **kwargs
                                             )

    return tensor_attributions

In [None]:
# file ipython-input-141-8aa18f246882
dl = DeepLift(model)
# Call attribute_image_features, which no longer passes the target argument
attr_dl = attribute_image_features(dl, input, baselines=input * 0)
attr_dl = np.transpose(attr_dl.squeeze(0).cpu().detach().numpy(), (1, 2, 0))

In [None]:

_ = viz.visualize_image_attr(attr_dl, img_np, method="blended_heat_map",sign="all",show_colorbar=True,
                          title="Overlayed DeepLift")

In [None]:
met_df['tip'].value_counts()

In [None]:
predicted_features_base_df['tip'].value_counts(normalize=True)