In [None]:
# !git clone https://github.com/Soham-Gaonkar/BubbleSegmentation.git

In [None]:
# STEP 0: Imports
import os
import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import GroupShuffleSplit
import numpy as np
import matplotlib.pyplot as plt


# STEP 1: Custom Dataset for Bubble Images
import random
from torchvision.transforms import functional as TF

import random
from torchvision.transforms import functional as TF
from torchvision import transforms

class BubbleDataset(Dataset):
    def __init__(self, image_paths, label_paths, augment=False):
        self.image_paths = image_paths
        self.label_paths = label_paths
        self.augment = augment

    def __len__(self):
        return len(self.image_paths)

    def center_crop(self, img, target_width=750, target_height=554):
        w, h = img.size
        left = (w - target_width) // 2
        top = (h - target_height) // 2
        right = left + target_width
        bottom = top + target_height
        return img.crop((left, top, right, bottom))

    def __getitem__(self, idx):
        # Load original image
        image = Image.open(self.image_paths[idx]).convert('L')
        label = Image.open(self.label_paths[idx])

        # Keep copy before any cropping/resizing
        original_image = image.copy()  # PIL Image (single-channel)

        # Center Crop
        image = self.center_crop(image, target_width=750)
        label = self.center_crop(label, target_width=750)

        # Resize to model input
        image = TF.resize(image, (256, 256))
        label = TF.resize(label, (256, 256), interpolation=Image.NEAREST)

        if self.augment:
            if random.random() > 0.5:
                image = TF.hflip(image)
                label = TF.hflip(label)
            if random.random() > 0.5:
                image = TF.vflip(image)
                label = TF.vflip(label)
            if random.random() > 0.5:
                angle = random.uniform(-5, 5)
                image = TF.rotate(image, angle)
                label = TF.rotate(label, angle, interpolation=Image.NEAREST)

            # Brightness / Contrast
            if random.random() > 0.5:
                image = TF.adjust_brightness(image, random.uniform(0.9, 1.1))
            if random.random() > 0.5:
                image = TF.adjust_contrast(image, random.uniform(0.9, 1.1))

            # Random Crop & Resize (mild zoom)
            if random.random() > 0.5:
                i, j, h, w = transforms.RandomResizedCrop.get_params(
                    image, scale=(0.9, 1.0), ratio=(1.0, 1.0))
                image = TF.resized_crop(image, i, j, h, w, (256, 256))
                label = TF.resized_crop(label, i, j, h, w, (256, 256), interpolation=Image.NEAREST)

            # Gaussian Noise
            if random.random() > 0.5:
                img_tensor = TF.to_tensor(image)
                noise = torch.randn_like(img_tensor) * 0.01
                img_tensor = (img_tensor + noise).clamp(0, 1)
                image = TF.to_pil_image(img_tensor)


        image = TF.to_tensor(image)
        image = image.expand(3, -1, -1)
        label = TF.pil_to_tensor(label).squeeze().long()
        label = (label > 127).long()

        return image, label


# STEP 2: Parsing and Splitting Data Based on Dataset Number
all_images = sorted(glob.glob('../Data/US_2/*.jpg'))
all_labels = [img_path.replace('US', 'Label').replace('.jpg', '.png') for img_path in all_images]


print("Sample image:", all_images[0])
print("Sample label:", all_labels[0])

img = Image.open(all_images[0])
lbl = Image.open(all_labels[0])

print("Image size:", img.size)
print("Label size:", lbl.size)


# Extract dataset number (last digit before .jpg)
def extract_dataset_number(path):
    return int(path.split('_')[-1].split('.')[0])

def extract_pulse_number(path):
    return int(path.split('US')[1].split('_')[0])

groups = [extract_dataset_number(p) for p in all_images]

splitter = GroupShuffleSplit(n_splits=1, test_size=1/6)
print("Number of unique groups:", len(np.unique(groups)))
train_idx, val_idx = next(splitter.split(all_images, groups=groups))

# train_idx = np.concatenate([train_idx, val_idx])


train_images = [all_images[i] for i in train_idx]
train_labels = [all_labels[i] for i in train_idx]
val_images = [all_images[i] for i in val_idx]
val_labels = [all_labels[i] for i in val_idx]



print( np.array(set([extract_dataset_number(p) for p in train_images])))
print( np.array(set([extract_dataset_number(p) for p in val_images])))

print("Sample mapping:")
for img, lbl in zip(train_images[:3], train_labels[:3]):
    print(f"{img}  -->  {lbl}")


# STEP 3: Transforms
img_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
])

label_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=Image.NEAREST),
    transforms.PILToTensor(),
    transforms.Lambda(lambda x: x.squeeze().long())
])


