In [None]:
import numpy as np
from transformers import AutoImageProcessor, AutoModel
from transformers.image_utils import to_numpy_array
import torch
import plotly.graph_objects as go
from PIL import Image
import spaces
import matplotlib.cm as cm
import time

In [None]:
def process_images(image1, image2):
    """
    Process two images and return a plot of the matching keypoints using matplotlib.
    """
    if image1 is None or image2 is None:
        return None

    images = [image1, image2]
    inputs = processor(images, return_tensors="pt")
    inputs = inputs.to(model.device)
    print(
        "Model is on device: ",
        model.device,
        "and inputs are on device: ",
        inputs["pixel_values"].device,
    )

    with torch.no_grad():
        outputs = model(**inputs)

    image_sizes = [[(image.height, image.width) for image in images]]
    outputs = processor.post_process_keypoint_matching(
        outputs, image_sizes, threshold=0.2
    )
    output = outputs[0]

    image1 = to_numpy_array(image1)
    image2 = to_numpy_array(image2)

    height0, width0 = image1.shape[:2]
    height1, width1 = image2.shape[:2]

    # Create matplotlib figure
    fig, ax = plt.subplots(figsize=(20, 10))
    
    # Display images side by side
    composite_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
    composite_image[:height0, :width0] = image1
    composite_image[:height1, width0:width0+width1] = image2
    
    ax.imshow(composite_image)
    
    # Create colormap
    colormap = cm.RdYlGn
    
    # Get keypoints
    keypoints0_x, keypoints0_y = output["keypoints0"].unbind(1)
    keypoints1_x, keypoints1_y = output["keypoints1"].unbind(1)
    
    # Plot matches
    for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
        keypoints0_x,
        keypoints0_y,
        keypoints1_x,
        keypoints1_y,
        output["matching_scores"],
    ):
        color_val = matching_score.item()
        rgba_color = colormap(color_val)
        
        # Plot the line connecting the keypoints
        ax.plot(
            [keypoint0_x.item(), keypoint1_x.item() + width0],
            [keypoint0_y.item(), keypoint1_y.item()],
            color=rgba_color,
            linewidth=1.5,
            alpha=0.8
        )
        
        # Plot the keypoints
        ax.plot(
            keypoint0_x.item(), 
            keypoint0_y.item(), 
            'o', 
            color=rgba_color, 
            markersize=4
        )
        ax.plot(
            keypoint1_x.item() + width0, 
            keypoint1_y.item(), 
            'o', 
            color=rgba_color, 
            markersize=4
        )
    
    # Set axis properties
    ax.set_xlim(0, width0 + width1)
    ax.set_ylim(max(height0, height1), 0)  # Invert y-axis for image coordinates
    ax.axis('off')
    plt.tight_layout()
    
    return fig


In [None]:
processor = AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint")
model = AutoModel.from_pretrained("ETH-CVG/lightglue_superpoint", device_map="auto")

In [None]:
url_image1 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg"
image1 = Image.open(requests.get(url_image1, stream=True).raw)
url_image2 = "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg"
image2 = Image.open(requests.get(url_image2, stream=True).raw)

In [None]:
processed_fig = process_images(image1, image2)
