In [None]:
import cv2
import torch
import numpy as np
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from PIL import Image
from tqdm import tqdm

# ================= CASCADED REFEREE CONFIGURATION =================
STAGE1_PATH = "/kaggle/input/datasets/gonoszgonosz/b2-1024-weights/final_rat_model_b3_1024"

STAGE2_BCE_PATH = "/kaggle/working/final_rat_model_stage2_bce"
STAGE2_DICE_PATH = "/kaggle/working/final_rat_model_stage2_dice"

INPUT_VIDEO = "/kaggle/input/datasets/gonoszgonosz/rat-test-video/test.mp4"
OUTPUT_VIDEO = "/kaggle/working/Cascaded_Referee_Grid.mp4"

CONFIDENCE = 0.5
PADDING_RATIO = 0.25
# ==================================================================

device = "cuda" if torch.cuda.is_available() else "cpu"

def load_model_and_proc(path):
    proc = SegformerImageProcessor.from_pretrained(path)
    model = SegformerForSemanticSegmentation.from_pretrained(path).to(device)
    model.eval()
    return proc, model

def get_prediction(model, processor, image_array):
    h, w = image_array.shape[:2]
    rgb_img = cv2.cvtColor(image_array, cv2.COLOR_BGR2RGB)
    inputs = processor(images=Image.fromarray(rgb_img), return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = torch.nn.functional.interpolate(
            outputs.logits, size=(h, w), mode="bilinear", align_corners=False
        )
        probs = torch.nn.functional.softmax(logits, dim=1)
        mask = (probs[0, 1, :, :] > CONFIDENCE).cpu().numpy().astype(np.uint8)
    return mask

def get_padded_bbox(mask, padding):
    """Calculates the 25% padded square bounding box from a mask."""
    coords = np.column_stack(np.where(mask > 0))
    if coords.size == 0:
        return None
        
    y_min, x_min = coords.min(axis=0)
    y_max, x_max = coords.max(axis=0)
    
    bw, bh = x_max - x_min, y_max - y_min
    pad_w, pad_h = int(bw * padding), int(bh * padding)
    
    x_min, y_min = max(0, x_min - pad_w), max(0, y_min - pad_h)
    x_max, y_max = min(mask.shape[1], x_max + pad_w), min(mask.shape[0], y_max + pad_h)
    
    crop_w, crop_h = x_max - x_min, y_max - y_min
    side = max(crop_w, crop_h)
    cx, cy = (x_min + x_max) // 2, (y_min + y_max) // 2
    
    x_min, y_min = max(0, cx - side // 2), max(0, cy - side // 2)
    x_max, y_max = min(mask.shape[1], x_min + side), min(mask.shape[0], y_min + side)
    
    return y_min, y_max, x_min, x_max

def apply_overlay(image, mask, color, alpha=0.4):
    """Blends a solid color over the masked region."""
    overlay = np.full_like(image, color)
    blended = cv2.addWeighted(image, 1 - alpha, overlay, alpha, 0)
    res = image.copy()
    res[mask == 1] = blended[mask == 1]
    return res

def main():
    print("--- LOADING PIPELINE MODELS ---")
    proc1, model1 = load_model_and_proc(STAGE1_PATH)
    proc2_bce, model2_bce = load_model_and_proc(STAGE2_BCE_PATH)
    proc2_dice, model2_dice = load_model_and_proc(STAGE2_DICE_PATH)

    cap = cv2.VideoCapture(INPUT_VIDEO)
    w, h = int(cap.get(3)), int(cap.get(4))
    fps = cap.get(5)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Grid output: 2W width, 2H height
    out = cv2.VideoWriter(OUTPUT_VIDEO, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w * 2, h * 2))

    print("--- STARTING 2x2 GRID INFERENCE ---")
    pbar = tqdm(total=total_frames)
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret: break
        
        # 1. STAGE 1: Global Search
        stage1_mask = get_prediction(model1, proc1, frame)
        bbox = get_padded_bbox(stage1_mask, PADDING_RATIO)
        
        # Initialize the 4 quadrants
        res_full_bce = frame.copy()
        res_full_dice = frame.copy()
        res_mag_bce = np.zeros_like(frame)
        res_mag_dice = np.zeros_like(frame)
        
        if bbox is not None:
            y1, y2, x1, x2 = bbox
            crop = frame[y1:y2, x1:x2]
            
            # Avoid crashing on zero-dimension crops at extreme edges
            if crop.shape[0] > 0 and crop.shape[1] > 0:
                # 2. STAGE 2: Magnified Expert Inference
                crop_mask_bce = get_prediction(model2_bce, proc2_bce, crop)
                crop_mask_dice = get_prediction(model2_dice, proc2_dice, crop)
                
                # 3. RECONSTRUCTION: Full Scale Top Half
                full_mask_bce = np.zeros((h, w), dtype=np.uint8)
                full_mask_dice = np.zeros((h, w), dtype=np.uint8)
                full_mask_bce[y1:y2, x1:x2] = crop_mask_bce
                full_mask_dice[y1:y2, x1:x2] = crop_mask_dice
                
                res_full_bce = apply_overlay(frame, full_mask_bce, (0, 0, 255)) # Red
                res_full_dice = apply_overlay(frame, full_mask_dice, (0, 255, 0)) # Green
                
                # Draw Bounding Box to show Stage 1's contribution
                cv2.rectangle(res_full_bce, (x1, y1), (x2, y2), (255, 255, 255), 2)
                cv2.rectangle(res_full_dice, (x1, y1), (x2, y2), (255, 255, 255), 2)
                
                # 4. MAGNIFICATION: Bottom Half Visualization
                crop_bce_colored = apply_overlay(crop, crop_mask_bce, (0, 0, 255))
                crop_dice_colored = apply_overlay(crop, crop_mask_dice, (0, 255, 0))
                
                # Scale the crop up to WxH so it matches the grid quadrant size perfectly
                res_mag_bce = cv2.resize(crop_bce_colored, (w, h), interpolation=cv2.INTER_LANCZOS4)
                res_mag_dice = cv2.resize(crop_dice_colored, (w, h), interpolation=cv2.INTER_LANCZOS4)

        # 5. LABELS
        cv2.putText(res_full_bce, "FULL: BCE WEIGHTED", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        cv2.putText(res_full_dice, "FULL: DICE LOSS", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        cv2.putText(res_mag_bce, "MAGNIFIED PERCEPTION: BCE", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        cv2.putText(res_mag_dice, "MAGNIFIED PERCEPTION: DICE", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

        # 6. STITCH THE 2x2 GRID
        top_row = np.hstack((res_full_bce, res_full_dice))
        bottom_row = np.hstack((res_mag_bce, res_mag_dice))
        grid_frame = np.vstack((top_row, bottom_row))

        out.write(grid_frame)
        pbar.update(1)

    cap.release()
    out.release()
    pbar.close()
    print(f"Cascaded 2x2 Referee Video Saved: {OUTPUT_VIDEO}")

if __name__ == "__main__":
    main()