In [None]:
# =======================================
# Google Colab - SUIM Visual Quality Demo
# =======================================
!pip install gdown tensorflow opencv-python matplotlib --quiet

import os
import cv2
import gdown
import zipfile
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model

# ------------------------------
# 1. Accept SUIM dataset link from user
# ------------------------------
drive_link = input("Enter Google Drive link for SUIM dataset: ").strip()

# Convert share link to direct download link
def gdrive_to_direct(link):
    if "drive.google.com" in link:
        if "id=" in link:
            file_id = link.split("id=")[1]
        elif "/d/" in link:
            file_id = link.split("/d/")[1].split("/")[0]
        else:
            raise ValueError("Invalid Google Drive link format.")
        return f"https://drive.google.com/uc?id={file_id}"
    else:
        return link

direct_link = gdrive_to_direct(drive_link)

# ------------------------------
# 2. Download and extract dataset
# ------------------------------
output_zip = "suim_dataset.zip"
gdown.download(direct_link, output_zip, quiet=False)

with zipfile.ZipFile(output_zip, 'r') as zip_ref:
    zip_ref.extractall("suim_dataset")

image_dir = "suim_dataset/images"
mask_dir = "suim_dataset/masks"

# ------------------------------
# 3. Dummy models (replace with trained weights)
# ------------------------------
# Load pre-trained / fine-tuned models
funie_gan = load_model("/path/to/FUnIEGAN_model.h5", compile=False)
tiny_funie_gan = load_model("/path/to/TinyFUnIEGAN_model.h5", compile=False)
mobilenet_seg = load_model("/path/to/MobileNetV2_segmentation_model.h5", compile=False)

# ------------------------------
# 4. Helper functions
# ------------------------------
def load_image(path, target_size=(256, 256)):
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, target_size)
    return img

def preprocess(img):
    return img.astype(np.float32) / 255.0

def postprocess(mask):
    mask = (mask > 0.5).astype(np.uint8)
    return mask

# ------------------------------
# 5. Visualize outputs
# ------------------------------
sample_images = sorted(os.listdir(image_dir))[:5]  # first 5 samples
for img_name in sample_images:
    img_path = os.path.join(image_dir, img_name)
    mask_path = os.path.join(mask_dir, img_name)

    # Load
    orig_img = load_image(img_path)
    gt_mask = load_image(mask_path, target_size=(256, 256))[:,:,0]  # assuming single channel mask

    # Enhance
    enhanced_funie = funie_gan.predict(np.expand_dims(preprocess(orig_img), axis=0))[0]
    enhanced_tiny_funie = tiny_funie_gan.predict(np.expand_dims(preprocess(orig_img), axis=0))[0]

    # Segment
    pred_mask_mobilenet = mobilenet_seg.predict(np.expand_dims(preprocess(orig_img), axis=0))[0]
    pred_mask_mobilenet = postprocess(pred_mask_mobilenet[:,:,0])

    # Plot
    plt.figure(figsize=(14, 6))
    plt.subplot(2, 3, 1)
    plt.imshow(orig_img)
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(2, 3, 2)
    plt.imshow((enhanced_funie * 255).astype(np.uint8))
    plt.title("Enhanced (FUnIE-GAN)")
    plt.axis("off")

    plt.subplot(2, 3, 3)
    plt.imshow((enhanced_tiny_funie * 255).astype(np.uint8))
    plt.title("Enhanced (Tiny FUnIE-GAN)")
    plt.axis("off")

    plt.subplot(2, 3, 4)
    plt.imshow(gt_mask, cmap="gray")
    plt.title("Ground Truth Mask")
    plt.axis("off")

    plt.subplot(2, 3, 5)
    plt.imshow(pred_mask_mobilenet, cmap="gray")
    plt.title("Predicted Mask (MobileNetV2)")
    plt.axis("off")

    plt.tight_layout()
    plt.show()
