# Document Layout Detection - Demo & Visualization

This notebook demonstrates:
1. Model loading and inference
2. Prediction visualization
3. Comparison between Baseline and GroundingDINO

## 1. Setup

In [None]:
import os
import json
import random
from typing import List, Dict, Any

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

import torch
import torch.nn.functional as F

from preprocess import Vocab, UniDSet, find_jsons, read_json
from model import build_model

# Set random seed
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 2. Load Model

In [None]:
def load_checkpoint(ckpt_path: str):
    """Load model checkpoint."""
    ckpt = torch.load(ckpt_path, map_location=device)
    
    # Restore vocab
    vocab = Vocab()
    vocab.itos = ckpt['vocab_itos']
    vocab.stoi = {t: i for i, t in enumerate(vocab.itos)}
    
    # Restore config
    config = ckpt.get('config', {})
    
    # Build model
    model = build_model(config.get('model', {}), vocab_size=len(vocab))
    model.load_state_dict(ckpt['model_state'])
    model = model.to(device)
    model.eval()
    
    return model, vocab, config

# Example: Load your trained model
# model, vocab, config = load_checkpoint('outputs/ckpt/grounding_dino_best.pth')
print("Model loading function defined.")

## 3. Visualization Functions

In [None]:
def denormalize_bbox(bbox, img_w, img_h):
    """Convert normalized bbox (cx, cy, w, h) to pixel coordinates (x, y, w, h)."""
    cx, cy, nw, nh = bbox
    x = (cx - nw / 2.0) * img_w
    y = (cy - nh / 2.0) * img_h
    w = nw * img_w
    h = nh * img_h
    return x, y, w, h


def plot_prediction(img_path: str, query_text: str, pred_bbox: List[float], 
                   gt_bbox: List[float] = None, title: str = "Prediction"):
    """Plot image with predicted and ground truth bboxes."""
    # Load image
    img = Image.open(img_path).convert('RGB')
    img_w, img_h = img.size
    
    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    ax.imshow(img)
    
    # Plot predicted bbox
    if pred_bbox is not None:
        px, py, pw, ph = denormalize_bbox(pred_bbox, img_w, img_h)
        rect_pred = patches.Rectangle(
            (px, py), pw, ph,
            linewidth=3, edgecolor='red', facecolor='none',
            label='Prediction'
        )
        ax.add_patch(rect_pred)
    
    # Plot ground truth bbox
    if gt_bbox is not None:
        gx, gy, gw, gh = denormalize_bbox(gt_bbox, img_w, img_h)
        rect_gt = patches.Rectangle(
            (gx, gy), gw, gh,
            linewidth=3, edgecolor='green', facecolor='none',
            linestyle='--', label='Ground Truth'
        )
        ax.add_patch(rect_gt)
    
    ax.set_title(f"{title}\nQuery: {query_text}", fontsize=14, fontweight='bold')
    ax.axis('off')
    ax.legend(loc='upper right', fontsize=12)
    
    plt.tight_layout()
    plt.show()


def compare_models(img_path: str, query_text: str, 
                  pred_baseline: List[float], pred_grounding: List[float],
                  gt_bbox: List[float] = None):
    """Compare predictions from Baseline and GroundingDINO."""
    img = Image.open(img_path).convert('RGB')
    img_w, img_h = img.size
    
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    
    for ax, pred, title in zip(axes, [pred_baseline, pred_grounding], 
                               ['Baseline', 'GroundingDINO']):
        ax.imshow(img)
        
        # Predicted bbox
        if pred is not None:
            px, py, pw, ph = denormalize_bbox(pred, img_w, img_h)
            rect_pred = patches.Rectangle(
                (px, py), pw, ph,
                linewidth=3, edgecolor='red', facecolor='none',
                label='Prediction'
            )
            ax.add_patch(rect_pred)
        
        # Ground truth bbox
        if gt_bbox is not None:
            gx, gy, gw, gh = denormalize_bbox(gt_bbox, img_w, img_h)
            rect_gt = patches.Rectangle(
                (gx, gy), gw, gh,
                linewidth=3, edgecolor='green', facecolor='none',
                linestyle='--', label='Ground Truth'
            )
            ax.add_patch(rect_gt)
        
        ax.set_title(f"{title}\nQuery: {query_text}", fontsize=14, fontweight='bold')
        ax.axis('off')
        ax.legend(loc='upper right', fontsize=12)
    
    plt.tight_layout()
    plt.show()

print("Visualization functions defined.")

## 4. Inference Function

In [None]:
@torch.no_grad()
def predict_single(model, vocab, img_path: str, query_text: str, img_size: int = 512):
    """Predict bbox for a single image and query."""
    from torchvision import transforms as T
    
    # Load and preprocess image
    img = Image.open(img_path).convert('RGB')
    orig_w, orig_h = img.size
    
    transform = T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor()
    ])
    img_tensor = transform(img).unsqueeze(0).to(device)  # (1, 3, H, W)
    
    # Tokenize query
    tokens = vocab.encode(query_text, max_len=40)
    tokens_tensor = torch.tensor([tokens], dtype=torch.long).to(device)  # (1, L)
    lengths = torch.tensor([len(tokens)], dtype=torch.long).to(device)   # (1,)
    
    # Predict
    model.eval()
    pred = model(img_tensor, tokens_tensor, lengths)  # (1, 4)
    pred_bbox = pred[0].cpu().numpy().tolist()  # [cx, cy, w, h] normalized
    
    return pred_bbox

print("Inference function defined.")

## 5. Example Usage

