In [None]:
# Given the dataset at google drive, unzip it to Google Colab to
# minimize file access times and improve performance

# I seperated this part in another block of code, in case the user doesn't use Colab

!pip install zipfile36
import zipfile
z=zipfile.ZipFile('/content/drive/MyDrive/CV PROJECT/dataset.zip','r')
z.extractall('/content/')
z.close()

# Install RoMa (! is used for Colab)

!git clone https://github.com/Parskatt/RoMa
%cd RoMa
!pip install -r requirements.txt
!pip install -e .

In [None]:
import os
import math
import csv
import time
from tqdm import tqdm
import warnings
from warnings import warn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from PIL import Image
import cv2
import numpy as np
from typing import Union

from romatch.models.matcher import ConvRefiner,CosKernel,GP,Decoder
from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
from romatch.models.encoders import *
from romatch.utils import get_tuple_transform_ops
from romatch.utils.utils import check_rgb, cls_to_flow_refine, check_not_i16
# from romatch.utils.kde import kde
# No need to import this because we use our optimized version of kde

# Delete if Colab is not used
from google.colab import files

"""
Computer-Vision-SfM-relative-pose-estimator
Author: Guni Deyo Haness
"""

# Modified RoMa codeclass RegressionMatcher(nn.Module):
    def __init__(
        self,
        encoder,
        decoder,
        h=448,
        w=448,
        sample_mode = "threshold_balanced",
        upsample_preds = False,
        symmetric = False,
        name = None,
        attenuate_cert = None,
    ):
        super().__init__()
        self.attenuate_cert = attenuate_cert
        self.encoder = encoder
        self.decoder = decoder
        self.name = name
        self.w_resized = w
        self.h_resized = h
        self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
        self.sample_mode = sample_mode
        self.upsample_preds = upsample_preds
        self.upsample_res = (14*16*6, 14*16*6)
        self.symmetric = symmetric
        self.sample_thresh = 2.5 # Optimized parameter after fine-tuning (default value in source code is 0.5)

    def get_output_resolution(self):
        if not self.upsample_preds:
            return self.h_resized, self.w_resized
        else:
            return self.upsample_res

    def extract_backbone_features(self, batch, batched = True, upsample = False):
        x_q = batch["im_A"]
        x_s = batch["im_B"]
        if batched:
            X = torch.cat((x_q, x_s), dim = 0)
            feature_pyramid = self.encoder(X, upsample = upsample)
        else:
            feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
        return feature_pyramid

    def fast_kde(self,x, std=0.1, half=True, down=None):
        """
        A fast version of KDE that computes the pairwise squared Euclidean distances
        using matrix multiplications rather than torch.cdist. This should be faster
        than the original if memory permits.

        This version computes:

            dist_sq = ||x||^2 + ||x2||^2.T - 2 * (x @ x2.T)
            scores = exp(-dist_sq / (2*std^2))
            density = scores.sum(dim=-1)

        Args:
            x (torch.Tensor): Input tensor of shape [N, d].
            std (float): Standard deviation for the Gaussian kernel.
            half (bool): Whether to convert x to half precision.
            down (int or None): If provided, use x[::down] as the second argument.

        Returns:
            torch.Tensor: A tensor of shape [N] containing the density estimates.
        """
        if half:
            x = x.half()

        # Choose second tensor
        x2 = x[::down] if down is not None else x

        # Compute squared norms
        x_norm = (x ** 2).sum(dim=1, keepdim=True)  # shape [N, 1]
        x2_norm = (x2 ** 2).sum(dim=1, keepdim=True)  # shape [M, 1]

        # Compute squared Euclidean distances:
        # dist_sq[i, j] = ||x[i]||^2 + ||x2[j]||^2 - 2*x[i]·x2[j]
        # We compute this using broadcasting.
        dist_sq = x_norm + x2_norm.T - 2 * (x @ x2.T)
        # Clamp any negative values (due to floating-point errors) to 0.
        dist_sq = torch.clamp(dist_sq, min=0.0)

        # Compute Gaussian kernel scores.
        scores = torch.exp(-dist_sq / (2 * std**2))
        # Sum scores along the second dimension to yield density for each row.
        density = scores.sum(dim=-1)
        return density

    # Modified sample function to use fast_kde
    def sample(self, matches, certainty, num=10000):
        if "threshold" in self.sample_mode:
            upper_thresh = self.sample_thresh
            certainty = certainty.clone()
            certainty[certainty > upper_thresh] = 1

        # Flatten matches and certainty
        matches = matches.reshape(-1, 4)
        certainty = certainty.reshape(-1)

        expansion_factor = 4 if "balanced" in self.sample_mode else 1
        good_samples = torch.multinomial(
            certainty,
            num_samples=min(expansion_factor * num, len(certainty)),
            replacement=False
        )
        good_matches = matches[good_samples]
        good_certainty = certainty[good_samples]

        if "balanced" not in self.sample_mode:
            return good_matches, good_certainty

        # Use the fast_kde instead of the original:
        density = self.fast_kde(good_matches, std=0.1, half=True, down=None)

        p = 1 / (density + 1)
        p[density < 10] = 1e-7

        balanced_samples = torch.multinomial(
            p,
            num_samples=min(num, len(good_certainty)),
            replacement=False
        )
        return good_matches[balanced_samples], good_certainty[balanced_samples]

    def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
        feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
        if batched:
            f_q_pyramid = {
                scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
            }
            f_s_pyramid = {
                scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
            }
        else:
            f_q_pyramid, f_s_pyramid = feature_pyramid
        corresps = self.decoder(f_q_pyramid,
                                f_s_pyramid,
                                upsample = upsample,
                                **(batch["corresps"] if "corresps" in batch else {}),
                                scale_factor=scale_factor)

        return corresps

    def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
        feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
        f_q_pyramid = feature_pyramid
        f_s_pyramid = {
            scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
            for scale, f_scale in feature_pyramid.items()
        }
        corresps = self.decoder(f_q_pyramid,
                                f_s_pyramid,
                                upsample = upsample,
                                **(batch["corresps"] if "corresps" in batch else {}),
                                scale_factor=scale_factor)
        return corresps

    def conf_from_fb_consistency(self, flow_forward, flow_backward, th = 2):
        # assumes that flow forward is of shape (..., H, W, 2)
        has_batch = False
        if len(flow_forward.shape) == 3:
            flow_forward, flow_backward = flow_forward[None], flow_backward[None]
        else:
            has_batch = True
        H,W = flow_forward.shape[-3:-1]
        th_n = 2 * th / max(H,W)
        coords = torch.stack(torch.meshgrid(
            torch.linspace(-1 + 1 / W, 1 - 1 / W, W),
            torch.linspace(-1 + 1 / H, 1 - 1 / H, H), indexing = "xy"),
                             dim = -1).to(flow_forward.device)
        coords_fb = F.grid_sample(
            flow_backward.permute(0, 3, 1, 2),
            flow_forward,
            align_corners=False, mode="bilinear").permute(0, 2, 3, 1)
        diff = (coords - coords_fb).norm(dim=-1)
        in_th = (diff < th_n).float()
        if not has_batch:
            in_th = in_th[0]
        return in_th

    def to_pixel_coordinates(self, coords, H_A, W_A, H_B = None, W_B = None):
        if coords.shape[-1] == 2:
            return self._to_pixel_coordinates(coords, H_A, W_A)

        if isinstance(coords, (list, tuple)):
            kpts_A, kpts_B = coords[0], coords[1]
        else:
            kpts_A, kpts_B = coords[...,:2], coords[...,2:]
        return self._to_pixel_coordinates(kpts_A, H_A, W_A), self._to_pixel_coordinates(kpts_B, H_B, W_B)

    def _to_pixel_coordinates(self, coords, H, W):
        kpts = torch.stack((W/2 * (coords[...,0]+1), H/2 * (coords[...,1]+1)),axis=-1)
        return kpts

    def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
        if isinstance(coords, (list, tuple)):
            kpts_A, kpts_B = coords[0], coords[1]
        else:
            kpts_A, kpts_B = coords[...,:2], coords[...,2:]
        kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1)
        kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
        return kpts_A, kpts_B

    def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
        x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
        cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
        D = torch.cdist(x_A_to_B, x_B)
        inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)

        if return_tuple:
            if return_inds:
                return inds_A, inds_B
            else:
                return x_A[inds_A], x_B[inds_B]
        else:
            if return_inds:
                return torch.cat((inds_A, inds_B),dim=-1)
            else:
                return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)

    @torch.inference_mode()
    def match(
        self,
        im_A_input,
        im_B_input,
        *args,
        batched=False,
        device=None,
    ):
        if device is None:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Check if inputs are file paths or already loaded images
        if isinstance(im_A_input, (str, os.PathLike)):
            im_A = Image.open(im_A_input)
            check_not_i16(im_A)
            im_A = im_A.convert("RGB")
        else:
            check_rgb(im_A_input)
            im_A = im_A_input

        if isinstance(im_B_input, (str, os.PathLike)):
            im_B = Image.open(im_B_input)
            check_not_i16(im_B)
            im_B = im_B.convert("RGB")
        else:
            check_rgb(im_B_input)
            im_B = im_B_input

        symmetric = self.symmetric
        self.train(False)
        with torch.no_grad():
            if not batched:
                b = 1
                w, h = im_A.size
                w2, h2 = im_B.size
                # Get images in good format
                ws = self.w_resized
                hs = self.h_resized

                test_transform = get_tuple_transform_ops(
                    resize=(hs, ws), normalize=True, clahe=False
                )
                im_A, im_B = test_transform((im_A, im_B))
                batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
            else:
                b, c, h, w = im_A.shape
                b, c, h2, w2 = im_B.shape
                assert w == w2 and h == h2, "For batched images we assume same size"
                batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
                if h != self.h_resized or self.w_resized != w:
                    warn("Model resolution and batch resolution differ, may produce unexpected results")
                hs, ws = h, w
            finest_scale = 1
            # Run matcher
            if symmetric:
                corresps = self.forward_symmetric(batch)
            else:
                corresps = self.forward(batch, batched=True)

            if self.upsample_preds:
                hs, ws = self.upsample_res

            if self.attenuate_cert:
                low_res_certainty = F.interpolate(
                    corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
                )
                cert_clamp = 0
                factor = 0.5
                low_res_certainty = factor * low_res_certainty * (low_res_certainty < cert_clamp)

            if self.upsample_preds:
                finest_corresps = corresps[finest_scale]
                torch.cuda.empty_cache()
                test_transform = get_tuple_transform_ops(
                    resize=(hs, ws), normalize=True
                )
                if isinstance(im_A_input, (str, os.PathLike)):
                    im_A, im_B = test_transform(
                        (Image.open(im_A_input).convert('RGB'), Image.open(im_B_input).convert('RGB')))
                else:
                    im_A, im_B = test_transform((im_A_input, im_B_input))

                im_A, im_B = im_A[None].to(device), im_B[None].to(device)
                scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
                batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
                if symmetric:
                    corresps = self.forward_symmetric(batch, upsample=True, batched=True, scale_factor=scale_factor)
                else:
                    corresps = self.forward(batch, batched=True, upsample=True, scale_factor=scale_factor)

            im_A_to_im_B = corresps[finest_scale]["flow"]
            certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
            if finest_scale != 1:
                im_A_to_im_B = F.interpolate(
                    im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
                )
                certainty = F.interpolate(
                    certainty, size=(hs, ws), align_corners=False, mode="bilinear"
                )
            im_A_to_im_B = im_A_to_im_B.permute(
                0, 2, 3, 1
            )
            # Create im_A meshgrid
            im_A_coords = torch.meshgrid(
                (
                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
                ),
                indexing='ij'
            )
            im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
            im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
            certainty = certainty.sigmoid()  # logits -> probs
            im_A_coords = im_A_coords.permute(0, 2, 3, 1)
            if (im_A_to_im_B.abs() > 1).any() and True:
                wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
                certainty[wrong[:, None]] = 0
            im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
            if symmetric:
                A_to_B, B_to_A = im_A_to_im_B.chunk(2)
                q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
                im_B_coords = im_A_coords
                s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
                warp = torch.cat((q_warp, s_warp), dim=2)
                certainty = torch.cat(certainty.chunk(2), dim=3)
            else:
                warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
            if batched:
                return (
                    warp,
                    certainty[:, 0]
                )
            else:
                return (
                    warp[0],
                    certainty[0, 0],
                )

    def visualize_warp(self, warp, certainty, im_A = None, im_B = None,
                       im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None, unnormalize = False):
        #assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
        H,W2,_ = warp.shape
        W = W2//2 if symmetric else W2
        if im_A is None:
            from PIL import Image
            im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
        if not isinstance(im_A, torch.Tensor):
            im_A = im_A.resize((W,H))
            im_B = im_B.resize((W,H))
            x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
            if symmetric:
                x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
        else:
            if symmetric:
                x_A = im_A
            x_B = im_B
        im_A_transfer_rgb = F.grid_sample(
        x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
        )[0]
        if symmetric:
            im_B_transfer_rgb = F.grid_sample(
            x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
            )[0]
            warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
            white_im = torch.ones((H,2*W),device=device)
        else:
            warp_im = im_A_transfer_rgb
            white_im = torch.ones((H, W), device = device)
        vis_im = certainty * warp_im + (1 - certainty) * white_im
        if save_path is not None:
            from romatch.utils import tensor_to_pil
            tensor_to_pil(vis_im, unnormalize=unnormalize).save(save_path)
        return vis_im


