In [35]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
import torch
from PIL import Image
import torch.nn.functional as F
import numpy as np
from RoMa.romatch.utils.utils import tensor_to_pil

from RoMa.romatch import roma_outdoor
from argparse import ArgumentParser

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
if torch.backends.mps.is_available():
    device = torch.device('mps')

parser = ArgumentParser()
parser.add_argument("--im_A_path", default="RoMa/assets/sacre_coeur_A.jpg", type=str)
parser.add_argument("--im_B_path", default="RoMa/assets/sacre_coeur_B.jpg", type=str)
parser.add_argument("--save_path", default="RoMa/demo/roma_warp_sacre_coeur.jpg", type=str)

args, _ = parser.parse_known_args()
im1_path = args.im_A_path
im2_path = args.im_B_path
save_path = args.save_path

# Create model
roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(216, 288))

H, W = roma_model.get_output_resolution()

im1 = Image.open(im1_path).resize((W, H))
im2 = Image.open(im2_path).resize((W, H))

# Match
warp, certainty = roma_model.match(im1_path, im2_path, device=device)
# Sampling not needed, but can be done with model.sample(warp, certainty)
x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)

[32m2025-10-30 16:50:03.406[0m | [1mINFO    [0m | [36mRoMa.romatch.models.model_zoo.roma_models[0m:[36mroma_model[0m:[36m61[0m - [1mUsing coarse resolution (560, 560), and upsample res (216, 288)[0m


In [44]:
print("Image 2 to 1")
print(warp[:,:,:W,2:])
print("Image 1 to 2")
print(warp[:,:,W:,:2])

Image 2 to 1
tensor([[[[-0.6894, -0.2032],
          [-0.6867, -0.2059],
          [-0.6875, -0.2115],
          ...,
          [ 0.9765, -0.4433],
          [ 0.9814, -0.4405],
          [ 0.9841, -0.4375]],

         [[-0.6905, -0.2010],
          [-0.6881, -0.2033],
          [-0.6876, -0.2071],
          ...,
          [ 0.9772, -0.4346],
          [ 0.9808, -0.4320],
          [ 0.9829, -0.4317]],

         [[-0.6904, -0.1982],
          [-0.6884, -0.2011],
          [-0.6886, -0.2051],
          ...,
          [ 0.9787, -0.4246],
          [ 0.9822, -0.4226],
          [ 0.9861, -0.4222]],

         ...,

         [[-0.4797,  0.8542],
          [-0.4758,  0.8540],
          [-0.4713,  0.8530],
          ...,
          [ 0.7370,  0.8057],
          [ 0.7418,  0.8062],
          [ 0.7463,  0.8061]],

         [[-0.4790,  0.8587],
          [-0.4747,  0.8581],
          [-0.4701,  0.8580],
          ...,
          [ 0.7364,  0.8098],
          [ 0.7404,  0.8099],
          [ 0.7436,

In [36]:
grid_im2 = warp[:, :, :W, 2:]
grid_im1 = warp[:, :, W:, :2]
print(f"Shape of grid_im2: {grid_im2.shape}")
print(f"Shape of grid_im1: {grid_im1.shape}")
print(f"Shape of x2[None]: {x2[None].shape}")
print(f"Shape of x1[None]: {x1[None].shape}")
print(f"Spahe of warp: {warp.shape}")

im2_transfer_rgb = F.grid_sample(
x2[None], grid_im2, mode="bilinear", align_corners=False
)[0]
im1_transfer_rgb = F.grid_sample(
x1[None], grid_im1, mode="bilinear", align_corners=False
)[0]
warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
white_im = torch.ones((H,2*W),device=device)
print(f"Shape of warp_im: {warp_im.shape}")
print(f"Shape of white_im: {white_im.shape}")
print(f"Shape of certainty: {certainty.shape}")

Shape of grid_im2: torch.Size([1, 216, 288, 2])
Shape of grid_im1: torch.Size([1, 216, 288, 2])
Shape of x2[None]: torch.Size([1, 3, 216, 288])
Shape of x1[None]: torch.Size([1, 3, 216, 288])
Spahe of warp: torch.Size([1, 216, 576, 4])
Shape of warp_im: torch.Size([3, 216, 576])
Shape of white_im: torch.Size([216, 576])
Shape of certainty: torch.Size([1, 216, 576])


In [37]:
vis_im = certainty * warp_im + (1 - certainty) * white_im
tensor_to_pil(vis_im, unnormalize=False).save(save_path)