In [None]:
# Example: Load model and make prediction
# Uncomment and modify paths as needed

# # Load model
# model, vocab, config = load_checkpoint('outputs/ckpt/grounding_dino_best.pth')
# img_size = config.get('data', {}).get('img_size', 512)

# # Example image and query
# img_path = './data/val/jpg/example.jpg'
# query_text = '표를 찾아주세요'
# gt_bbox = [0.5, 0.5, 0.3, 0.2]  # Ground truth (if available)

# # Predict
# pred_bbox = predict_single(model, vocab, img_path, query_text, img_size)

# # Visualize
# plot_prediction(img_path, query_text, pred_bbox, gt_bbox, title="GroundingDINO Prediction")

print("Example usage template provided.")

## 6. Batch Visualization from Dataset

In [None]:
def visualize_random_samples(model, vocab, json_dir: str, jpg_dir: str, 
                            num_samples: int = 5, img_size: int = 512):
    """Visualize random samples from dataset."""
    # Load dataset
    json_files = find_jsons(json_dir)
    dataset = UniDSet(json_files, jpg_dir=jpg_dir, vocab=vocab, 
                     build_vocab=False, resize_to=(img_size, img_size))
    
    # Get samples with ground truth
    valid_indices = [i for i in range(len(dataset)) if dataset[i]['target'] is not None]
    
    if len(valid_indices) == 0:
        print("No samples with ground truth found.")
        return
    
    # Random sample
    sample_indices = random.sample(valid_indices, min(num_samples, len(valid_indices)))
    
    for idx in sample_indices:
        sample = dataset[idx]
        
        # Get image path
        img_path = dataset.items[idx]['img']
        query_text = sample['query_text']
        gt_bbox = sample['target'].numpy().tolist() if sample['target'] is not None else None
        
        # Predict
        pred_bbox = predict_single(model, vocab, img_path, query_text, img_size)
        
        # Calculate IoU if GT available
        if gt_bbox is not None:
            from test import iou_xywh_pixel
            img = Image.open(img_path)
            W, H = img.size
            
            # Convert to pixel coords
            px, py, pw, ph = denormalize_bbox(pred_bbox, W, H)
            gx, gy, gw, gh = denormalize_bbox(gt_bbox, W, H)
            
            iou = iou_xywh_pixel([px, py, pw, ph], [gx, gy, gw, gh])
            print(f"\nSample {idx} - Query: {query_text}")
            print(f"IoU: {iou:.4f}")
        
        # Visualize
        plot_prediction(img_path, query_text, pred_bbox, gt_bbox, 
                       title=f"Sample {idx}")

# Example usage:
# visualize_random_samples(model, vocab, './data/val/json', './data/val/jpg', num_samples=3)

print("Batch visualization function defined.")

## 7. Model Comparison (Baseline vs GroundingDINO)

In [None]:
# Example: Compare two models
# Uncomment and modify as needed

# # Load both models
# model_baseline, vocab_baseline, _ = load_checkpoint('outputs/ckpt/baseline_best.pth')
# model_grounding, vocab_grounding, _ = load_checkpoint('outputs/ckpt/grounding_dino_best.pth')

# img_path = './data/val/jpg/example.jpg'
# query_text = '차트를 찾아주세요'
# gt_bbox = [0.5, 0.5, 0.3, 0.2]  # Ground truth

# # Predict with both models
# pred_baseline = predict_single(model_baseline, vocab_baseline, img_path, query_text)
# pred_grounding = predict_single(model_grounding, vocab_grounding, img_path, query_text)

# # Compare
# compare_models(img_path, query_text, pred_baseline, pred_grounding, gt_bbox)

print("Model comparison template provided.")

## 8. Attention Visualization (Advanced)

In [None]:
def visualize_attention_map(model, vocab, img_path: str, query_text: str, img_size: int = 512):
    """Visualize attention map (for GroundingDINO models)."""
    # This is a placeholder - requires model modification to extract attention weights
    print("Attention visualization requires model hooks to extract attention weights.")
    print("This is left as an exercise for advanced users.")
    
    # TODO: 
    # 1. Register forward hooks on attention layers
    # 2. Extract attention weights during forward pass
    # 3. Visualize attention weights as heatmap overlaid on image

print("Attention visualization placeholder defined.")

## 9. Interactive Demo

In [None]:
# Interactive widget for Jupyter (requires ipywidgets)
try:
    from ipywidgets import interact, widgets
    from IPython.display import display
    
    def interactive_demo(model, vocab, json_dir: str, jpg_dir: str, img_size: int = 512):
        """Interactive demo with dropdown."""
        json_files = find_jsons(json_dir)
        dataset = UniDSet(json_files, jpg_dir=jpg_dir, vocab=vocab, 
                         build_vocab=False, resize_to=(img_size, img_size))
        
        valid_indices = [i for i in range(len(dataset)) if dataset[i]['target'] is not None]
        
        @interact(sample_idx=widgets.Dropdown(options=valid_indices, description='Sample:'))
        def show_prediction(sample_idx):
            sample = dataset[sample_idx]
            img_path = dataset.items[sample_idx]['img']
            query_text = sample['query_text']
            gt_bbox = sample['target'].numpy().tolist() if sample['target'] is not None else None
            
            pred_bbox = predict_single(model, vocab, img_path, query_text, img_size)
            plot_prediction(img_path, query_text, pred_bbox, gt_bbox, 
                           title=f"Sample {sample_idx}")
    
    print("Interactive demo function defined.")
    
except ImportError:
    print("ipywidgets not installed. Install with: pip install ipywidgets")