In [None]:
!pip install gdown segmentation-models-pytorch torch torchvision opencv-python scikit-learn

import os
import cv2
import gdown
import zipfile
import numpy as np
import torch
import segmentation_models_pytorch as smp
from sklearn.metrics import confusion_matrix

# ======================
# User Input Links
# ======================
suim_link = input("Enter Google Drive link for SUIM dataset: ").strip()
uieb_link = input("Enter Google Drive link for UIEB dataset: ").strip()

# ======================
# Download Function
# ======================
def download_and_extract(gdrive_link, output_dir):
    if "id=" in gdrive_link:
        file_id = gdrive_link.split("id=")[1]
    elif "/d/" in gdrive_link:
        file_id = gdrive_link.split("/d/")[1].split("/")[0]
    else:
        raise ValueError("Invalid Google Drive link format.")

    gdown.download(f"https://drive.google.com/uc?id={file_id}", "temp.zip", quiet=False)
    with zipfile.ZipFile("temp.zip", 'r') as zip_ref:
        zip_ref.extractall(output_dir)
    os.remove("temp.zip")
    print(f"Extracted to: {output_dir}")

# Download datasets
download_and_extract(suim_link, "SUIM")
download_and_extract(uieb_link, "UIEB")

# ======================
# Metric Functions
# ======================
def pixel_accuracy(y_true, y_pred):
    return np.mean(y_true == y_pred)

def mean_iou(y_true, y_pred, num_classes):
    ious = []
    for cls in range(num_classes):
        intersection = np.logical_and(y_pred == cls, y_true == cls).sum()
        union = np.logical_or(y_pred == cls, y_true == cls).sum()
        if union > 0:
            ious.append(intersection / union)
    return np.mean(ious)

def dice_coefficient(y_true, y_pred, num_classes):
    dice_scores = []
    for cls in range(num_classes):
        intersection = np.logical_and(y_pred == cls, y_true == cls).sum()
        total = (y_pred == cls).sum() + (y_true == cls).sum()
        if total > 0:
            dice_scores.append(2 * intersection / total)
    return np.mean(dice_scores)

def mean_pixel_accuracy(y_true, y_pred, num_classes):
    acc_per_class = []
    for cls in range(num_classes):
        cls_pixels = (y_true == cls).sum()
        if cls_pixels > 0:
            acc_per_class.append(((y_pred == cls) & (y_true == cls)).sum() / cls_pixels)
    return np.mean(acc_per_class)

# ======================
# Model Evaluation
# ======================
def evaluate_dataset(dataset_path, model, device, num_classes):
    image_dir = os.path.join(dataset_path, "images")
    mask_dir = os.path.join(dataset_path, "masks")

    pa_list, miou_list, dice_list, mpa_list = [], [], [], []

    model.eval()
    with torch.no_grad():
        for fname in os.listdir(image_dir):
            img_path = os.path.join(image_dir, fname)
            mask_path = os.path.join(mask_dir, fname)

            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            mask_gt = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

            img_tensor = torch.tensor(img.transpose(2,0,1) / 255.0, dtype=torch.float32).unsqueeze(0).to(device)

            pred_mask = model(img_tensor)
            pred_mask = torch.argmax(pred_mask.squeeze(), dim=0).cpu().numpy()

            pa_list.append(pixel_accuracy(mask_gt, pred_mask))
            miou_list.append(mean_iou(mask_gt, pred_mask, num_classes))
            dice_list.append(dice_coefficient(mask_gt, pred_mask, num_classes))
            mpa_list.append(mean_pixel_accuracy(mask_gt, pred_mask, num_classes))

    return np.mean(pa_list), np.mean(miou_list), np.mean(dice_list), np.mean(mpa_list)

# ======================
# Main Execution
# ======================
device = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLASSES = 8  # Adjust according to dataset

# Define models (random init here; load trained weights in real case)
models = {
    "UNet": smp.Unet(encoder_name="resnet34", classes=NUM_CLASSES, encoder_weights=None).to(device),
    "PSPNet": smp.PSPNet(encoder_name="resnet34", classes=NUM_CLASSES, encoder_weights=None).to(device),
    "UNet++": smp.UnetPlusPlus(encoder_name="resnet34", classes=NUM_CLASSES, encoder_weights=None).to(device),
    "DeepLabv3+": smp.DeepLabV3Plus(encoder_name="resnet34", classes=NUM_CLASSES, encoder_weights=None).to(device)
}

# Evaluate SUIM
print("\n=== SUIM Dataset Evaluation ===")
for name, model in models.items():
    pa, miou, dice, mpa = evaluate_dataset("SUIM", model, device, NUM_CLASSES)
    print(f"{name}: Pixel Acc={pa:.4f}, mIoU={miou:.4f}, Dice={dice:.4f}, mPA={mpa:.4f}")

# Evaluate UIEB
print("\n=== UIEB Dataset Evaluation ===")
for name, model in models.items():
    pa, miou, dice, mpa = evaluate_dataset("UIEB", model, device, NUM_CLASSES)
    print(f"{name}: Pixel Acc={pa:.4f}, mIoU={miou:.4f}, Dice={dice:.4f}, mPA={mpa:.4f}")
