# Vertex AI / Apex-X: Model Visualization

This notebook provides tools to visualize the predictions of `TeacherModelV3` on satellite imagery. 

It demonstrates the "World-Class" capabilities:
- **Instance Segmentation**: Precise roof masks.
- **Mask Quality Scores**: Predicted IoU for each object.
- **Cascade Refinement**: High-precision bounding boxes.

In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path

# Add project root to path
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

from apex_x.model import TeacherModelV3
from apex_x.utils import seed_all

seed_all(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Model

In [None]:
model = TeacherModelV3(
    num_classes=80,
    backbone_model="facebook/dinov2-large",
    lora_rank=8,
    fpn_channels=256,
    num_cascade_stages=3
)
model.to(device)
model.eval()
print("TeacherModelV3 loaded successfully. Note: Weights are random unless you load a checklist.")

## 2. Visualization Logic

In [None]:
def visualize_prediction(image_path, model, threshold=0.5):
    # Load image
    if not os.path.exists(image_path):
        print(f"Image not found: {image_path}")
        return
        
    img_raw = cv2.imread(str(image_path))
    img_raw = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB)
    
    # Preprocess: DINOv2 needs multiples of 14
    h, w = img_raw.shape[:2]
    new_h = (h // 14) * 14
    new_w = (w // 14) * 14
    img = cv2.resize(img_raw, (new_w, new_h))
    
    img_tensor = torch.from_numpy(img).permute(2, 0, 1).float() / 255.0
    img_tensor = img_tensor.unsqueeze(0).to(device)
    
    # Inference
    with torch.no_grad():
        outputs = model(img_tensor)
    
    # Parse outputs
    boxes = outputs['boxes'].cpu().numpy()
    scores = outputs['scores'].cpu().numpy()
    quality = outputs['predicted_quality'].cpu().numpy()
    masks = outputs['masks'].cpu().numpy() if outputs['masks'] is not None else None
    
    # Filter
    keep = scores > threshold
    boxes = boxes[keep]
    scores = scores[keep]
    quality = quality[keep]
    if masks is not None:
        masks = masks[keep]
    
    # Visualization
    plt.figure(figsize=(12, 12))
    
    # Create display image
    disp_img = img.copy()
    
    # Overlay masks
    if masks is not None and len(masks) > 0:
        mask_overlay = np.zeros_like(disp_img)
        for i in range(len(masks)):
            x1, y1, x2, y2 = boxes[i].astype(int)
            mask_logit = masks[i, 0]
            
            # Resize mask to box size
            bw = max(1, x2 - x1)
            bh = max(1, y2 - y1)
            mask_crop = cv2.resize(mask_logit, (bw, bh))
            mask_bool = mask_crop > 0.0
            
            # Color (Red)
            roi = mask_overlay[y1:y2, x1:x2]
            roi[mask_bool] = [255, 0, 0]  # RGB for Red
            
        disp_img = cv2.addWeighted(disp_img, 1.0, mask_overlay, 0.4, 0)
    
    plt.imshow(disp_img)
    ax = plt.gca()
    
    num_dets = len(boxes)
    print(f"Found {num_dets} objects")
    
    for i in range(num_dets):
        x1, y1, x2, y2 = boxes[i]
        sc = scores[i]
        q = quality[i]
        
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, color='cyan', linewidth=2)
        ax.add_patch(rect)
        ax.text(x1, y1-5, f"{sc:.2f} Q:{q:.2f}", color='white', fontsize=10, backgroundcolor='black')
            
    plt.axis('off')
    plt.show()

In [None]:
# Run
# visualize_prediction("../data/sample.jpg", model, threshold=0.1)