train_dataset = BubbleDataset(train_images, train_labels, augment=True)
val_dataset = BubbleDataset(val_images, val_labels, augment=False)


train_datasets = sorted(set(extract_dataset_number(p) for p in train_images))
val_datasets = sorted(set(extract_dataset_number(p) for p in val_images))

print("Train dataset numbers:", train_datasets)
print("Validation dataset numbers:", val_datasets)


train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

img_batch, lbl_batch = next(iter(train_loader))
print("Image shape:", img_batch.shape)   # [B, 3, 256, 256]
print("Label shape:", lbl_batch.shape)   # [B, 256, 256]
print("Label dtype:", lbl_batch.dtype)   # should be torch.int64
print("Label values:", lbl_batch.unique())  # should be tensor([0, 1])

In [None]:
def show_images(img_batch, lbl_batch, num_samples=5):
    plt.figure(figsize=(num_samples * 3, 6))

    for i in range(num_samples):
        img = img_batch[i][0].cpu().numpy()  # take 1st channel directly (no permute needed)
        lbl = lbl_batch[i].cpu().numpy()

        # Image
        plt.subplot(2, num_samples, i + 1)
        plt.imshow(img, cmap='gray')
        plt.axis('off')
        plt.title("Image")

        # Label
        plt.subplot(2, num_samples, i + 1 + num_samples)
        plt.imshow(lbl, cmap='gray')
        plt.axis('off')
        plt.title("Label")

    plt.tight_layout()
    plt.show()

img_batch, lbl_batch = next(iter(train_loader))
show_images(img_batch, lbl_batch)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os, glob, re
import numpy as np
from PIL import Image, ImageEnhance
from tqdm import tqdm
import random

# ---- Loss Functions ----
class AsymmetricTverskyLoss(nn.Module):
    def __init__(self, delta=0.7, smooth=1e-6, class_weights=None):
        """
        delta > 0.5 penalizes false negatives more (good for segmentation)
        class_weights: tensor of shape (num_classes,), e.g., [background_weight, foreground_weight]
        """
        super().__init__()
        self.delta = delta
        self.smooth = smooth
        self.class_weights = class_weights

    def forward(self, preds, targets):
        # preds shape: (batch, num_classes, H, W)
        preds = F.softmax(preds, dim=1)  # probability over classes

        # Assume binary segmentation: background (0), foreground (1)
        foreground_preds = preds[:, 1, :, :]  # shape: (batch, H, W)
        background_preds = preds[:, 0, :, :]  # shape: (batch, H, W)

        foreground_targets = (targets == 1).float()
        background_targets = (targets == 0).float()

        # True Positives, False Negatives, False Positives for foreground
        true_pos_fg  = (foreground_preds * foreground_targets).sum(dim=[1, 2])
        false_neg_fg = (foreground_targets * (1 - foreground_preds)).sum(dim=[1, 2])
        false_pos_fg = ((1 - foreground_targets) * foreground_preds).sum(dim=[1, 2])

        # True Positives, False Negatives, False Positives for background (optional)
        true_pos_bg  = (background_preds * background_targets).sum(dim=[1, 2])
        false_neg_bg = (background_targets * (1 - background_preds)).sum(dim=[1, 2])
        false_pos_bg = ((1 - background_targets) * background_preds).sum(dim=[1, 2])

        # Tversky index for foreground and background
        tversky_fg = (true_pos_fg + self.smooth) / (true_pos_fg + self.delta * false_neg_fg + (1 - self.delta) * false_pos_fg + self.smooth)
        tversky_bg = (true_pos_bg + self.smooth) / (true_pos_bg + self.delta * false_neg_bg + (1 - self.delta) * false_pos_bg + self.smooth)

        if self.class_weights is not None:
            # Weighted sum of background and foreground losses
            loss = (1 - tversky_bg) * self.class_weights[0] + (1 - tversky_fg) * self.class_weights[1]
        else:
            # Just use foreground loss if no class weights given
            loss = 1 - tversky_fg

        return loss  # shape: (batch,)

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, reduction='none'):
        super().__init__()
        self.gamma = gamma
        self.reduction = reduction
        self.ce = nn.CrossEntropyLoss(reduction='none')  # always 'none' internally

    def forward(self, preds, targets):
        ce_loss = self.ce(preds, targets)  # (batch, H, W)
        pt = torch.exp(-ce_loss)
        focal = (1 - pt) ** self.gamma * ce_loss

        # Apply reduction manually
        if self.reduction == 'mean':
            return focal.mean()
        elif self.reduction == 'sum':
            return focal.sum()
        else:  # 'none' – average over spatial dims per sample
            focal = focal.view(focal.shape[0], -1).mean(dim=1)
            return focal


