In [None]:
!pip install --no-index /kaggle/input/imc2025-packages-python-11-new/* --no-deps
!mkdir -p /root/.cache/torch/hub/checkpoints
!cp /kaggle/input/aliked/pytorch/aliked-n16/1/aliked-n16.pth /root/.cache/torch/hub/checkpoints/
!cp /kaggle/input/lightglue/pytorch/aliked/1/aliked_lightglue.pth /root/.cache/torch/hub/checkpoints/
!cp /kaggle/input/lightglue/pytorch/aliked/1/aliked_lightglue.pth /root/.cache/torch/hub/checkpoints/aliked_lightglue_v0-1_arxiv-pth

import sys
import os
from tqdm import tqdm
from time import time, sleep
import gc
import numpy as np
import h5py
import dataclasses
import pandas as pd
from IPython.display import clear_output
from collections import defaultdict
from copy import deepcopy
from PIL import Image
import cv2
import torch
import torch.nn.functional as F
import kornia as K
import kornia.feature as KF
from lightglue import match_pair
from lightglue import ALIKED, LightGlue
from lightglue.utils import load_image, rbd
from transformers import AutoImageProcessor, AutoModel
import pycolmap
sys.path.append('/kaggle/input/imc25-utils')
from database import *
from h5_to_db import *
import metric

@dataclasses.dataclass
class Config:
    """All tunable hyperparameters in one place"""
    
    sim_threshold: float = 0.40
    
    min_pairs_per_image: int = 25
    
    exhaustive_threshold: int = 20
    
    num_features: int = 6000
    
    detection_threshold: float = 0.01
    
    resize_to: int = 1600
    
    min_matches: int = 20
    
    min_model_size: int = 3
    
    max_num_models: int = 30
    
    def get_for_dataset(self, dataset_name: str, num_images: int):
        """Apply adaptive parameters based on dataset characteristics"""
        config = dataclasses.replace(self)
        
        if num_images <= 30:
            # Small scenes - can afford more features and higher resolution
            config.num_features = 8000
            config.resize_to = 2048
            config.exhaustive_threshold = 30
            config.detection_threshold = 0.005
            
        elif num_images >= 200:
            # Large scenes - need more connectivity
            config.sim_threshold = 0.35
            config.min_pairs_per_image = 35
            config.max_num_models = 50
        
        # Dataset-specific tweaks (add your observations here)
        if 'vineyard' in dataset_name.lower():
            # Repetitive structures - be more strict
            config.detection_threshold = 0.02
            config.min_matches = 25
            
        elif 'brandenburg' in dataset_name.lower() or 'church' in dataset_name.lower():
            # Architectural scenes with detail
            config.num_features = 7000
            config.resize_to = 2048
        
        return config

CONFIG = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'{device=}')

def load_torch_image(fname, device=torch.device('cpu')):
    img = K.io.load_image(fname, K.io.ImageLoadType.RGB32, device=device)[None, ...]
    return img

def get_global_desc(fnames, device=torch.device('cpu')):
    processor = AutoImageProcessor.from_pretrained('/kaggle/input/dinov2/pytorch/base/1')
    model = AutoModel.from_pretrained('/kaggle/input/dinov2/pytorch/base/1')
    model = model.eval().to(device)
    global_descs_dinov2 = []
    
    for i, img_fname_full in tqdm(enumerate(fnames), total=len(fnames)):
        timg = load_torch_image(img_fname_full, device='cpu')
        with torch.inference_mode():
            inputs = processor(images=timg, return_tensors="pt", do_rescale=False).to(device)
            outputs = model(**inputs)
            dino_mac = F.normalize(outputs.last_hidden_state[:, 1:].max(dim=1)[0], dim=1, p=2)
        global_descs_dinov2.append(dino_mac.detach().cpu())
    
    return torch.cat(global_descs_dinov2, dim=0)

def get_img_pairs_exhaustive(img_fnames):
    index_pairs = []
    for i in range(len(img_fnames)):
        for j in range(i + 1, len(img_fnames)):
            index_pairs.append((i, j))
    return index_pairs

def get_image_pairs_shortlist(fnames, sim_th, min_pairs, exhaustive_if_less, device):
    num_imgs = len(fnames)
    if num_imgs <= exhaustive_if_less:
        return get_img_pairs_exhaustive(fnames)
    
    descs = get_global_desc(fnames, device=device)
    dm = torch.cdist(descs, descs, p=2).detach().cpu().numpy()
    
    mask = dm <= sim_th
    matching_list = []
    ar = np.arange(num_imgs)
    
    for st_idx in range(num_imgs - 1):
        mask_idx = mask[st_idx]
        to_match = ar[mask_idx]
        if len(to_match) < min_pairs:
            to_match = np.argsort(dm[st_idx])[:min_pairs]
        
        for idx in to_match:
            if st_idx == idx:
                continue
            if dm[st_idx, idx] < 1000:
                matching_list.append(tuple(sorted((st_idx, idx.item()))))
    
    return sorted(list(set(matching_list)))

def detect_aliked(img_fnames, feature_dir, num_features, detection_threshold, resize_to, device):
    dtype = torch.float32
    extractor = ALIKED(
        max_num_keypoints=num_features,
        detection_threshold=detection_threshold,
        resize=resize_to
    ).eval().to(device, dtype)
    
    os.makedirs(feature_dir, exist_ok=True)
    
    with h5py.File(f'{feature_dir}/keypoints.h5', mode='w') as f_kp, \
         h5py.File(f'{feature_dir}/descriptors.h5', mode='w') as f_desc:
        
        for img_path in tqdm(img_fnames):
            key = img_path.split('/')[-1]
            
            with torch.inference_mode():
                image = load_torch_image(img_path, device=device).to(dtype)
                feats = extractor.extract(image)
                
                kpts = feats['keypoints'].reshape(-1, 2).detach().cpu().numpy()
                descs = feats['descriptors'].reshape(len(kpts), -1).detach().cpu().numpy()
                
                f_kp[key] = kpts
                f_desc[key] = descs

def match_with_lightglue(img_fnames, index_pairs, feature_dir, min_matches, device, verbose=True):
    lg_matcher = KF.LightGlueMatcher(
        "aliked",
        {"width_confidence": -1, "depth_confidence": -1, "mp": 'cuda' in str(device)}
    ).eval().to(device)
    
    with h5py.File(f'{feature_dir}/keypoints.h5', mode='r') as f_kp, \
         h5py.File(f'{feature_dir}/descriptors.h5', mode='r') as f_desc, \
         h5py.File(f'{feature_dir}/matches.h5', mode='w') as f_match:
        
        for idx1, idx2 in tqdm(index_pairs):
            key1 = img_fnames[idx1].split('/')[-1]
            key2 = img_fnames[idx2].split('/')[-1]
            
            kp1 = torch.from_numpy(f_kp[key1][...]).to(device)
            kp2 = torch.from_numpy(f_kp[key2][...]).to(device)
            desc1 = torch.from_numpy(f_desc[key1][...]).to(device)
            desc2 = torch.from_numpy(f_desc[key2][...]).to(device)
            
            with torch.inference_mode():
                laf1 = KF.laf_from_center_scale_ori(kp1[None])
                laf2 = KF.laf_from_center_scale_ori(kp2[None])
                dists, idxs = lg_matcher(desc1, desc2, laf1, laf2)
            
            if len(idxs) == 0:
                continue
            
            n_matches = len(idxs)
            if verbose:
                print(f'{key1}-{key2}: {n_matches} matches')
            
            if n_matches >= min_matches:
                group = f_match.require_group(key1)
                group.create_dataset(key2, data=idxs.detach().cpu().numpy().reshape(-1, 2))

def import_into_colmap(img_dir, feature_dir, database_path):
    db = COLMAPDatabase.connect(database_path)
    db.create_tables()
    fname_to_id = add_keypoints(db, feature_dir, img_dir, '', 'simple-pinhole', False)
    add_matches(db, feature_dir, fname_to_id)
    db.commit()

@dataclasses.dataclass
class Prediction:
    image_id: str | None
    dataset: str
    filename: str
    cluster_index: int | None = None
    rotation: np.ndarray | None = None
    translation: np.ndarray | None = None

# Configuration
is_train = False
data_dir = '/kaggle/input/image-matching-challenge-2025'
workdir = '/kaggle/working/result/'
os.makedirs(workdir, exist_ok=True)

sample_submission_csv = os.path.join(data_dir, 'train_labels.csv' if is_train else 'sample_submission.csv')

# Load datasets
samples = {}
competition_data = pd.read_csv(sample_submission_csv)
for _, row in competition_data.iterrows():
    if row.dataset not in samples:
        samples[row.dataset] = []
    samples[row.dataset].append(
        Prediction(
            image_id=None if is_train else row.image_id,
            dataset=row.dataset,
            filename=row.image
        )
    )

for dataset in samples:
    print(f'Dataset "{dataset}" -> num_images={len(samples[dataset])}')

gc.collect()

# Processing configuration
max_images = None
datasets_to_process = None
if is_train:
    datasets_to_process = ['amy_gardens', 'ETs', 'fbk_vineyard', 'stairs']

# Processing
timings = {
    "shortlisting": [], "feature_detection": [], "feature_matching": [],
    "RANSAC": [], "Reconstruction": [],
}
mapping_result_strs = []

print(f"\nExtracting on device {device}")
print(f"\nDefault hyperparameters:")
print(f"  sim_threshold: {CONFIG.sim_threshold}")
print(f"  min_pairs_per_image: {CONFIG.min_pairs_per_image}")
print(f"  num_features: {CONFIG.num_features}")
print(f"  detection_threshold: {CONFIG.detection_threshold}")
print(f"  resize_to: {CONFIG.resize_to}")
print(f"  min_matches: {CONFIG.min_matches}\n")

for dataset, predictions in samples.items():
    if datasets_to_process and dataset not in datasets_to_process:
        print(f'Skipping "{dataset}"')
        continue
    
    images_dir = os.path.join(data_dir, 'train' if is_train else 'test', dataset)
    images = [os.path.join(images_dir, p.filename) for p in predictions]
    if max_images:
        images = images[:max_images]
    
    num_images = len(images)
    print(f'\n{"="*70}')
    print(f'Processing dataset "{dataset}": {num_images} images')
    print(f'{"="*70}')
    
    # Get adaptive config for this dataset
    cfg = CONFIG.get_for_dataset(dataset, num_images)
    print(f"Adaptive params: features={cfg.num_features}, resize={cfg.resize_to}, "
          f"sim_th={cfg.sim_threshold:.2f}, min_pairs={cfg.min_pairs_per_image}")
    
    filename_to_index = {p.filename: idx for idx, p in enumerate(predictions)}
    feature_dir = os.path.join(workdir, 'featureout', dataset)
    os.makedirs(feature_dir, exist_ok=True)
    
    try:
        # Shortlisting
        t = time()
        index_pairs = get_image_pairs_shortlist(
            images, cfg.sim_threshold, cfg.min_pairs_per_image,
            cfg.exhaustive_threshold, device
        )
        timings['shortlisting'].append(time() - t)
        print(f'✓ Shortlisting: {len(index_pairs)} pairs in {time() - t:.2f}s')
        gc.collect()
        
        # Feature detection
        t = time()
        detect_aliked(images, feature_dir, cfg.num_features, 
                     cfg.detection_threshold, cfg.resize_to, device)
        gc.collect()
        timings['feature_detection'].append(time() - t)
        print(f'✓ Feature detection in {time() - t:.2f}s')
        
        # Feature matching
        t = time()
        match_with_lightglue(images, index_pairs, feature_dir, 
                           cfg.min_matches, device, verbose=False)
        timings['feature_matching'].append(time() - t)
        print(f'✓ Feature matching in {time() - t:.2f}s')
        
        # Import to COLMAP
        database_path = os.path.join(feature_dir, 'colmap.db')
        if os.path.isfile(database_path):
            os.remove(database_path)
        gc.collect()
        sleep(1)
        import_into_colmap(images_dir, feature_dir, database_path)
        output_path = f'{feature_dir}/colmap_rec_aliked'
        
        # RANSAC
        t = time()
        pycolmap.match_exhaustive(database_path)
        timings['RANSAC'].append(time() - t)
        print(f'✓ RANSAC in {time() - t:.2f}s')
        
        # Reconstruction
        mapper_options = pycolmap.IncrementalPipelineOptions()
        mapper_options.min_model_size = cfg.min_model_size
        mapper_options.max_num_models = cfg.max_num_models
        os.makedirs(output_path, exist_ok=True)
        
        t = time()
        maps = pycolmap.incremental_mapping(
            database_path=database_path,
            image_path=images_dir,
            output_path=output_path,
            options=mapper_options
        )
        sleep(1)
        timings['Reconstruction'].append(time() - t)
        print(f'✓ Reconstruction in {time() - t:.2f}s')
        clear_output(wait=False)
        
        # Store results
        registered = 0
        for map_index, cur_map in maps.items():
            for index, image in cur_map.images.items():
                prediction_index = filename_to_index[image.name]
                predictions[prediction_index].cluster_index = map_index
                predictions[prediction_index].rotation = deepcopy(image.cam_from_world.rotation.matrix())
                predictions[prediction_index].translation = deepcopy(image.cam_from_world.translation)
                registered += 1
        
        result = f'✓ "{dataset}": {registered}/{num_images} images, {len(maps)} clusters'
        mapping_result_strs.append(result)
        print(result)
        gc.collect()
        
    except Exception as e:
        print(f"✗ Error: {e}")
        mapping_result_strs.append(f'✗ "{dataset}": Failed!')

# Results
print(f'\n{"="*70}')
print('FINAL RESULTS')
print(f'{"="*70}')
for s in mapping_result_strs:
    print(s)

print(f'\n{"="*70}')
print('TIMING SUMMARY')
print(f'{"="*70}')
for k, v in timings.items():
    print(f'{k:20s}: {sum(v):7.2f}s')

# Create submission
array_to_str = lambda array: ';'.join([f"{x:.09f}" for x in array])
none_to_str = lambda n: ';'.join(['nan'] * n)

submission_file = '/kaggle/working/submission.csv'
with open(submission_file, 'w') as f:
    if is_train:
        # Training format: dataset,scene,image,rotation_matrix,translation_vector
        f.write('dataset,scene,image,rotation_matrix,translation_vector\n')
        for dataset in samples:
            for prediction in samples[dataset]:
                cluster = 'outliers' if prediction.cluster_index is None else f'cluster{prediction.cluster_index}'
                rot = none_to_str(9) if prediction.rotation is None else array_to_str(prediction.rotation.flatten())
                trans = none_to_str(3) if prediction.translation is None else array_to_str(prediction.translation)
                f.write(f'{prediction.dataset},{cluster},{prediction.filename},{rot},{trans}\n')
    else:
        # Test format: image_id,dataset,scene,image,rotation_matrix,translation_vector
        # IMPORTANT: image_id must be from sample_submission.csv
        f.write('image_id,dataset,scene,image,rotation_matrix,translation_vector\n')
        for dataset in samples:
            for prediction in samples[dataset]:
                # Use cluster name or 'outliers' for unregistered images
                cluster = 'outliers' if prediction.cluster_index is None else f'cluster{prediction.cluster_index}'
                
                # Format rotation and translation with nan for outliers
                rot = none_to_str(9) if prediction.rotation is None else array_to_str(prediction.rotation.flatten())
                trans = none_to_str(3) if prediction.translation is None else array_to_str(prediction.translation)
                
                # Write with image_id (critical for test submission)
                f.write(f'{prediction.image_id},{prediction.dataset},{cluster},{prediction.filename},{rot},{trans}\n')

print(f'\n✓ Submission saved to: {submission_file}')

# Validate submission format
print("\n" + "="*70)
print("SUBMISSION VALIDATION")
print("="*70)

# Check file exists and show first few lines
!head -20 {submission_file}

# Validate format
import csv
with open(submission_file, 'r') as f:
    reader = csv.DictReader(f)
    rows = list(reader)
    
    print(f"\n✓ Total rows: {len(rows)}")
    
    if is_train:
        required_cols = ['dataset', 'scene', 'image', 'rotation_matrix', 'translation_vector']
    else:
        required_cols = ['image_id', 'dataset', 'scene', 'image', 'rotation_matrix', 'translation_vector']
    
    actual_cols = reader.fieldnames
    print(f"✓ Columns: {actual_cols}")
    
    # Check all required columns present
    missing_cols = set(required_cols) - set(actual_cols)
    if missing_cols:
        print(f"✗ ERROR: Missing columns: {missing_cols}")
    else:
        print(f"✓ All required columns present")
    
    # Check format of some rows
    outlier_count = 0
    registered_count = 0
    
    for row in rows:
        if row['scene'] == 'outliers':
            outlier_count += 1
            # Verify outliers have nan values
            rot_parts = row['rotation_matrix'].split(';')
            trans_parts = row['translation_vector'].split(';')
            if len(rot_parts) != 9 or len(trans_parts) != 3:
                print(f"✗ ERROR: Outlier {row['image']} has wrong number of values")
            if not all(p == 'nan' for p in rot_parts + trans_parts):
                print(f"✗ WARNING: Outlier {row['image']} should have all nan values")
        else:
            registered_count += 1
            # Verify registered images have numeric values
            rot_parts = row['rotation_matrix'].split(';')
            trans_parts = row['translation_vector'].split(';')
            if len(rot_parts) != 9:
                print(f"✗ ERROR: Image {row['image']} rotation has {len(rot_parts)} values (need 9)")
            if len(trans_parts) != 3:
                print(f"✗ ERROR: Image {row['image']} translation has {len(trans_parts)} values (need 3)")
    
    print(f"\n✓ Registered images: {registered_count}")
    print(f"✓ Outlier images: {outlier_count}")
    print(f"✓ Registration rate: {100*registered_count/len(rows):.1f}%")
    
    # Check cluster distribution
    clusters = {}
    for row in rows:
        dataset = row['dataset']
        scene = row['scene']
        key = f"{dataset}/{scene}"
        clusters[key] = clusters.get(key, 0) + 1
    
    print(f"\n✓ Total clusters: {len([k for k in clusters if 'outliers' not in k])}")
    print("\nCluster sizes:")
    for cluster, count in sorted(clusters.items(), key=lambda x: x[1], reverse=True)[:10]:
        print(f"  {cluster}: {count} images")

print("\n" + "="*70)
print("✓ SUBMISSION COMPLETE")
print("="*70)

✓ "stairs": 39/51 images, 1 clusters

FINAL RESULTS
✓ "ETs": 20/22 images, 1 clusters
✗ "amy_gardens": Failed!
✗ "fbk_vineyard": Failed!
✗ "imc2023_haiper": Failed!
✗ "imc2023_heritage": Failed!
✗ "imc2023_theather_imc2024_church": Failed!
✗ "imc2024_dioscuri_baalshamin": Failed!
✗ "imc2024_lizard_pond": Failed!
✗ "pt_brandenburg_british_buckingham": Failed!
✗ "pt_piazzasanmarco_grandplace": Failed!
✗ "pt_sacrecoeur_trevi_tajmahal": Failed!
✗ "pt_stpeters_stpauls": Failed!
✓ "stairs": 39/51 images, 1 clusters

TIMING SUMMARY
shortlisting        :   10.64s
feature_detection   :    9.26s
feature_matching    :  313.84s
RANSAC              :    1.78s
Reconstruction      :   98.54s

✓ Submission saved to: /kaggle/working/submission.csv

SUBMISSION VALIDATION
image_id,dataset,scene,image,rotation_matrix,translation_vector
ETs_another_et_another_et001.png_public,ETs,cluster0,another_et_another_et001.png,0.999776171;0.002971523;-0.020947008;-0.006434279;0.985894798;-0.167242481;0.020154582;0.