In [1]:
import cv2
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt

from utils import plot_circles, imshow
from LightGlue.lightglue import LightGlue, SuperPoint
from LightGlue.lightglue.utils import rbd

# Prepare Image

In [2]:
image_size = (2016, 1512)
img1 = cv2.imread("./images/img1.jpg")
img1 = cv2.resize(img1, image_size)
img2 = cv2.imread("./images/img2.jpg")
img2 = cv2.resize(img2, image_size)

In [3]:
def extractor_preprocess(img):
    img_tensor = torch.from_numpy(img)
    img_tensor = torch.permute(img_tensor, (2, 0, 1)) / 255.0
    return img_tensor

In [4]:
img1_tensor = extractor_preprocess(img1)
img2_tensor = extractor_preprocess(img2)

# Prepare extractor and matcher

In [5]:
extractor = SuperPoint(max_num_keypoints=2048).eval()
matcher = LightGlue(features="superpoint").eval()

# Extraction and Matching

In [6]:
img1_feats = extractor.extract(img1_tensor)
img2_feats = extractor.extract(img2_tensor)

In [7]:
matches = matcher({"image0": img1_feats, "image1": img2_feats})
img1_feats, img2_feats, matches = [rbd(x) for x in [img1_feats, img2_feats, matches]]
matched_points = matches["matches"]
matched_scores = matches["scores"]

In [8]:
matched_points = matched_points[matched_scores >= 0.9]
print(len(matched_points))

213


In [9]:
img1_points = img1_feats["keypoints"][matched_points[..., 0]]
img2_points = img2_feats["keypoints"][matched_points[... ,1]]

In [10]:
def expand_coord(points):
    """
    expand 2d points to 3d
    :param points: shape: (n, 2)
    :return: expanded points: shape (n, 3)
    """
    ones = torch.ones((len(points), 1))
    return torch.concat([points, ones], dim=1)

In [11]:
img1_points_expanded = expand_coord(img1_points)
img2_points_expanded = expand_coord(img2_points)

In [12]:
class homography(nn.Module):
    def __init__(self):
        super(homography, self).__init__()
        self.proj1 = nn.Linear(in_features=3, out_features=1, bias=True)
        self.proj2 = nn.Linear(in_features=3, out_features=1, bias=True)

    def forward(self, x):
        """
        :param x: expanded coordinates, shape (n, 3)
        :return: projected coordinates, shape (n, 3)
        """
        x2 = self.proj1(x)   # (n, 1)
        y2 = self.proj2(x)   # (n, 1)
        out = torch.concat([x2, y2], dim=1)
        return out

In [13]:
homo1 = homography()
optimizer = torch.optim.SGD(homo1.parameters(), lr=0.000001)
criterion = nn.MSELoss()
epochs = 2000
for epoch in range(epochs):
    X_batch, Y_batch = img1_points_expanded, img2_points
    Y_pred = homo1(X_batch)

    optimizer.zero_grad()
    loss = criterion(Y_pred, Y_batch)
    loss.backward()
    optimizer.step()
    if epoch % 100 == 0:
        print(f"Epoch {epoch} / {epochs}: {loss.item()}")

Epoch 0 / 2000: 803076.4375
Epoch 100 / 2000: 10024.0068359375
Epoch 200 / 2000: 10023.873046875
Epoch 300 / 2000: 10023.7431640625
Epoch 400 / 2000: 10023.6142578125
Epoch 500 / 2000: 10023.4833984375
Epoch 600 / 2000: 10023.35546875
Epoch 700 / 2000: 10023.2275390625
Epoch 800 / 2000: 10023.0986328125
Epoch 900 / 2000: 10022.9658203125
Epoch 1000 / 2000: 10022.8359375
Epoch 1100 / 2000: 10022.7099609375
Epoch 1200 / 2000: 10022.5791015625
Epoch 1300 / 2000: 10022.44921875
Epoch 1400 / 2000: 10022.3193359375
Epoch 1500 / 2000: 10022.19140625
Epoch 1600 / 2000: 10022.0625
Epoch 1700 / 2000: 10021.9326171875
Epoch 1800 / 2000: 10021.802734375
Epoch 1900 / 2000: 10021.673828125


In [14]:
imshow("img1", plot_circles(img1.copy(), [(500, 600)]))


In [15]:
homo1(torch.tensor([[500.0, 600.0, 1.0]])).detach().numpy()

array([[680.0786 , 570.46625]], dtype=float32)

In [16]:
imshow("img2", plot_circles(img2.copy(), [[681, 570]]))

In [17]:
Y_pred

tensor([[1006.6027,  289.6958],
        [ 793.7565,  446.4607],
        [ 347.0727,   84.3975],
        [1295.2919,  792.9963],
        [1143.5670,  306.3139],
        [1357.9669,  361.4039],
        [1249.8883,  800.9209],
        [ 984.2383,  247.4771],
        [ 632.2855,   79.9395],
        [ 778.7212,  154.4900],
        [ 855.5292,  159.3019],
        [ 712.6091,  530.3404],
        [ 881.4719,  393.3825],
        [1002.3417,  254.0484],
        [ 911.5109,  204.3621],
        [1097.6925,  288.9935],
        [ 662.4578,   90.8916],
        [1229.2501,  337.0812],
        [ 925.0629,  223.9914],
        [ 712.2010,  508.4614],
        [1284.5411,  345.1154],
        [ 870.1956,  367.2242],
        [1254.5885,  345.9443],
        [ 879.1884,  206.6711],
        [1193.6554,  485.3037],
        [ 698.9315,  503.9791],
        [1098.6106,  273.9480],
        [ 371.7128,  120.0871],
        [ 855.0819,  199.5924],
        [ 377.0175,  147.4211],
        [ 895.7224,  129.0924],
        

In [18]:
Y_batch

tensor([[1006.5156,  276.1093],
        [ 612.7656,  630.4843],
        [  16.2344,  388.3281],
        [1331.3593,  738.7656],
        [1175.8280,  236.7344],
        [1392.3905,  211.1406],
        [1301.8280,  778.1406],
        [1006.5156,  238.7031],
        [ 553.7031,  189.4844],
        [ 756.4843,  219.0156],
        [ 882.4843,  189.4844],
        [ 632.4531,  833.2656],
        [ 650.1718,  480.8593],
        [1026.2031,  238.7031],
        [ 935.6406,  219.0156],
        [1132.5155,  236.7344],
        [ 604.8906,  191.4531],
        [1264.4218,  234.7656],
        [ 933.6718,  236.7344],
        [ 616.7031,  799.7968],
        [1321.5155,  222.9531],
        [ 620.6406,  449.3593],
        [1288.0468,  234.7656],
        [ 876.5781,  236.7344],
        [1047.8594,  415.8906],
        [ 593.0781,  803.7343],
        [1144.3280,  220.9844],
        [  45.7656,  431.6406],
        [ 843.1093,  240.6719],
        [  35.9219,  476.9218],
        [ 973.0468,  138.2969],
        