class AsymmetricFocalTverskyLoss(nn.Module):
    def __init__(self, tversky_weight=0.5, focal_weight=0.5, delta=0.3):
        super().__init__()
        self.tversky = AsymmetricTverskyLoss(delta=delta)
        self.focal = FocalLoss(gamma=2, reduction='none')
        self.tversky_weight = tversky_weight
        self.focal_weight = focal_weight


    def forward(self, preds, targets):
        # Compute per-sample losses
        tversky_loss = self.tversky(preds, targets)  # shape: (batch,)
        focal_loss = self.focal(preds, targets)      # shape: (batch,)
        # Weighted sum of the two losses
        loss = self.tversky_weight * tversky_loss + self.focal_weight * focal_loss
        return loss.mean()  # Return the average over the batch

In [None]:
import torch.nn as nn

import torch
import torch.nn as nn
import torch.nn.functional as F

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        preds = F.softmax(preds, dim=1)[:, 1, :, :]  # Use class 1 (foreground)
        targets = (targets == 1).float()
        intersection = (preds * targets).sum()
        dice = (2. * intersection + self.smooth) / (preds.sum() + targets.sum() + self.smooth)
        return 1 - dice

class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss()

    def forward(self, preds, targets):
        ce_loss = self.ce(preds, targets)
        pt = torch.exp(-ce_loss)
        focal = (1 - pt) ** self.gamma * ce_loss
        return focal


class DiceFocalLoss(nn.Module):
    def __init__(self, dice_weight=0.5, focal_weight=0.5):
        super().__init__()
        self.dice = DiceLoss()
        self.focal = FocalLoss()
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight

    def forward(self, preds, targets):
        return self.dice_weight * self.dice(preds, targets) + self.focal_weight * self.focal(preds, targets)


In [None]:
import torch.nn as nn
from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
weights = DeepLabV3_ResNet101_Weights.DEFAULT

try:
    from torchvision.models.segmentation import deeplabv3_resnet101, DeepLabV3_ResNet101_Weights
    weights = DeepLabV3_ResNet101_Weights.DEFAULT

    model = deeplabv3_resnet101(weights=weights)
    model.classifier[4] = nn.Conv2d(256, 2, kernel_size=1)
    model = model.cuda()
    print("✅ Model successfully moved to CUDA")
except Exception as e:
    print("❌ CUDA error during model setup:", e)

In [None]:
# import os
# import yaml
# from tqdm import tqdm
# import numpy as np
# import torch
# from datetime import datetime

# # --- IoU ---
# def compute_class_iou(preds, targets, num_classes):
#     ious = []
#     total_intersection = 0
#     total_union = 0

#     for cls in range(num_classes):
#         pred_inds = (preds == cls)
#         target_inds = (targets == cls)
#         intersection = (pred_inds & target_inds).sum().item()
#         union = (pred_inds | target_inds).sum().item()

#         if union == 0:
#             ious.append(float('nan'))  # Undefined IoU
#         else:
#             ious.append(intersection / union)
#             total_intersection += intersection
#             total_union += union

#     mean_iou = np.nanmean(ious)
#     weighted_iou = total_intersection / total_union if total_union != 0 else float('nan')

#     return {
#         'per_class_iou': ious,
#         'mean_iou': mean_iou,
#         'weighted_iou': weighted_iou
#     }

# num_classes =2
# # === Config Params === #
# experiment_name = f"BubbleSeg_lr1e-4_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
# save_dir = os.path.join("checkpoints", experiment_name)
# os.makedirs(save_dir, exist_ok=True)

# params = {
#     'model': str(model.__class__.__name__),
#     'criterion': 'DiceFocalLoss(dice_weight=0.5, focal_weight=0.6))',
#     'optimizer': 'Adam',
#     'lr': 3e-4,
#     'scheduler': 'ReduceLROnPlateau',
#     'num_epochs': 2
# }

# with open(os.path.join(save_dir, "initial_config.yaml"), 'w') as f:
#     yaml.dump(params, f)

# # === Training Setup === #
# criterion = DiceFocalLoss(dice_weight=0.5, focal_weight=0.6)
# optimizer = torch.optim.Adam(model.parameters(), lr= 3e-4)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)

# num_epochs = 5
# best_val_iou = 0

# for epoch in range(num_epochs):
#     model.train()
#     train_loss = 0.0

#     train_loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)
#     for imgs, masks in train_loop:
#         imgs, masks = imgs.cuda(), masks.cuda()

#         optimizer.zero_grad()
#         outputs = model(imgs)['out']
#         loss = criterion(outputs, masks)
#         loss.backward()
#         optimizer.step()

#         train_loss += loss.item()
#         train_loop.set_postfix(loss=loss.item())