def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs):
    warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
    gp_dim = 512
    feat_dim = 512
    decoder_dim = gp_dim + feat_dim
    cls_to_coord_res = 64
    coordinate_decoder = TransformerDecoder(
        nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]),
        decoder_dim,
        cls_to_coord_res**2 + 1,
        is_classifier=True,
        amp = True,
        pos_enc = False,)
    dw = True
    hidden_blocks = 8
    kernel_size = 5
    displacement_emb = "linear"
    disable_local_corr_grad = True

    conv_refiner = nn.ModuleDict(
        {
            "16": ConvRefiner(
                2 * 512+128+(2*7+1)**2,
                2 * 512+128+(2*7+1)**2,
                2 + 1,
                kernel_size=kernel_size,
                dw=dw,
                hidden_blocks=hidden_blocks,
                displacement_emb=displacement_emb,
                displacement_emb_dim=128,
                local_corr_radius = 7,
                corr_in_other = True,
                amp = True,
                disable_local_corr_grad = disable_local_corr_grad,
                bn_momentum = 0.01,
            ),
            "8": ConvRefiner(
                2 * 512+64+(2*3+1)**2,
                2 * 512+64+(2*3+1)**2,
                2 + 1,
                kernel_size=kernel_size,
                dw=dw,
                hidden_blocks=hidden_blocks,
                displacement_emb=displacement_emb,
                displacement_emb_dim=64,
                local_corr_radius = 3,
                corr_in_other = True,
                amp = True,
                disable_local_corr_grad = disable_local_corr_grad,
                bn_momentum = 0.01,
            ),
            "4": ConvRefiner(
                2 * 256+32+(2*2+1)**2,
                2 * 256+32+(2*2+1)**2,
                2 + 1,
                kernel_size=kernel_size,
                dw=dw,
                hidden_blocks=hidden_blocks,
                displacement_emb=displacement_emb,
                displacement_emb_dim=32,
                local_corr_radius = 2,
                corr_in_other = True,
                amp = True,
                disable_local_corr_grad = disable_local_corr_grad,
                bn_momentum = 0.01,
            ),
            "2": ConvRefiner(
                2 * 64+16,
                128+16,
                2 + 1,
                kernel_size=kernel_size,
                dw=dw,
                hidden_blocks=hidden_blocks,
                displacement_emb=displacement_emb,
                displacement_emb_dim=16,
                amp = True,
                disable_local_corr_grad = disable_local_corr_grad,
                bn_momentum = 0.01,
            ),
            "1": ConvRefiner(
                2 * 9 + 6,
                24,
                2 + 1,
                kernel_size=kernel_size,
                dw=dw,
                hidden_blocks = hidden_blocks,
                displacement_emb = displacement_emb,
                displacement_emb_dim = 6,
                amp = True,
                disable_local_corr_grad = disable_local_corr_grad,
                bn_momentum = 0.01,
            ),
        }
    )
    kernel_temperature = 0.2
    learn_temperature = False
    no_cov = True
    kernel = CosKernel
    only_attention = False
    basis = "fourier"
    gp16 = GP(
        kernel,
        T=kernel_temperature,
        learn_temperature=learn_temperature,
        only_attention=only_attention,
        gp_dim=gp_dim,
        basis=basis,
        no_cov=no_cov,
    )
    gps = nn.ModuleDict({"16": gp16})
    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
    proj = nn.ModuleDict({
        "16": proj16,
        "8": proj8,
        "4": proj4,
        "2": proj2,
        "1": proj1,
        })
    displacement_dropout_p = 0.0
    gm_warp_dropout_p = 0.0
    decoder = Decoder(coordinate_decoder,
                      gps,
                      proj,
                      conv_refiner,
                      detach=True,
                      scales=["16", "8", "4", "2", "1"],
                      displacement_dropout_p = displacement_dropout_p,
                      gm_warp_dropout_p = gm_warp_dropout_p)

    encoder = CNNandDinov2(
        cnn_kwargs = dict(
            pretrained=False,
            amp = True),
        amp = True,
        use_vgg = True,
        dinov2_weights = dinov2_weights,
        amp_dtype=amp_dtype,
    )
    h,w = resolution
    symmetric = True
    attenuate_cert = True
    sample_mode = "threshold_balanced"
    matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds,
                                symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device)
    matcher.load_state_dict(weights)
    return matcher


