In [None]:
# ========================
# Google Colab: SUIM + UIEB Evaluation
# ========================

import os
import time
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from torchvision import transforms
from google.colab import drive
from thop import profile, clever_format
import torchmetrics

# ====== 1. MOUNT DRIVE AND ACCEPT LINKS ======
drive.mount('/content/drive')

suim_link = input("Enter Google Drive link for SUIM dataset: ").strip()
uieb_link = input("Enter Google Drive link for UIEB dataset: ").strip()

# Convert Google Drive share link to direct download if needed
def gdrive_download(link, dest_folder):
    if "drive.google.com" in link:
        if "id=" in link:
            file_id = link.split("id=")[-1]
        elif "/d/" in link:
            file_id = link.split("/d/")[1].split("/")[0]
        os.system(f"gdown --id {file_id} -O {dest_folder} --fuzzy")
    else:
        print("Invalid Google Drive link!")

os.makedirs("/content/SUIM", exist_ok=True)
os.makedirs("/content/UIEB", exist_ok=True)

gdrive_download(suim_link, "/content/SUIM.zip")
gdrive_download(uieb_link, "/content/UIEB.zip")

os.system("unzip -q /content/SUIM.zip -d /content/SUIM")
os.system("unzip -q /content/UIEB.zip -d /content/UIEB")

# ====== 2. PLACEHOLDER MODELS ======
class DummySegModel(nn.Module):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.conv = nn.Conv2d(3, 3, kernel_size=3, padding=1)

    def forward(self, x):
        return torch.sigmoid(self.conv(x))

# Replace with actual implementations
models = {
    "FUnIEGAN+ResNet50": DummySegModel("funie_resnet50"),
    "DeepLab+MobileNetV2": DummySegModel("deeplab_mobilenet"),
    "DeepLab+TinyFUnIEGAN+PrunedMobileOne50": DummySegModel("deeplab_tinyfunie_mobileone"),
}

device = "cuda" if torch.cuda.is_available() else "cpu"

# ====== 3. DATA LOADING (SUIM + UIEB) ======
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

def load_dataset_images(path, max_samples=20):
    images, masks = [], []
    img_folder = os.path.join(path, "images")
    mask_folder = os.path.join(path, "masks")
    for i, file in enumerate(os.listdir(img_folder)):
        if i >= max_samples:
            break
        img_path = os.path.join(img_folder, file)
        mask_path = os.path.join(mask_folder, file)
        if os.path.exists(img_path) and os.path.exists(mask_path):
            img = transform(Image.open(img_path).convert("RGB"))
            mask = transform(Image.open(mask_path).convert("L"))
            images.append(img)
            masks.append(mask)
    return torch.stack(images), torch.stack(masks)

suim_imgs, suim_masks = load_dataset_images("/content/SUIM")
uieb_imgs, uieb_masks = load_dataset_images("/content/UIEB")

# ====== 4. METRICS ======
miou_metric = torchmetrics.JaccardIndex(task="binary", num_classes=2).to(device)
dice_metric = torchmetrics.Dice().to(device)

def evaluate_model(model, images, masks):
    model.to(device)
    model.eval()

    # Params and GFLOPs
    dummy_input = torch.randn(1, 3, 256, 256).to(device)
    flops, params = profile(model, inputs=(dummy_input,), verbose=False)
    flops, params = clever_format([flops, params], "%.3f")

    # Latency
    torch.cuda.synchronize() if device == "cuda" else None
    start_time = time.time()
    with torch.no_grad():
        _ = model(dummy_input)
    torch.cuda.synchronize() if device == "cuda" else None
    latency = (time.time() - start_time) * 1000  # ms

    # Power consumption placeholder
    power_watts = np.random.uniform(3, 7)  # mock values

    # mIoU & Dice
    preds = []
    gts = []
    with torch.no_grad():
        for img, mask in zip(images, masks):
            img = img.unsqueeze(0).to(device)
            pred = model(img)
            pred_bin = (pred > 0.5).float()
            preds.append(pred_bin.cpu())
            gts.append(mask.unsqueeze(0).cpu())

    preds = torch.cat(preds).to(device)
    gts = torch.cat(gts).to(device)

    miou = miou_metric(preds, gts).item()
    dice = dice_metric(preds, gts).item()

    return params, flops, latency, power_watts, miou, dice

# ====== 5. RUN EVALUATION ======
for dataset_name, (imgs, masks) in {
    "SUIM": (suim_imgs, suim_masks),
    "UIEB": (uieb_imgs, uieb_masks)
}.items():
    print(f"\n=== Dataset: {dataset_name} ===")
    for model_name, model in models.items():
        params, flops, latency, power, miou, dice = evaluate_model(model, imgs, masks)
        print(f"{model_name}: Params={params}, GFLOPs={flops}, Latency={latency:.2f} ms, Power={power:.2f} W, mIoU={miou:.4f}, Dice={dice:.4f}")