#     avg_train_loss = train_loss / len(train_loader)

#     # === Validation === #
#     model.eval()
#     val_loss = 0.0
#     val_iou = 0.0

#     with torch.no_grad():
#         for imgs, masks in tqdm(val_loader, desc="Validating", leave=False):
#             imgs, masks = imgs.cuda(), masks.cuda()
#             outputs = model(imgs)['out']
#             loss = criterion(outputs, masks)
#             val_loss += loss.item()

#             preds = torch.argmax(outputs, dim=1)
#             val_iou += compute_class_iou(preds.cpu(), masks.cpu(), num_classes=num_classes)['mean_iou']

#     avg_val_loss = val_loss / len(val_loader)
#     avg_val_iou = val_iou / len(val_loader)

#     print(f"📊 Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f} | Val Loss = {avg_val_loss:.4f} | Val IoU = {avg_val_iou:.4f}")

#     scheduler.step(avg_val_iou)

#     # === Save per epoch === #
#     torch.save(model.state_dict(), os.path.join(save_dir, f"epoch_{epoch+1}.pth"))

#     # === Save best model === #
#     if avg_val_iou > best_val_iou:
#         best_val_iou = avg_val_iou
#         torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pth"))
#         print("✅ Best model saved!")

# # === Save Final model === #
# torch.save(model.state_dict(), os.path.join(save_dir, "last_model.pth"))
# print("🏁 Final model saved.")

In [None]:
# metric.py
import torch
import numpy as np
from sklearn.metrics import roc_auc_score
from scipy.spatial.distance import directed_hausdorff, cdist
from scipy.ndimage import binary_erosion, distance_transform_edt
import warnings
from config import Config

# --- Extract binary boundary map ---
def get_boundary(mask):
    eroded = binary_erosion(mask, border_value=0)
    boundary = mask ^ eroded
    return boundary

# --- Boundary F1 score (MATLAB-style) ---
def compute_bf_score_single(pred, target, tolerance=2):
    if pred.sum() == 0 and target.sum() == 0:
        return 1.0
    if pred.sum() == 0 or target.sum() == 0:
        return 0.0

    pred_boundary = get_boundary(pred)
    target_boundary = get_boundary(target)

    dt_pred = distance_transform_edt(~pred_boundary)
    dt_target = distance_transform_edt(~target_boundary)

    match_pred = target_boundary & (dt_pred <= tolerance)
    match_target = pred_boundary & (dt_target <= tolerance)

    precision = match_target.sum() / (pred_boundary.sum() + 1e-8)
    recall = match_pred.sum() / (target_boundary.sum() + 1e-8)

    if precision + recall == 0:
        return 0.0

    bf_score = 2 * precision * recall / (precision + recall)
    return bf_score

def compute_mean_bf_score(preds, targets, tolerance=2):
    scores = []
    for i in range(preds.shape[0]):
        pred = preds[i, 0].astype(np.bool_)
        target = targets[i, 0].astype(np.bool_)
        score = compute_bf_score_single(pred, target, tolerance)
        scores.append(score)
    return float(np.mean(scores)) if scores else -1.0

# --- Hausdorff Distance ---
def compute_hausdorff(preds, targets):
    mean_hd, max_hd = [], []
    for i in range(preds.shape[0]):
        p = preds[i, 0].astype(np.bool_)
        t = targets[i, 0].astype(np.bool_)

        pred_boundary = get_boundary(p)
        target_boundary = get_boundary(t)

        p_coords = np.argwhere(pred_boundary)
        t_coords = np.argwhere(target_boundary)

        if p_coords.size == 0 or t_coords.size == 0:
            continue

        try:
            hd1 = directed_hausdorff(p_coords, t_coords)[0]
            hd2 = directed_hausdorff(t_coords, p_coords)[0]
            max_hd.append(max(hd1, hd2))

            dist1 = cdist(p_coords, t_coords).min(axis=1)
            dist2 = cdist(t_coords, p_coords).min(axis=1)
            mean_hd.append((dist1.mean() + dist2.mean()) / 2)
        except Exception as e:
            warnings.warn(f"Hausdorff calculation failed: {e}")

    return (
        float(np.nanmean(mean_hd)) if mean_hd else -1.0,
        float(np.nanmean(max_hd)) if max_hd else -1.0,
    )

