In [1]:
!pip install gradio pyngrok segmentation_models_pytorch opencv-python-headless torch torchvision torchaudio

Collecting pyngrok
  Downloading pyngrok-7.4.1-py3-none-any.whl.metadata (8.1 kB)
Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Downloading pyngrok-7.4.1-py3-none-any.whl (25 kB)
Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl (154 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m154.8/154.8 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyngrok, segmentation_models_pytorch
Successfully installed pyngrok-7.4.1 segmentation_models_pytorch-0.5.0


In [2]:
# ==============================
# 1Ô∏è‚É£ INSTALL DEPENDENCIES
# ==============================
!pip install -q gradio segmentation_models_pytorch opencv-python-headless torch torchvision torchaudio

# ==============================
# 2Ô∏è‚É£ IMPORTS
# ==============================
import gradio as gr
import torch
import cv2
import numpy as np
import tempfile
import os
import segmentation_models_pytorch as smp

# ==============================
# 3Ô∏è‚É£ MODEL CONFIGURATION
# ==============================
MODEL_CHECKPOINTS = {
    "UNet++": ("/content/best_model_unet++_epoch13.pth",
               smp.UnetPlusPlus(encoder_name="resnet34", encoder_weights=None, in_channels=3, classes=1)),
    "PSPNet": ("/content/best_model_pspnet_epoch13.pth",
               smp.PSPNet(encoder_name="resnet34", encoder_weights=None, in_channels=3, classes=1)),
    "DINOv2": ("/content/best_model_dinov2_epoch31.pth",
               smp.FPN(encoder_name="resnet50", encoder_weights=None, in_channels=3, classes=1)),
    "DeepLabV3": ("/content/best_model_deeplabv3_epoch19.pth",
                  smp.DeepLabV3(encoder_name="resnet34", encoder_weights=None, in_channels=3, classes=1)),
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================
# 4Ô∏è‚É£ LOAD MODELS
# ==============================
@torch.inference_mode()
def load_models():
    models = {}
    for name, (ckpt_path, model) in MODEL_CHECKPOINTS.items():
        if os.path.exists(ckpt_path):
            try:
                model.load_state_dict(torch.load(ckpt_path, map_location=device))
                model.to(device)
                model.eval()
                models[name] = model
                print(f"‚úÖ Loaded {name}")
            except Exception as e:
                print(f"‚ùå Failed to load {name}: {e}")
        else:
            print(f"‚ö†Ô∏è Checkpoint not found for {name}: {ckpt_path}")
    return models

MODELS = load_models()

# ==============================
# 5Ô∏è‚É£ IMAGE UTILITIES
# ==============================
def preprocess_frame(frame, img_width=256, img_height=256):
    """Convert OpenCV frame to normalized tensor."""
    img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (img_width, img_height))
    img = img / 255.0
    tensor = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).float()
    return tensor

def postprocess_mask(mask, original_size):
    """Convert sigmoid mask to binary image and resize."""
    mask = (mask > 0.5).astype(np.uint8) * 255
    mask = cv2.resize(mask, original_size, interpolation=cv2.INTER_NEAREST)
    return mask

def apply_mask_overlay(image, mask, alpha=0.6):
    """Overlay red mask on top of the original image."""
    if len(mask.shape) == 3:
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    if image.shape[:2] != mask.shape[:2]:
        mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
    red_mask = np.zeros_like(image)
    red_mask[mask > 0] = [255, 0, 0]
    overlay = cv2.addWeighted(image, 1 - alpha, red_mask, alpha, 0)
    return overlay

# ==============================
# 6Ô∏è‚É£ VIDEO PROCESSING
# ==============================
def process_video(video_path, model_name, frame_skip=5):
    """Run segmentation on a video file and return an annotated video."""
    if video_path is None:
        return "‚ùå No video uploaded", None

    if model_name not in MODELS:
        return f"‚ùå Model '{model_name}' not loaded", None

    model = MODELS[model_name]
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return "‚ùå Could not open video", None

    fps = cap.get(cv2.CAP_PROP_FPS)
    width, height = int(cap.get(3)), int(cap.get(4))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    temp_output = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
    out = cv2.VideoWriter(temp_output.name, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))

    frame_count = 0
    processed = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Frame skipping for speed
        if frame_skip and frame_count % (frame_skip + 1) != 0:
            frame_count += 1
            continue

        # Preprocess and run inference
        tensor = preprocess_frame(frame).to(device)
        with torch.no_grad():
            output = model(tensor)
            mask = output.sigmoid().cpu().numpy()[0, 0]

        # Postprocess and overlay
        mask = postprocess_mask(mask, (width, height))
        overlay = apply_mask_overlay(frame, mask)
        out.write(overlay)

        frame_count += 1
        processed += 1

    cap.release()
    out.release()

    return f"‚úÖ Processed {processed}/{total_frames} frames successfully!", temp_output.name

# ==============================
# 7Ô∏è‚É£ GRADIO INTERFACE
# ==============================
def segment_interface(video, model_name):
    try:
        # Gradio passes a file path, not file object
        message, output_path = process_video(video, model_name)
        return message, output_path
    except Exception as e:
        return f"Error: {str(e)}", None

with gr.Blocks(title="Character Segmentation") as demo:
    gr.Markdown("# üé¨ Character Segmentation App (Gradio Version)")
    gr.Markdown("Upload a video and choose a segmentation model to process it frame by frame.")

    with gr.Row():
        video_input = gr.Video(label="Upload Video")
        model_choice = gr.Dropdown(list(MODELS.keys()), label="Choose Model", value="UNet++")

    output_message = gr.Textbox(label="Status")
    output_video = gr.Video(label="Segmented Output")

    run_button = gr.Button("Run Segmentation")
    run_button.click(fn=segment_interface, inputs=[video_input, model_choice],
                     outputs=[output_message, output_video])

demo.launch(share=True)


‚úÖ Loaded UNet++
‚úÖ Loaded PSPNet
‚úÖ Loaded DINOv2
‚úÖ Loaded DeepLabV3
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://deb8d9e0bff96341d1.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