weight_urls = {
    "romatch": {
        "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
        "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
    },
    "tiny_roma_v1": {
        "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/tiny_roma_v1_outdoor.pth",
    },
    "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
}


def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
    if isinstance(coarse_res, int):
        coarse_res = (coarse_res, coarse_res)
    if isinstance(upsample_res, int):
        upsample_res = (upsample_res, upsample_res)

    if str(device) == 'cpu':
        amp_dtype = torch.float32

    assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
    assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"

    if weights is None:
        weights = torch.hub.load_state_dict_from_url(weight_urls["romatch"]["outdoor"],
                                                     map_location=device)
    if dinov2_weights is None:
        dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
                                                     map_location=device)
    model = roma_model(resolution=coarse_res, upsample_preds=True,
               weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
    model.upsample_res = upsample_res
    print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
    return model


def main():

    # Speed up PyTorch's conv algorithms if input sizes are consistent
    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    TEST_DIR = "/content/dataset/test_images"
    TEST_CSV = "/content/dataset/test.csv"
    SUBMISSION_CSV = "/content/submission.csv"
    WEIGHTS_PTH = "/content/roma_full_weights.pth"

    if not os.path.isfile(TEST_CSV) or not os.path.isdir(TEST_DIR):
        print(f"Test Files not found!")
        return

    test_rows = []
    with open(TEST_CSV, "r", newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            sample_id = row["sample_id"]
            batch_id  = row["batch_id"]
            im1_id    = row["image_1_id"]
            im2_id    = row["image_2_id"]
            test_rows.append((sample_id, batch_id, im1_id, im2_id))
    print(f"Found {len(test_rows)} pairs in {TEST_CSV}.")

    roma_model = roma_outdoor(device=device)

    results = []
    start_inference = time.time()

    try:
        # Use inference_mode for speed
        with torch.inference_mode():
            with tqdm(total=len(test_rows), desc="Estimating F on test pairs") as pbar:
                for (sample_id, scene_name, im1_id, im2_id) in test_rows:
                    try:
                        img1_path = os.path.join(TEST_DIR, scene_name, im1_id + ".jpg")
                        img2_path = os.path.join(TEST_DIR, scene_name, im2_id + ".jpg")

                        # Check existence
                        if not (os.path.isfile(img1_path) and os.path.isfile(img2_path)):
                            F_est = np.zeros((3,3), dtype=np.float64)
                        else:
                            # Load sizes
                            W_A, H_A = Image.open(img1_path).size
                            W_B, H_B = Image.open(img2_path).size

                            # Match with RoMa
                            warp, certainty = roma_model.match(img1_path, img2_path, device=device)

                            # Sample
                            matches, c = roma_model.sample(warp, certainty)

                            # Convert to pixel coords
                            kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
                            kpts1_np = kpts1.cpu().numpy()
                            kpts2_np = kpts2.cpu().numpy()

                            if kpts1_np.shape[0] < 8:
                                # Not enough matches
                                F_est = np.zeros((3,3), dtype=np.float64)

                            else:
                                # 4) Estimate F
                                try:
                                    F_est, mask = cv2.findFundamentalMat(
                                        kpts1_np,
                                        kpts2_np,
                                        ransacReprojThreshold=0.7, # Changed from 0.2 after fine-tuning and optimization
                                        method=cv2.USAC_MAGSAC,
                                        confidence=0.999999,
                                        maxIters=10000
                                    )
                                except cv2.error:
                                    F_est = None
                                if F_est is None or F_est.shape != (3,3):
                                    F_est = np.zeros((3,3), dtype=np.float64)

                        # Flatten for submission
                        F_str = " ".join(f"{val:e}" for val in F_est.flatten())
                        results.append((sample_id, F_str))

                    except Exception as e: # Usually happens when VRAM is full
                        print(f"Error processing pair {sample_id}: {e}")
                        F_est = np.zeros((3,3), dtype=np.float64)
                        F_str = " ".join(f"{val:e}" for val in F_est.flatten())
                        results.append((sample_id, F_str))

                    finally:
                        pbar.update(1)

        end_inference = time.time()
        print(f"Done estimating F for all pairs in {end_inference - start_inference:.2f} seconds.")

        # Save F predictions and model weights

        with open(SUBMISSION_CSV, "w", newline="") as fout:
            writer = csv.writer(fout)
            writer.writerow(["sample_id", "fundamental_matrix"])
            for sample_id, F_str in results:
                writer.writerow([sample_id, F_str])

        print(f"Wrote {len(results)} rows to {SUBMISSION_CSV}.")
        torch.save(roma_model.state_dict(), WEIGHTS_PTH)
        print('Saved mode weights')
        files.download(SUBMISSION_CSV) # Remove if code is not run in Colab

    except Exception as e:
        print(f"An error occurred: {e}")


if __name__ == "__main__":
    main()