# --- Per-class metrics ---
def compute_class_metrics(y_true, y_pred, class_val, epsilon=1e-7):
    cls_pred = (y_pred == class_val)
    cls_true = (y_true == class_val)

    TP = np.sum(cls_pred & cls_true)
    TN = np.sum(~cls_pred & ~cls_true)
    FP = np.sum(cls_pred & ~cls_true)
    FN = np.sum(~cls_pred & cls_true)

    accuracy = TP / (TP + FN + epsilon)
    iou = TP / (TP + FP + FN + epsilon)
    
    dice = (2 * TP) / (2 * TP + FP + FN + epsilon)
    precision = TP / (TP + FP + epsilon)
    recall = TP / (TP + FN + epsilon)
    f1_score = (2 * precision * recall) / (precision + recall + epsilon)

    return accuracy, iou, dice, precision, recall, f1_score, TP, FP, FN

# --- Main Metric Function ---
def calculate_all_metrics(predictions, targets, threshold=0.5):
    # Ensure shape is [B, H, W]
    if predictions.ndim == 4 and predictions.shape[1] == 1:
        predictions = predictions[:, 0]
    if targets.ndim == 4 and targets.shape[1] == 1:
        targets = targets[:, 0]

    assert predictions.shape == targets.shape and predictions.ndim == 3, "Expecting [B, H, W] for both predictions and targets"


    probs = torch.sigmoid(predictions)
    preds_bin = (probs > threshold).int()
    
    preds_np = preds_bin.cpu().numpy()
    targets_np = targets.cpu().numpy()

    y_true_flat = targets_np.flatten()
    y_pred_flat = preds_np.flatten()

    global_acc = np.sum(y_pred_flat == y_true_flat) / len(y_true_flat)

    class_ids = [0, 1]
    class_accuracies, class_ious, class_dices = [], [], []
    class_precisions, class_recalls, class_f1s = [], [], []
    weighted_ious = []
    TP_sum, FP_sum, FN_sum = 0, 0, 0

    for cls in class_ids:
        acc, iou, dice, prec, rec, f1, TP, FP, FN = compute_class_metrics(y_true_flat, y_pred_flat, cls)
        class_accuracies.append(acc)
        class_ious.append(iou)
        class_dices.append(dice)
        class_precisions.append(prec)
        class_recalls.append(rec)
        class_f1s.append(f1)
        class_px_count = np.sum(y_true_flat == cls)
        weighted_ious.append(iou * class_px_count)
        TP_sum += TP
        FP_sum += FP
        FN_sum += FN

    mean_acc = np.mean(class_accuracies)
    mean_iou = np.mean(class_ious)
    weighted_iou = np.sum(weighted_ious) / (len(y_true_flat) + 1e-7)
    mean_dice = np.mean(class_dices)
    mean_precision = np.mean(class_precisions)
    mean_recall = np.mean(class_recalls)

    mean_bf_score = compute_mean_bf_score(preds_np, targets_np)
    mean_hd, max_hd = compute_hausdorff(preds_np, targets_np)

    try:
        probs_flat = probs.detach().cpu().numpy().flatten()
        if len(np.unique(y_true_flat)) > 1:
            auroc = roc_auc_score(y_true_flat, probs_flat)
        else:
            auroc = 0.5
    except Exception as e:
        warnings.warn(f"AUROC computation failed: {e}")
        auroc = 0.5

    return {
        "GlobalAccuracy": global_acc,
        "MeanAccuracy": mean_acc,
        "MeanIoU": mean_iou,
        "WeightedIoU": weighted_iou,
        "MeanBFScore": mean_bf_score,
        "Dice (F1 Score)": mean_dice,
        "AUROC": auroc,
        "Precision": mean_precision,
        "Recall": mean_recall,
        "MeanHausdorff": mean_hd,
        "MaxHausdorff": max_hd,
    }


In [None]:
# import gdown
# url = 'https://drive.google.com/uc?id=FILE_ID_OR_FULL_URL'
# gdown.download(url, 'model.pth', quiet=False)

In [None]:
def recover_original(img, final_width=1024, final_height=256, cropped_width=750, cropped_height=554):
    # Step 1: Resize back to cropped size
    img_resized = TF.resize(img, [cropped_height, cropped_width], interpolation=Image.NEAREST)

    # Step 2: Pad to original size
    pad_left = (final_width - cropped_width) // 2
    pad_right = final_width - cropped_width - pad_left
    pad_top = (final_height - cropped_height) // 2
    pad_bottom = final_height - cropped_height - pad_top

    img_padded = ImageOps.expand(img_resized, (pad_left, pad_top, pad_right, pad_bottom), fill=0)

    return img_padded

In [None]:
from torch.utils.data import Dataset, DataLoader
import glob

