# LightGlue Based Image-Map Matching
this notebook is based on the [LightGlue Demo](https://github.com/cvg/LightGlue/blob/main/demo.ipynb)

In [None]:
# If we are on colab: this clones the repo and installs the dependencies
from pathlib import Path

# if Path.cwd().name != "LightGlue":
#     if not Path("LightGlue").exists():
#         !git clone --quiet https://github.com/cvg/LightGlue/
#     %cd LightGlue
    # !pip install --progress-bar off --quiet -e .

from lightglue import LightGlue, SuperPoint, DISK
from lightglue.utils import load_image, rbd
from lightglue import viz2d
import torch

torch.set_grad_enabled(False)
images = Path("assets")

## Load extractor and matcher module
In this example we use SuperPoint features combined with LightGlue.

In [None]:
# Check CUDA availability and set device appropriately
import torch
print(f"PyTorch version: {torch.__version__}")  
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    print("CUDA not available, using CPU")
    device = torch.device("cpu")

print(f"Using device: {device}")

In [None]:
# max_num_keypoints=2048
max_num_keypoints = None
extractor = SuperPoint(max_num_keypoints = max_num_keypoints).eval().to(device)  # load the extractor
matcher = LightGlue(features="superpoint").eval().to(device)


## Own example

In [None]:
data_path = Path("../data")
# Load images and move to device (works for both CUDA and CPU)
image_map = load_image(data_path / 'map.png').to(device)

In [None]:
# Read the training positions
import pandas as pd
import matplotlib.pyplot as plt
train_pos = pd.read_csv(data_path / 'train_data/train_pos.csv')

def extract_and_match(image0, image_map, map_features=None):
    feats0 = extractor.extract(image0)
    if map_features is not None:
        feats1 = map_features
    else:
        feats1 = extractor.extract(image_map)
    matches01 = matcher({"image0": feats0, "image1": feats1})
    feats0, feats1, matches01 = [
        rbd(x) for x in [feats0, feats1, matches01]
    ]  # remove batch dimension

    kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
    m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
    # print(f"__Number of matches: {len(m_kpts0)}")
    return m_kpts0, m_kpts1, matches01, kpts0, kpts1

def get_ground_truth_positions(train_pos, img_path):
    # Get id from filename
    id = int(img_path.name.split('.')[0])
    # Get the ground truth positions for the current image
    gt_positions = train_pos[train_pos['id'] == id]
    gt_x = gt_positions.iloc[0]['x_pixel']
    gt_y = gt_positions.iloc[0]['y_pixel']
    return gt_x, gt_y, id

def get_zoomed_map(image_map, gt_x, gt_y, zoom_factor=5):
    # Get map dimensions
    map_height, map_width = image_map.shape[1], image_map.shape[2]
    
    # Calculate zoom window (4x zoom means 1/4 of original size)
    window_width = map_width // zoom_factor
    window_height = map_height // zoom_factor
    
    # Calculate crop boundaries centered on ground truth
    left = max(0, int(gt_x - window_width // 2))
    right = min(map_width, int(gt_x + window_width // 2))
    top = max(0, int(gt_y - window_height // 2))
    bottom = min(map_height, int(gt_y + window_height // 2))
    
    # Crop the map
    zoomed_map = image_map[:, top:bottom, left:right]

    # Adjust ground truth coordinates for the cropped image
    adj_gt_x = gt_x - left
    adj_gt_y = gt_y - top
    
    return zoomed_map, adj_gt_x, adj_gt_y

def get_map_tile_features(image_map, zoom_factor=5):
    """
    Extract features for map tiles at a specified zoom factor, creating a grid of tiles across the entire map.
    """
    # Get map dimensions
    map_height, map_width = image_map.shape[1], image_map.shape[2]
    
    # Calculate tile dimensions
    tile_width = map_width // zoom_factor
    tile_height = map_height // zoom_factor
    
    tile_features = []
    tile_positions = []
    tiles = []
    
    # Create grid of tiles
    for row in range(zoom_factor):
        for col in range(zoom_factor):
            # Calculate tile boundaries
            left = col * tile_width
            right = min((col + 1) * tile_width, map_width)
            top = row * tile_height
            bottom = min((row + 1) * tile_height, map_height)
            
            # Extract tile
            tile = image_map[:, top:bottom, left:right]
            
            # Extract features from the tile
            feats = extractor.extract(tile)
            feats = rbd(feats)  # remove batch dimension
            
            tile_features.append(feats)
            tile_positions.append({
                'row': row, 
                'col': col, 
                'left': left, 
                'right': right, 
                'top': top, 
                'bottom': bottom
            })
            tiles.append(tile)
    
    return tile_features, tile_positions, tiles

def plot_matches(image0, img_path, image_map, train_pos, zoom_factor=5, map_features=None, plot_gt=False):
    gt_x, gt_y, id = get_ground_truth_positions(train_pos, img_path)
    print(f"__Plotting ground truth for ID: {id}")
    print(f"__Ground truth position: ({gt_x:.1f}, {gt_y:.1f})")

    # Get zoomed map and adjusted ground truth coordinates
    if zoom_factor > 1:
        zoomed_map, adj_gt_x, adj_gt_y = get_zoomed_map(image_map, gt_x, gt_y, zoom_factor)
        image_map = zoomed_map.to(device)
        gt_x, gt_y = adj_gt_x, adj_gt_y

    # Extract and match keypoints
    m_kpts0, m_kpts1, matches01, kpts0, kpts1 = extract_and_match(image0, image_map, map_features)
    # Visualize the images
    _ = viz2d.plot_images([image0, image_map], titles=[f"Image {id}", f"Map (Zoom {zoom_factor}x)"])
    # Get the current figure and axes
    fig = plt.gcf()
    axes = fig.axes
    
    # Plot red dot on the map (image1) at ground truth position
    if plot_gt:
        axes[1].plot(gt_x, gt_y, 'ro', markersize=8, markeredgecolor='white', markeredgewidth=1)
        axes[1].text(gt_x + 20, gt_y - 20, f'GT ({gt_x:.0f}, {gt_y:.0f})', 
                    color='red', fontsize=10, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
    # Plot matches
    viz2d.plot_matches(m_kpts0, m_kpts1, color="blue", lw=0.3)
    viz2d.add_text(0, f'Stop after {matches01["stop"]} layers')
    
    # Compute center of map matches and plot
    if len(m_kpts1) > 0:
        center_x = m_kpts1[:, 0].mean().item()
        center_y = m_kpts1[:, 1].mean().item()
        axes[1].plot(center_x, center_y, 'go', markersize=10, markeredgecolor='white', markeredgewidth=2)
        axes[1].text(center_x + 20, center_y + 20, f'Center ({center_x:.0f}, {center_y:.0f})', 
                    color='green', fontsize=10, fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
    
    plt.show()

    return center_x, center_y

    # # Plot pruned keypoints
    # kpc0, kpc1 = viz2d.cm_prune(matches01["prune0"]), viz2d.cm_prune(matches01["prune1"])
    # viz2d.plot_images([image0, image_map])
    # viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=6)


def compute_pose_prediction(image0, img_path, image_map, train_pos, zoom_factor=5, map_features=None, plot_gt=False):
    # Extract and match keypoints
    m_kpts0, m_kpts1, matches01, kpts0, kpts1 = extract_and_match(image0, image_map, map_features)
    
    # Compute center of map matches and plot
    if len(m_kpts1) > 0:
        center_x = m_kpts1[:, 0].mean().item()
        center_y = m_kpts1[:, 1].mean().item()

    return center_x, center_y

In [None]:
TEST_IMAGE_PATH = data_path / "test_data" / "test_images"
TRAIN_IMAGE_PATH = data_path / "train_data" / "train_images"

image_paths = sorted(TRAIN_IMAGE_PATH.iterdir())

# Configuration parameters
DRONE_SCALE_FACTOR = 8  # Scale factor for drone image resizing
MAP_ZOOM_FACTOR = int(DRONE_SCALE_FACTOR * 0.625)     # Map tiling factor
STRIDE_FACTOR = 2  # 2 = 50% overlap; 1 = stride equals window size

# Known resolutions
DRONE_WIDTH = 8004
DRONE_HEIGHT = 6001
MAP_WIDTH = 5000
MAP_HEIGHT = 2500

NUM_BEST_TILES = 5

# Calculate target resize dimensions
target_width = DRONE_WIDTH // DRONE_SCALE_FACTOR   # 8004 // 6 = 1334
target_height = DRONE_HEIGHT // DRONE_SCALE_FACTOR # 6001 // 6 = 1000

window_height, window_width = MAP_HEIGHT // MAP_ZOOM_FACTOR, MAP_WIDTH // MAP_ZOOM_FACTOR
stride_y, stride_x = window_height // STRIDE_FACTOR, window_width // STRIDE_FACTOR

print(f"Drone image scaling: {DRONE_WIDTH}x{DRONE_HEIGHT} -> {target_width}x{target_height} (factor: {DRONE_SCALE_FACTOR})")
print(f"Map dimensions: {MAP_WIDTH}x{MAP_HEIGHT} -> Tile size: {window_width}x{window_height} (zoom factor: {MAP_ZOOM_FACTOR})")
print(f"Stride: {stride_x}x{stride_y}")

In [None]:
# Run for x images in train_data with a step of 50
for i in range(0, len(image_paths), 70):
    img_path = image_paths[i]
    print(f"Processing image: {img_path.name}...")
    
    # Resize drone image using calculated dimensions
    image0 = load_image(img_path, resize=(target_height, target_width)).to(device)
    
    # Store tile results with match counts
    tile_results = []
    
    # Generate all 90-degree rotations of the image (0, 90, 180, 270 degrees)
    rotated_images = [image0]
    for k in range(1, 4):
        rotated_images.append(torch.rot90(image0, k, dims=[1, 2]))
    for rot_idx, rotated_image in enumerate(rotated_images):

        windows = []
        map_tile_positions = []

        for top in range(0, MAP_HEIGHT - window_height + 1, stride_y):
            for left in range(0, MAP_WIDTH - window_width + 1, stride_x):
                bottom = top + window_height
                right = left + window_width
                window = image_map[:, top:bottom, left:right]
                windows.append(window)
                map_tile_positions.append(
                    {"top": top,
                    "bottom": bottom,
                    "left": left,
                    "right": right}
                    )
        
        for tile_pos, tile_img in zip(map_tile_positions, windows):
            # Extract and match to count matches without plotting
            m_kpts0, m_kpts1, matches01, kpts0, kpts1 = extract_and_match(rotated_image, tile_img, map_features=None)
            num_matches = len(m_kpts0)
            
            # Store results
            tile_results.append({
                'tile_pos': tile_pos,
                'tile_img': tile_img,
                'num_matches': num_matches,
                'matches_data': (m_kpts0, m_kpts1, matches01, kpts0, kpts1),
                'rotation_index': rot_idx,
                'rotated_image': rotated_image
            })
            
            # print(f"__Tile [{tile_pos["left"]}, {tile_pos["top"]}] has {num_matches} matches")
        
    # Sort by number of matches (descending) and take top 3
    best_tiles = sorted(tile_results, key=lambda x: x['num_matches'], reverse=True)[:NUM_BEST_TILES]
    
    print(f"\n=== Top 3 tiles for image {img_path.name} ===")
    for idx, tile_result in enumerate(best_tiles, 1):
        tile_pos = tile_result['tile_pos']
        num_matches = tile_result['num_matches']
        print(f"{idx}. Tile {tile_pos['left']}, {tile_pos['top']}: {num_matches} matches")
    
    # Plot only the best 3 tiles
    for idx, tile_result in enumerate(best_tiles, 1):
        tile_pos = tile_result['tile_pos']
        tile_img = tile_result['tile_img']
        num_matches = tile_result['num_matches']
        best_rotation_idx = tile_result['rotation_index']
        rotated_image = tile_result['rotated_image']
        
        print(f"\n--- Plotting #{idx} best tile (image rot= {best_rotation_idx *90}) at: [{tile_pos['left']}, {tile_pos['top']}] with {num_matches} matches ---")
        pred_x, pred_y = plot_matches(rotated_image, img_path, tile_img, train_pos, 
                                      zoom_factor=1, map_features=None, plot_gt=False)
        
        # Plot whole map with best tile highlighted, predicted position, and ground truth
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.imshow(image_map.cpu().numpy().transpose(1, 2, 0))
        
        # Get ground truth position for the current image
        gt_x, gt_y, id = get_ground_truth_positions(train_pos, img_path)
        
        # Highlight the best tile with a rectangle
        tile_pos = tile_result['tile_pos']
        rect = plt.Rectangle((tile_pos['left'], tile_pos['top']), 
                   tile_pos['right'] - tile_pos['left'], 
                   tile_pos['bottom'] - tile_pos['top'],
                   linewidth=1, edgecolor='blue', facecolor='none', alpha=0.8)
        ax.add_patch(rect)
        
        # Adjust predicted position to full map coordinates
        full_map_pred_x = pred_x + tile_pos['left']
        full_map_pred_y = pred_y + tile_pos['top']
        
        # Plot ground truth position (red dot)
        ax.plot(gt_x, gt_y, 'ro', markersize=10, markeredgecolor='white', markeredgewidth=2)
        ax.text(gt_x + 50, gt_y - 50, f'GT ({gt_x:.0f}, {gt_y:.0f})', 
            color='red', fontsize=12, fontweight='bold',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
        
        # Plot predicted position (green dot)
        ax.plot(full_map_pred_x, full_map_pred_y, 'go', markersize=10, markeredgecolor='white', markeredgewidth=2)
        ax.text(full_map_pred_x + 50, full_map_pred_y + 50, f'Pred ({full_map_pred_x:.0f}, {full_map_pred_y:.0f})', 
            color='green', fontsize=12, fontweight='bold',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
        
        # Add tile information
        ax.text(tile_pos['left'] + 10, tile_pos['top'] + 30, 
            f'Best Tile \n{num_matches} matches', 
            color='blue', fontsize=10, fontweight='bold',
            bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8))
        
        ax.set_title(f'Full Map - Image {id} - Tile Ranking #{idx}')
        ax.set_xlabel('X coordinate')
        ax.set_ylabel('Y coordinate')
        plt.show()


## Run on test data

In [None]:
# Run on ALL test data and save to CSV
TEST_IMAGE_PATH = data_path / "test_data" / "test_images"
test_image_paths = sorted(TEST_IMAGE_PATH.iterdir())
MIN_MATCHES_THRESHOLD = 10  # Minimum matches to consider a tile
HIGH_CONFIDENCE_THRESHOLD = 40  # If a tile has this many matches, use only this tile
print(f"Processing {len(test_image_paths)} test images...")

# Prepare CSV file with headers
output_path = data_path / 'predictions_lightglue_test.csv'
with open(output_path, 'w') as f:
    f.write('id,x_pixel,y_pixel\n')

for i, img_path in enumerate(test_image_paths):
    if i % 10 == 0:
        print(f"Processing {i+1}/{len(test_image_paths)}: {img_path.name}")
    
    img_id = int(img_path.name.split('.')[0])
    
    # Resize drone image using calculated dimensions
    image0 = load_image(img_path, resize=(target_height, target_width)).to(device)
    
    # Store tile results with match counts
    tile_results = []
    
    # Generate all 90-degree rotations of the image (0, 90, 180, 270 degrees)
    rotated_images = [image0]
    for k in range(1, 4):
        rotated_images.append(torch.rot90(image0, k, dims=[1, 2]))
    for rot_idx, rotated_image in enumerate(rotated_images):
    
        # Sliding window approach
        windows = []
        map_tile_positions = []

        for top in range(0, MAP_HEIGHT - window_height + 1, stride_y):
            for left in range(0, MAP_WIDTH - window_width + 1, stride_x):
                bottom = top + window_height
                right = left + window_width
                window = image_map[:, top:bottom, left:right]
                windows.append(window)
                map_tile_positions.append(
                    {"top": top,
                    "bottom": bottom,
                    "left": left,
                    "right": right}
                    )
        
        for tile_pos, tile_img in zip(map_tile_positions, windows):
            # Extract and match to count matches without plotting
            m_kpts0, m_kpts1, matches01, kpts0, kpts1 = extract_and_match(rotated_image, tile_img, map_features=None)
            num_matches = len(m_kpts0)
            
            # Store results
            tile_results.append({
                'tile_pos': tile_pos,
                'tile_img': tile_img,
                'num_matches': num_matches,
                'matches_data': (m_kpts0, m_kpts1, matches01, kpts0, kpts1),
                'rotation_index': rot_idx,
                'rotated_image': rotated_image
            })
        
     # Sort by number of matches (descending) and apply filtering logic
    sorted_tiles = sorted(tile_results, key=lambda x: x['num_matches'], reverse=True)

    # Check if any tiles have high confidence
    high_confidence_tiles = [tile for tile in sorted_tiles if tile['num_matches'] >= HIGH_CONFIDENCE_THRESHOLD]

    if len(high_confidence_tiles) > 0:
        # Use all high confidence tiles
        best_tiles = high_confidence_tiles
    else:
        # Filter tiles with enough matches
        good_tiles = [tile for tile in sorted_tiles if tile['num_matches'] >= MIN_MATCHES_THRESHOLD]
        
        if len(good_tiles) > 0:
            # Use up to NUM_BEST_TILES from good tiles
            best_tiles = good_tiles[:NUM_BEST_TILES]
        else:
            # If no tile meets minimum threshold, use at least the best one
            best_tiles = [sorted_tiles[0]]
            
    # Process multiple best tiles and collect predictions
    predictions_x = []
    predictions_y = []
    
    for idx, tile_result in enumerate(best_tiles, 1):
        tile_pos = tile_result['tile_pos']
        tile_img = tile_result['tile_img']
        num_matches = tile_result['num_matches']
        best_rotation_idx = tile_result['rotation_index']
        rotated_image = tile_result['rotated_image']
        
        # Only use tiles with at least some matches
        if num_matches > 0:
            pred_x, pred_y = compute_pose_prediction(rotated_image, img_path, tile_img, train_pos, 
                                          zoom_factor=1, map_features=None, plot_gt=False)
            
            # Adjust predicted position to full map coordinates
            full_map_pred_x = pred_x + tile_pos['left']
            full_map_pred_y = pred_y + tile_pos['top']
            
            predictions_x.append(full_map_pred_x)
            predictions_y.append(full_map_pred_y)
    
    # Calculate mean prediction from multiple tiles
    if predictions_x:
        final_pred_x = sum(predictions_x) / len(predictions_x)
        final_pred_y = sum(predictions_y) / len(predictions_y)
        num_tiles_used = len(predictions_x)
    
    # Write prediction to CSV immediately
    with open(output_path, 'a') as f:
        f.write(f'{img_id},{full_map_pred_x:.1f},{full_map_pred_y:.1f}\n')

print(f"\nSaved predictions to {output_path}")

# Quick validation
predictions_df = pd.read_csv(output_path)
print(f"Total predictions: {len(predictions_df)}")
print("Sample predictions:")
print(predictions_df.head())

### Evaluate on train data

In [None]:
# Run on ALL train data and evaluate against ground truth
TRAIN_IMAGE_PATH = data_path / "train_data" / "train_images"
train_image_paths = sorted(TRAIN_IMAGE_PATH.iterdir())

print(f"Processing {len(train_image_paths)} train images...")

MIN_MATCHES_THRESHOLD = 10  # Minimum matches to consider a tile
HIGH_CONFIDENCE_THRESHOLD = 40  # If a tile has this many matches, use only this tile

# Prepare CSV file with headers
output_path = data_path / 'predictions_lightglue_train.csv'
with open(output_path, 'w') as f:
    f.write('id,x_pixel,y_pixel,gt_x_pixel,gt_y_pixel,error_distance,num_tiles_used\n')

for i, img_path in enumerate(train_image_paths):
    if (i+1) % 11 == 0:
        break
    if i % 10 == 0:
        print(f"Processing {i+1}/{len(train_image_paths)}: {img_path.name}")
    
    img_id = int(img_path.name.split('.')[0])
    
    # Get ground truth for this image
    gt_x, gt_y, _ = get_ground_truth_positions(train_pos, img_path)
    
    # Resize drone image using calculated dimensions
    image0 = load_image(img_path, resize=(target_height, target_width)).to(device)
    
    # Store tile results with match counts
    tile_results = []
    
    # Generate all 90-degree rotations of the image (0, 90, 180, 270 degrees)
    rotated_images = [image0]
    for k in range(1, 4):
        rotated_images.append(torch.rot90(image0, k, dims=[1, 2]))
    for rot_idx, rotated_image in enumerate(rotated_images):
        
        # Sliding window approach
        windows = []
        map_tile_positions = []

        for top in range(0, MAP_HEIGHT - window_height + 1, stride_y):
            for left in range(0, MAP_WIDTH - window_width + 1, stride_x):
                bottom = top + window_height
                right = left + window_width
                window = image_map[:, top:bottom, left:right]
                windows.append(window)
                map_tile_positions.append(
                    {"top": top,
                    "bottom": bottom,
                    "left": left,
                    "right": right}
                    )
        
        for tile_pos, tile_img in zip(map_tile_positions, windows):
            # Extract and match to count matches without plotting
            m_kpts0, m_kpts1, matches01, kpts0, kpts1 = extract_and_match(rotated_image, tile_img, map_features=None)
            num_matches = len(m_kpts0)
            
            # Store results
            tile_results.append({
                'tile_pos': tile_pos,
                'tile_img': tile_img,
                'num_matches': num_matches,
                'matches_data': (m_kpts0, m_kpts1, matches01, kpts0, kpts1),
                'rotation_index': rot_idx,
                'rotated_image': rotated_image
            })
        
    # Sort by number of matches (descending) and apply filtering logic
    sorted_tiles = sorted(tile_results, key=lambda x: x['num_matches'], reverse=True)

    # Check if any tiles have high confidence
    high_confidence_tiles = [tile for tile in sorted_tiles if tile['num_matches'] >= HIGH_CONFIDENCE_THRESHOLD]

    if len(high_confidence_tiles) > 0:
        # Use all high confidence tiles
        best_tiles = high_confidence_tiles
    else:
        # Filter tiles with enough matches
        good_tiles = [tile for tile in sorted_tiles if tile['num_matches'] >= MIN_MATCHES_THRESHOLD]
        
        if len(good_tiles) > 0:
            # Use up to NUM_BEST_TILES from good tiles
            best_tiles = good_tiles[:NUM_BEST_TILES]
        else:
            # If no tile meets minimum threshold, use at least the best one
            best_tiles = [sorted_tiles[0]]
            
    # Process multiple best tiles and collect predictions
    predictions_x = []
    predictions_y = []
    
    for idx, tile_result in enumerate(best_tiles, 1):
        tile_pos = tile_result['tile_pos']
        tile_img = tile_result['tile_img']
        num_matches = tile_result['num_matches']
        best_rotation_idx = tile_result['rotation_index']
        rotated_image = tile_result['rotated_image']
        
        # Only use tiles with at least some matches
        if num_matches > 0:
            pred_x, pred_y = compute_pose_prediction(rotated_image, img_path, tile_img, train_pos, 
                                          zoom_factor=1, map_features=None, plot_gt=False)
            
            # Adjust predicted position to full map coordinates
            full_map_pred_x = pred_x + tile_pos['left']
            full_map_pred_y = pred_y + tile_pos['top']
            
            predictions_x.append(full_map_pred_x)
            predictions_y.append(full_map_pred_y)
    
    # Calculate mean prediction from multiple tiles
    if predictions_x:
        final_pred_x = sum(predictions_x) / len(predictions_x)
        final_pred_y = sum(predictions_y) / len(predictions_y)
        num_tiles_used = len(predictions_x)
    
    # Calculate error distance
    error_distance = ((final_pred_x - gt_x)**2 + (final_pred_y - gt_y)**2)**0.5
    
    # Write prediction and evaluation to CSV immediately
    with open(output_path, 'a') as f:
        f.write(f'{img_id},{final_pred_x:.1f},{final_pred_y:.1f},{gt_x:.1f},{gt_y:.1f},{error_distance:.1f},{num_tiles_used}\n')

print(f"\nSaved predictions to {output_path}")

# Quick validation and evaluation
predictions_df = pd.read_csv(output_path)
print(f"Total predictions: {len(predictions_df)}")
print("Sample predictions:")
print(predictions_df.head())

# Calculate evaluation metrics
mean_error = predictions_df['error_distance'].mean()
median_error = predictions_df['error_distance'].median()
std_error = predictions_df['error_distance'].std()

print(f"\nEvaluation Results:")
print(f"Mean error: {mean_error:.1f} pixels")
print(f"Median error: {median_error:.1f} pixels")
print(f"Std error: {std_error:.1f} pixels")
print(f"Min error: {predictions_df['error_distance'].min():.1f} pixels")
print(f"Max error: {predictions_df['error_distance'].max():.1f} pixels")
print(f"Average tiles used: {predictions_df['num_tiles_used'].mean():.1f}")

# Accuracy at different thresholds
thresholds = [50, 100, 200, 500]
for threshold in thresholds:
    accuracy = (predictions_df['error_distance'] <= threshold).mean() * 100
    print(f"Accuracy within {threshold} pixels: {accuracy:.1f}%")

## Easy example
The top image shows the matches, while the bottom image shows the point pruning across layers. In this case, LightGlue prunes a few points with occlusions, but is able to stop the context aggregation after 4/9 layers.

In [None]:
image0 = load_image(images / "DSC_0411.JPG")
image1 = load_image(images / "DSC_0410.JPG")

feats0 = extractor.extract(image0.to(device))
feats1 = extractor.extract(image1.to(device))
matches01 = matcher({"image0": feats0, "image1": feats1})
feats0, feats1, matches01 = [
    rbd(x) for x in [feats0, feats1, matches01]
]  # remove batch dimension

kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]

axes = viz2d.plot_images([image0, image1])
viz2d.plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)
viz2d.add_text(0, f'Stop after {matches01["stop"]} layers', fs=20)

kpc0, kpc1 = viz2d.cm_prune(matches01["prune0"]), viz2d.cm_prune(matches01["prune1"])
viz2d.plot_images([image0, image1])
viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=10)

## Difficult example
For pairs with significant viewpoint- and illumination changes, LightGlue can exclude a lot of points early in the matching process (red points), which significantly reduces the inference time.

In [None]:
image0 = load_image(images / "sacre_coeur1.jpg")
image1 = load_image(images / "sacre_coeur2.jpg")

feats0 = extractor.extract(image0.to(device))
feats1 = extractor.extract(image1.to(device))
matches01 = matcher({"image0": feats0, "image1": feats1})
feats0, feats1, matches01 = [
    rbd(x) for x in [feats0, feats1, matches01]
]  # remove batch dimension

kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]

axes = viz2d.plot_images([image0, image1])
viz2d.plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)
viz2d.add_text(0, f'Stop after {matches01["stop"]} layers')

kpc0, kpc1 = viz2d.cm_prune(matches01["prune0"]), viz2d.cm_prune(matches01["prune1"])
viz2d.plot_images([image0, image1])
viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=6)

###  