# --- Test Setup ---
test_images = sorted(glob.glob('../Data/US_Test_2023April7/*.jpg'))
test_labels = sorted(glob.glob('../Data/Labels_Test_2023April7/*.png'))
test_dataset = BubbleDataset(test_images, test_labels, augment=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


import re

def extract_pulse_and_dataset(filename):
    """
    Extracts:
    - pulse number from 'US###'
    - dataset number as the last number before '.jpg'
    Example: 't3US100_738983_1.jpg' → pulse=100, dataset=1
    """
    base = os.path.basename(filename).replace(".jpg", "")
    parts = base.split('_')
    dataset = int(parts[-1]) if parts[-1].isdigit() else -1
    pulse_match = re.search(r'US(\d+)', base)
    pulse = int(pulse_match.group(1)) if pulse_match else -1
    return pulse, dataset

from sklearn.metrics import confusion_matrix
from scipy.spatial.distance import directed_hausdorff
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F

# STEP 1: Create test dataset
test_images = sorted(glob.glob('../Data/US_Test_2023April7/*.jpg'))
test_labels = sorted(glob.glob('../Data/Label_Test_2023April7/*.png'))


test_dataset = BubbleDataset(test_images, test_labels, augment=False)
test_loader = DataLoader(test_dataset, batch_size= 16, shuffle=False)


import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision.transforms.functional as TF
from tqdm import tqdm
import numpy as np
import torch
import pandas as pd
import torch
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
import random
from time import perf_counter



def compute_area(mask):
    return (mask == 1).sum().item()


# Function to compute per-class IoU
def compute_class_iou(preds, targets, num_classes):
    ious = []
    total_intersection = 0
    total_union = 0

    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (targets == cls)
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()

        if union == 0:
            ious.append(float('nan'))  # Undefined IoU
        else:
            ious.append(intersection / union)
            total_intersection += intersection
            total_union += union

    mean_iou = np.nanmean(ious)
    weighted_iou = total_intersection / total_union if total_union != 0 else float('nan')

    return {
        'per_class_iou': ious,
        'mean_iou': mean_iou,
        'weighted_iou': weighted_iou
    }



# === Load best model === #
model.load_state_dict(torch.load("best_model_final.pth"))
model.eval()
# checkpoint = torch.load("../code_files/checkpoints/TorchvisionDeepLabV3_DiceFocalLoss_Dice0.5_Tversky0.5_Focal0.6_Epochs5_LR0.0003/best.pth.tar")
# model.load_state_dict(checkpoint["state_dict"])
# model.eval()


# === Run on val_loader === #
# Evaluation
records = []
all_mean_ious = []
total_intersection = 0
total_union = 0
n_visualize = 5  # Change as needed


all_samples = []
# === Evaluation Loop ===
start_time = perf_counter()
with torch.no_grad():
    for idx, (imgs, masks) in enumerate(tqdm(test_loader, desc="Evaluating on Test")):
        imgs, masks = imgs.cuda(), masks.cuda()
        
        outputs = model(imgs)['out']
        preds = torch.argmax(outputs, dim=1)

        # preds = recover_original(preds)
        # masks = recover_original(masks)

        if preds.shape != masks.shape:
            preds = F.interpolate(preds.unsqueeze(1).float(), size=masks.shape[-2:], mode='nearest').squeeze(1).long()

        recovered_preds = torch.stack([
            TF.to_tensor(recover_original(TF.to_pil_image(pred.cpu().byte() * 255))).squeeze(0).long()
            for pred in preds
        ])

        recovered_masks = torch.stack([
            TF.to_tensor(recover_original(TF.to_pil_image(mask.cpu().byte() * 255))).squeeze(0).long()
            for mask in masks
        ])

        
        # Compute IoU metrics
        iou_metrics = compute_class_iou(recovered_preds.cpu().numpy(), recovered_masks.cpu().numpy(), num_classes=2)
        all_mean_ious.append(iou_metrics['mean_iou'])

        total_intersection += iou_metrics['weighted_iou'] * (recovered_preds == 1).sum().item()
        total_union += (recovered_preds == 1).sum().item()

        for i in range(imgs.size(0)):
            image_name = os.path.basename(test_images[idx * test_loader.batch_size + i])
            pulse, dataset = extract_pulse_and_dataset(image_name)

            gt_area = compute_area(recovered_masks[i].cpu())
            pred_area = compute_area(recovered_preds[i].cpu())

            # === Main Metric Computation ===
            pred = preds[i].unsqueeze(0).cpu()
            target = masks[i].unsqueeze(0).cpu()
            metrics = calculate_all_metrics(pred, target)
            all_samples.append((pred, target)) ########

            # Store everything
            records.append({
                'image': image_name,
                'pulse': pulse,
                'dataset': dataset,
                'gt_area_px': gt_area,
                'pred_area_px': pred_area,
                **metrics
            })

            # records.append({
            #     'image': image_name,
            #     'pulse': pulse,
            #     'dataset': dataset,
            #     'gt_area_px': gt_area,
            #     'pred_area_px': pred_area
            # })

end_time = perf_counter()
total_inference_time = end_time - start_time
print(f"Total Inference Time: {total_inference_time:.2f} seconds")

final_mean_iou = np.mean(all_mean_ious)
final_weighted_iou = total_intersection / total_union if total_union > 0 else float('nan')

print(f"Mean IoU (unweighted)   : {final_mean_iou:.4f}")
print(f"Weighted IoU (area)     : {final_weighted_iou:.4f}")



# --- Aggregate Metrics ---
all_preds = torch.stack([s[0].squeeze(0) for s in all_samples])   
all_targets = torch.stack([s[1] for s in all_samples])            

final_metrics = calculate_all_metrics(all_preds, all_targets)
# print(final_metrics)
print("\n--- Average Test Metrics ---")
if final_metrics:
    for k, v in sorted(final_metrics.items()):
        print(f"{k}: {v:.4f}" if isinstance(v, (float, np.number)) and pd.notna(v) else f"{k}: {v}")
        
# === Convert to DataFrame ===
df = pd.DataFrame(records)

# Convert to mm² using correct scaling
pixel_area_mm2 = 0.0025
df['gt_area_mm2'] = df['gt_area_px'] * pixel_area_mm2
df['pred_area_mm2'] = df['pred_area_px'] * pixel_area_mm2

# === Show first few entries for verification ===
print("✅ Evaluation complete. Sample entries:")
display(df.head())

# === Save to CSV ===
df.to_csv("gt_pred_areas_per_image.csv", index=False)
print("✅ Saved: gt_pred_areas_per_image.csv")

# === Optional: Group by (pulse, dataset) for plotting later ===
grouped_df = df.groupby(['pulse', 'dataset'])[['gt_area_mm2', 'pred_area_mm2']].mean().reset_index()
grouped_df.to_csv("area_grouped_by_pulse_dataset.csv", index=False)
print("✅ Saved: area_grouped_by_pulse_dataset.csv")



display(grouped_df)


In [None]:

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
# Ensure area in mm² is available
# pixel_area_mm2 = 0.005
# df['gt_area_mm2'] = df['gt_area_px'] * pixel_area_mm2
# df['pred_area_mm2'] = df['pred_area_px'] * pixel_area_mm2

# Group by pulse and calculate mean and std across datasets
pulse_agg = df.groupby('pulse')[['gt_area_mm2', 'pred_area_mm2']].agg(['mean', 'std']).reset_index()

# Flatten column names
pulse_agg.columns = ['pulse',
                     'gt_mean', 'gt_std',
                     'pred_mean', 'pred_std']
# display(pulse_agg)

pulse_agg['pulse'] = pulse_agg['pulse'].astype(int)*20
import matplotlib.pyplot as plt
import pandas as pd

# Assuming pulse_agg is already defined as your DataFrame
# If not, you can load it from CSV or another source

# Plot setup
plt.figure(figsize=(8,8))


plt.errorbar(
    pulse_agg['pulse'], pulse_agg['pred_mean'], yerr=pulse_agg['pred_std'],
    fmt='-o', label='Prediction', color='#439cce', capsize=3, alpha=0.7
)

plt.errorbar(
    pulse_agg['pulse'], pulse_agg['gt_mean'], yerr=pulse_agg['gt_std'],
    fmt='-o', label='Ground Truth', color='#ba4000', capsize=3, alpha=0.7
)



# Labels and legend
plt.xlabel('Number of Pulses')
plt.ylabel('Ablation Area (mm²)')
plt.title('Ablation Area vs Number of Pulses')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Group by pulse and calculate mean and std across datasets
pulse_agg = df.groupby('pulse')[['gt_area_mm2', 'pred_area_mm2']].agg(['mean', 'std']).reset_index()
pulse_agg.columns = ['pulse', 'gt_mean', 'gt_std', 'pred_mean', 'pred_std']
pulse_agg['pulse'] = pulse_agg['pulse'].astype(int) * 20

# Plot
plt.figure(figsize=(6, 5))
ax = plt.gca()
ax.tick_params(direction='in') 
for spine in ax.spines.values():
    spine.set_linewidth(1.5)  # or 2.0 for thicker


# CNN Prediction (blue)
plt.scatter(pulse_agg['pulse'], pulse_agg['pred_mean'], color='#0072ba', label='CNN', s=25)
plt.fill_between(
    pulse_agg['pulse'],
    pulse_agg['pred_mean'] - pulse_agg['pred_std'],
    pulse_agg['pred_mean'] + pulse_agg['pred_std'],
    color='#0072ba', alpha=0.4
)
plt.plot(
    pulse_agg['pulse'], pulse_agg['pred_mean'] - pulse_agg['pred_std'],
    color='black', linewidth=0.3
)
plt.plot(
    pulse_agg['pulse'], pulse_agg['pred_mean'] + pulse_agg['pred_std'],
    color='black', linewidth=0.3
)


# Ground Truth (reddish)
plt.scatter(pulse_agg['pulse'], pulse_agg['gt_mean'], color='#ba4000', label='Truth', s=40, marker='^')
plt.fill_between(
    pulse_agg['pulse'],
    pulse_agg['gt_mean'] - pulse_agg['gt_std'],
    pulse_agg['gt_mean'] + pulse_agg['gt_std'],
    color='#ba4000', alpha=0.4
)
plt.plot(
    pulse_agg['pulse'], pulse_agg['gt_mean'] - pulse_agg['gt_std'],
    color='black', linewidth=0.3
)
plt.plot(
    pulse_agg['pulse'], pulse_agg['gt_mean'] + pulse_agg['gt_std'],
    color='black', linewidth=0.3
)

# Labels, Title, Legend
plt.xlabel('Number of Pulses', fontsize=20, fontweight='bold')
plt.ylabel('Ablation Area (mm²)', fontsize=20, fontweight='bold')

# plt.title('Ablation Area vs Number of Pulses')
# plt.grid(True, linestyle='--', alpha=0.6)
plt.xlim(0, 2000)
plt.ylim(0, 120)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

plt.legend()
plt.tight_layout()
plt.show()


In [None]:
def plot_metrics_vs_pulses(metrics_csv_path, save_dir, experiment_name):
    import os
    import pandas as pd
    import matplotlib.pyplot as plt

    # Load and preprocess
    metrics_df = pd.read_csv(metrics_csv_path)
    if 'pulse' in metrics_df.columns:
        metrics_df = metrics_df.rename(columns={'pulse': 'pulses'})

    metrics_to_plot = {
        "GlobalAccuracy": "Predictive Accuracy (%)",
        "Dice (F1 Score)": "Dice Similarity Coefficient (%)"
    }

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    axes = axes.flatten()
    

    for idx, (metric, ylabel) in enumerate(metrics_to_plot.items()):
        ax = axes[idx]

        if metric not in metrics_df.columns:
            print(f"Warning: Metric '{metric}' not found in DataFrame columns.")
            continue

        grouped = metrics_df.groupby('pulses')[metric].agg(['mean', 'std']).reset_index().sort_values(by='pulses')
        grouped['mean'] *= 100
        grouped['std'] *= 100

        # === Styling Consistency ===
        ax.tick_params(direction='in')
        for spine in ax.spines.values():
            spine.set_linewidth(1.5)

        # === Plot scatter + fill_between ===
        ax.scatter(grouped['pulses'], grouped['mean'], color='#0071bd', s=25, alpha=0.9, edgecolors='black', linewidths=0.3)
        ax.fill_between(grouped['pulses'],
                        grouped['mean'] - grouped['std'],
                        grouped['mean'] + grouped['std'],
                        color='#0071bd', alpha=0.4)
        ax.plot(grouped['pulses'], grouped['mean'] - grouped['std'], color='black', linewidth=0.3)
        ax.plot(grouped['pulses'], grouped['mean'] + grouped['std'], color='black', linewidth=0.3)

        ax.set_xlabel('Number of Pulses', fontsize=20, fontweight='bold')
        ax.set_ylabel(ylabel, fontsize=20, fontweight='bold')
        # True range of your pulses (actual data range)
        actual_min = grouped['pulses'].min()
        actual_max = grouped['pulses'].max()

        # Create 5 evenly spaced positions across actual data
        tick_positions = np.linspace(actual_min, actual_max, 5)
        tick_labels = np.linspace(0, 2000, 5).astype(int)

        ax.set_xticks(tick_positions)
        ax.set_xticklabels(tick_labels)


        ax.grid(True, linestyle='--', alpha=0.5)
        # ax.set_xlim(0, 2000)
        ax.set_ylim(0, 100)


    # Super Title
    plt.tight_layout()
    plt.suptitle(f"{experiment_name} Metrics vs. Number of Pulses", fontsize=16, fontweight='bold', y=1.02)
    
    # Save
    os.makedirs(save_dir, exist_ok=True)
    plot_path = os.path.join(save_dir, f"{experiment_name}_metrics_vs_pulses.png")
    plt.savefig(plot_path, dpi=200, bbox_inches='tight')
    print(f"✅ Metrics plot saved to {plot_path}")

    plt.show()

plot_metrics_vs_pulses(
    metrics_csv_path="test_metrics_per_image.csv",
    save_dir="this_studio",
    experiment_name="BubbleSeg"
)