In [2]:
import tensorflow as tf
import torch
import onnx
from onnx_tf.backend import prepare
from torchsummary import summary
import numpy as np
import cv2

import numpy as np
import os
import torch
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import tqdm



TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class BasicLayer(nn.Module):
	"""
	  Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
	"""
	def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
		super().__init__()
		self.layer = nn.Sequential(
									  nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
									  nn.BatchNorm2d(out_channels, affine=False),
									  nn.ReLU(inplace = True),
									)

	def forward(self, x):
	  return self.layer(x)

class XFeatModel(nn.Module):
	"""
	   Implementation of architecture described in 
	   "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
	"""

	def __init__(self):
		super().__init__()
		self.norm = nn.InstanceNorm2d(1)


		########### ⬇️ CNN Backbone & Heads ⬇️ ###########

		self.skip1 = nn.Sequential(	 nn.AvgPool2d(4, stride = 4),
			  						 nn.Conv2d (1, 24, 1, stride = 1, padding=0) )

		self.block1 = nn.Sequential(
										BasicLayer( 1,  4, stride=1),
										BasicLayer( 4,  8, stride=2),
										BasicLayer( 8,  8, stride=1),
										BasicLayer( 8, 24, stride=2),
									)

		self.block2 = nn.Sequential(
										BasicLayer(24, 24, stride=1),
										BasicLayer(24, 24, stride=1),
									 )

		self.block3 = nn.Sequential(
										BasicLayer(24, 64, stride=2),
										BasicLayer(64, 64, stride=1),
										BasicLayer(64, 64, 1, padding=0),
									 )
		self.block4 = nn.Sequential(
										BasicLayer(64, 64, stride=2),
										BasicLayer(64, 64, stride=1),
										BasicLayer(64, 64, stride=1),
									 )

		self.block5 = nn.Sequential(
										BasicLayer( 64, 128, stride=2),
										BasicLayer(128, 128, stride=1),
										BasicLayer(128, 128, stride=1),
										BasicLayer(128,  64, 1, padding=0),
									 )

		self.block_fusion =  nn.Sequential(
										BasicLayer(64, 64, stride=1),
										BasicLayer(64, 64, stride=1),
										nn.Conv2d (64, 64, 1, padding=0)
									 )

		self.heatmap_head = nn.Sequential(
										BasicLayer(64, 64, 1, padding=0),
										BasicLayer(64, 64, 1, padding=0),
										nn.Conv2d (64, 1, 1),
										nn.Sigmoid()
									)


		self.keypoint_head = nn.Sequential(
										BasicLayer(64, 64, 1, padding=0),
										BasicLayer(64, 64, 1, padding=0),
										BasicLayer(64, 64, 1, padding=0),
										nn.Conv2d (64, 65, 1),
									)


  		########### ⬇️ Fine Matcher MLP ⬇️ ###########

		self.fine_matcher =  nn.Sequential(
											nn.Linear(128, 512),
											nn.BatchNorm1d(512, affine=False),
									  		nn.ReLU(inplace = True),
											nn.Linear(512, 512),
											nn.BatchNorm1d(512, affine=False),
									  		nn.ReLU(inplace = True),
											nn.Linear(512, 512),
											nn.BatchNorm1d(512, affine=False),
									  		nn.ReLU(inplace = True),
											nn.Linear(512, 512),
											nn.BatchNorm1d(512, affine=False),
									  		nn.ReLU(inplace = True),
											nn.Linear(512, 64),
										)

	def _unfold2d(self, x, ws = 2):
		"""
			Unfolds tensor in 2D with desired ws (window size) and concat the channels
		"""
		B, C, H, W = x.shape
		x = x.unfold(2,  ws , ws).unfold(3, ws,ws)                             \
			.reshape(B, C, H//ws, W//ws, ws**2)
		return x.permute(0, 1, 4, 2, 3).reshape(B, -1, H//ws, W//ws)


	def forward(self, x):
		"""
			input:
				x -> torch.Tensor(B, C, H, W) grayscale or rgb images
			return:
				feats     ->  torch.Tensor(B, 64, H/8, W/8) dense local features
				keypoints ->  torch.Tensor(B, 65, H/8, W/8) keypoint logit map
				heatmap   ->  torch.Tensor(B,  1, H/8, W/8) reliability map

		"""
		#dont backprop through normalization
		with torch.no_grad():
			x = x.mean(dim=1, keepdim = True)
			x = self.norm(x)

		#main backbone
		x1 = self.block1(x)
		x2 = self.block2(x1 + self.skip1(x))
		x3 = self.block3(x2)
		x4 = self.block4(x3)
		x5 = self.block5(x4)

		#pyramid fusion
		x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
		x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
		feats = self.block_fusion( x3 + x4 + x5 )

		#heads
		heatmap = self.heatmap_head(feats) # Reliability map
		keypoints = self.keypoint_head(self._unfold2d(x, ws=8)) #Keypoint map logits

		return feats, keypoints, heatmap


In [4]:
# Load the model and its weights
net = XFeatModel().to(dev)
net.load_state_dict(torch.load('./weights/xfeat.pt'))

# Create a dummy input with the shape (1, 3, 480, 640)
dummy_input = torch.randn(1, 3, 480, 640, device=dev)

# Optionally, you can check the output
dummy_output = net(dummy_input)
print(dummy_output)

# Export the model to ONNX format
torch.onnx.export(
    net, 
    dummy_input, 
    './weights/model_simple.onnx',
    input_names=['test_input'],
    output_names=['test_output'],
    export_params=True,  # Export the trained parameters
    opset_version=12,    # ONNX opset version (you can choose according to your needs)
    do_constant_folding=False,  # Whether to fold constants for optimization
    dynamic_axes={
        'test_input': {0: 'batch_size'},  # Variable batch size
        'test_output': {0: 'batch_size'}  # Variable batch size
    }
)

  net.load_state_dict(torch.load('./weights/xfeat.pt'))


(tensor([[[[ 2.0429,  2.0751,  2.7110,  ...,  2.5637,  2.8497,  3.1169],
          [ 2.2302,  2.3369,  4.0044,  ...,  3.3140,  3.7896,  4.1522],
          [ 2.7945,  3.3579,  4.4605,  ...,  3.1975,  3.9799,  4.5383],
          ...,
          [ 2.3023,  3.2868,  3.9855,  ...,  1.4351,  2.0857,  1.9563],
          [ 2.0496,  2.3789,  3.6429,  ...,  1.4296,  1.8734,  1.5088],
          [ 2.0182,  2.7893,  3.6996,  ...,  0.9894,  1.6685,  1.8546]],

         [[-1.4128, -2.1542, -1.9796,  ..., -2.0406, -1.7984, -1.4717],
          [-1.8793, -1.5273, -1.2732,  ..., -2.3919, -2.2580, -1.6899],
          [-1.8530, -1.5059, -1.5702,  ..., -2.3976, -2.7612, -1.6402],
          ...,
          [-1.7468, -2.2380, -1.7260,  ..., -0.9426, -1.5519, -1.0316],
          [-2.4974, -2.7083, -2.4247,  ..., -2.1692, -1.8448, -1.4998],
          [-1.2959, -2.3545, -2.1141,  ..., -1.7102, -0.7927, -0.8185]],

         [[-0.3779, -0.8449, -1.2185,  ..., -1.2664, -0.6083, -1.0680],
          [-1.5750, -2.8422, 

In [17]:
# Load ONNX model and convert to TensorFlow format
model_onnx = onnx.load('./weights/xfeat.onnx')

tf_rep = prepare(model_onnx)
# Export model as .pb file
tf_rep.export_graph('./weights/xfeat.pb')

INFO:absl:Function `__call__` contains input name(s) x, y, tensor with unsupported characters which will be renamed to transpose_348_x, mul_63_y, reshape_47_tensor in the SavedModel.
INFO:absl:Found untraced functions such as gen_tensor_dict while saving (showing 1 of 1). These functions will not be directly callable after loading.


INFO:tensorflow:Assets written to: ./weights/xfeat.pb/assets


INFO:tensorflow:Assets written to: ./weights/xfeat.pb/assets
INFO:absl:Writing fingerprint to ./weights/xfeat.pb/fingerprint.pb


In [25]:
model_dir = './weights/xfeat.pb'  # Path to the saved model directory
model = tf.saved_model.load(model_dir)


In [26]:
print("Available functions:")
for func in model.signatures:
    print(func)

Available functions:
serving_default


In [27]:
# Access the default serving function
infer = model.signatures['serving_default']

# Print input and output details
print("Inputs:")
for input_tensor in infer.inputs:
    print(f"{input_tensor.name}: {input_tensor.shape}")

print("Outputs:")
for output_tensor in infer.outputs:
    print(f"{output_tensor.name}: {output_tensor.shape}")

Inputs:
images0:0: (None, 3, None, None)
images1:0: (None, 3, None, None)
unknown:0: (4,)
unknown_0:0: (4,)
unknown_1:0: ()
unknown_2:0: ()
unknown_3:0: ()
unknown_4:0: (1,)
unknown_5:0: (1,)
unknown_6:0: (1,)
unknown_7:0: (1,)
unknown_8:0: (4, 1, 3, 3)
unknown_9:0: (4,)
unknown_10:0: (8, 4, 3, 3)
unknown_11:0: (8,)
unknown_12:0: (8, 8, 3, 3)
unknown_13:0: (8,)
unknown_14:0: (24, 8, 3, 3)
unknown_15:0: (24,)
unknown_16:0: (24, 1, 1, 1)
unknown_17:0: (24,)
unknown_18:0: (24, 24, 3, 3)
unknown_19:0: (24,)
unknown_20:0: (24, 24, 3, 3)
unknown_21:0: (24,)
unknown_22:0: (64, 24, 3, 3)
unknown_23:0: (64,)
unknown_24:0: (64, 64, 3, 3)
unknown_25:0: (64,)
unknown_26:0: (64, 64, 1, 1)
unknown_27:0: (64,)
unknown_28:0: (64, 64, 3, 3)
unknown_29:0: (64,)
unknown_30:0: (64, 64, 3, 3)
unknown_31:0: (64,)
unknown_32:0: (64, 64, 3, 3)
unknown_33:0: (64,)
unknown_34:0: (128, 64, 3, 3)
unknown_35:0: (128,)
unknown_36:0: (128, 128, 3, 3)
unknown_37:0: (128,)
unknown_38:0: (128, 128, 3, 3)
unknown_39:0: 

In [7]:
# Get the concrete function from the model
concrete_func = model.signatures['serving_default']

# Print graph details
print("Concrete function details:")
print(concrete_func.graph.as_graph_def())


Concrete function details:
node {
  name: "images0"
  op: "Placeholder"
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: -1
        }
        dim {
          size: 3
        }
        dim {
          size: -1
        }
        dim {
          size: -1
        }
      }
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "_user_specified_name"
    value {
      s: "images0"
    }
  }
}
node {
  name: "images1"
  op: "Placeholder"
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: -1
        }
        dim {
          size: 3
        }
        dim {
          size: -1
        }
        dim {
          size: -1
        }
      }
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "_user_specified_name"
    value {
      s: "images1"
    }
  }
}
node {
  name: "unknown"
  op: "Placeholder"
  attr {
    key: "shape"
    value {
      shape 

In [16]:
import tensorflow as tf
import numpy as np

# Load SavedModel
model = tf.saved_model.load('./weights/xfeat.pb')

# Get the inference function from the model
infer = model.signatures["serving_default"]

# Prepare the inputs
input_array_1 = im1.transpose(2, 0, 1).astype(np.float32)
input_array_1 = np.expand_dims(input_array_1, axis=0)
input_array_2 = im2.transpose(2, 0, 1).astype(np.float32)
input_array_2 = np.expand_dims(input_array_2, axis=0)

# Create batch
batch_size = 1
input_array_1 = np.concatenate([input_array_1 for _ in range(batch_size)], axis=0)
input_array_2 = np.concatenate([input_array_2 for _ in range(batch_size)], axis=0)

# Feed inputs to the model
inputs = {
    "input_1": tf.convert_to_tensor(input_array_1),  # Adjust name based on signature
    "input_2": tf.convert_to_tensor(input_array_2)   # Adjust name based on signature
}

# Run inference
output = infer(**inputs)

# Extract and print the output
output_tensor = output['output'].numpy()  # Adjust name based on signature
print("Model output:", output_tensor)


(480, 640, 3)
(480, 640, 3)
(1, 3, 480, 640)


TypeError: Binding inputs to tf.function `signature_wrapper` failed due to `too many positional arguments`. Received args: (array([[[[ 12.,  12.,  15., ...,  52.,  55.,  58.],
         [ 11.,  11.,  14., ...,  63.,  62.,  56.],
         [ 11.,  11.,  14., ...,  55.,  55.,  51.],
         ...,
         [ 97.,  93.,  95., ...,  47.,  48.,  52.],
         [ 77.,  80.,  82., ...,  45.,  45.,  49.],
         [ 64.,  66.,  66., ...,  45.,  46.,  48.]],

        [[ 14.,  15.,  19., ...,  56.,  58.,  61.],
         [ 14.,  15.,  20., ...,  68.,  67.,  60.],
         [ 14.,  15.,  21., ...,  60.,  61.,  56.],
         ...,
         [117., 117., 119., ...,  56.,  57.,  58.],
         [100., 104., 107., ...,  53.,  53.,  55.],
         [ 84.,  91.,  91., ...,  52.,  51.,  52.]],

        [[ 15.,  18.,  24., ...,  55.,  58.,  60.],
         [ 17.,  19.,  27., ...,  67.,  66.,  59.],
         [ 18.,  20.,  29., ...,  59.,  60.,  55.],
         ...,
         [123., 126., 128., ...,  66.,  67.,  67.],
         [107., 114., 117., ...,  63.,  63.,  62.],
         [ 89.,  99., 100., ...,  60.,  58.,  57.]]]], dtype=float32),) and kwargs: {} for signature: (*, images0: TensorSpec(shape=(None, 3, None, None), dtype=tf.float32, name='images0'), images1: TensorSpec(shape=(None, 3, None, None), dtype=tf.float32, name='images1')).

In [91]:
x = torch.randn(1,3,480,640)
dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')
net = XFeatModel().to(dev)
summary(net, input_size=(3, 480, 640))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
    InstanceNorm2d-1          [-1, 1, 480, 640]               0
            Conv2d-2          [-1, 4, 480, 640]              36
       BatchNorm2d-3          [-1, 4, 480, 640]               0
              ReLU-4          [-1, 4, 480, 640]               0
        BasicLayer-5          [-1, 4, 480, 640]               0
            Conv2d-6          [-1, 8, 240, 320]             288
       BatchNorm2d-7          [-1, 8, 240, 320]               0
              ReLU-8          [-1, 8, 240, 320]               0
        BasicLayer-9          [-1, 8, 240, 320]               0
           Conv2d-10          [-1, 8, 240, 320]             576
      BatchNorm2d-11          [-1, 8, 240, 320]               0
             ReLU-12          [-1, 8, 240, 320]               0
       BasicLayer-13          [-1, 8, 240, 320]               0
           Conv2d-14         [-1, 24, 1

In [101]:
net = XFeatModel().to(dev)
net.load_state_dict(torch.load('./weights/xfeat.pt',weights_only=True))
dummy_input = torch.tensor(dummy_input_np).to(dev)

output = net(dummy_input)
print(output)

(tensor([[[[ 2.0339,  2.0162,  2.4294,  ...,  2.9459,  2.7474,  3.0622],
          [ 2.2621,  2.2193,  3.5295,  ...,  3.5592,  3.6374,  4.0465],
          [ 2.5319,  2.6700,  3.6643,  ...,  3.3428,  3.3196,  3.8285],
          ...,
          [ 2.6655,  3.5698,  3.8302,  ...,  2.4511,  2.7391,  3.1545],
          [ 2.4450,  3.1618,  3.6899,  ...,  3.0513,  3.0640,  2.9057],
          [ 1.8802,  2.9788,  4.0835,  ...,  3.5376,  2.8929,  3.3399]],

         [[-1.7989, -1.6022, -1.8334,  ..., -2.0971, -1.8097, -1.1508],
          [-2.0277, -1.5544, -1.4291,  ..., -2.2789, -2.1056, -1.4540],
          [-1.9461, -1.8042, -1.7633,  ..., -2.5108, -2.5897, -1.8712],
          ...,
          [-0.8095, -1.4622, -1.1222,  ..., -1.7730, -1.3524, -1.5609],
          [-0.9608, -1.8014, -1.6007,  ..., -1.3224, -1.3463, -1.2188],
          [-0.6025, -1.3109, -1.1311,  ..., -0.2823, -0.2996, -0.3582]],

         [[-0.7027, -1.1888, -1.5610,  ..., -0.7867, -0.3383, -0.9603],
          [-1.6852, -2.7409, 

In [38]:
onnx_model = onnx.load("./weights/xfeat.onnx")  # load onnx model
output = prepare(onnx_model).run(dummy_input_np)  # run the loaded model
print(output)

2024-09-19 14:52:45.380244: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: assert_equal_9/Assert/AssertGuard/branch_executed/_49


Outputs(feats=array([[[[ 2.782367  ,  2.8502855 ,  3.1972933 , ...,  3.5725322 ,
           3.3227463 ,  4.157421  ],
         [ 2.1695764 ,  2.318455  ,  2.359754  , ...,  3.2151914 ,
           3.2509837 ,  4.336571  ],
         [ 2.3427782 ,  1.9400274 ,  2.4111996 , ...,  3.514491  ,
           3.4012427 ,  4.3216267 ],
         ...,
         [ 1.380759  ,  0.5372354 ,  0.12256491, ...,  2.6955793 ,
           2.81711   ,  3.513661  ],
         [ 1.2264051 ,  0.48460555,  0.6335151 , ...,  2.8777127 ,
           3.079503  ,  3.6800308 ],
         [ 1.2445734 ,  0.72885823,  0.30385983, ...,  3.0515566 ,
           3.337317  ,  3.796597  ]],

        [[-0.88819695, -2.183567  , -1.664429  , ..., -0.88600373,
          -0.6753097 ,  0.21018314],
         [-0.8816378 , -2.1159673 , -1.1256247 , ..., -1.7265385 ,
          -1.265811  , -0.39530993],
         [-1.4451569 , -2.9287925 , -1.3577006 , ..., -2.0345795 ,
          -1.0804574 , -0.6398294 ],
         ...,
         [-0.1007826

In [51]:
# Converting a SavedModel to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_saved_model('./weights/xfeat.pb')
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable LiteRT ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()

2024-09-19 12:53:04.282401: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2024-09-19 12:53:04.282419: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2024-09-19 12:53:04.282586: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: ./weights/xfeat.pb
2024-09-19 12:53:04.295811: I tensorflow/cc/saved_model/reader.cc:91] Reading meta graph with tags { serve }
2024-09-19 12:53:04.295831: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: ./weights/xfeat.pb
2024-09-19 12:53:04.332105: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2024-09-19 12:53:04.559495: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: ./weights/xfeat.pb
2024-09-19 12:53:04.780204: I tensorflow/cc/saved_model/loader.cc:314] SavedModel load for tags { serve }; Status: success: OK. Took 4

In [52]:
# Save the model.
with open('xfeat.tflite', 'wb') as f:
  f.write(tflite_model)

In [56]:
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

# Get input and output tensors details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Step 3: Prepare the dummy input
#np.random.seed(42)
#dummy_input_np = np.random.uniform(low=0.0, high=1.0, size=(1, 3, 480, 640)).astype(np.float32)  # Shape in HWC format for TFLite
interpreter.set_tensor(input_details[0]['index'], dummy_input_np)

# Step 4: Run the inference
interpreter.invoke()

# Step 5: Get the model output
output = interpreter.get_tensor(output_details[0]['index'])
print("Model output:", output)

Model output: [[[[-0.8923591   0.45524266 -0.35286513 ... -1.3290024   0.7108636
    -3.2854712 ]
   [-0.37151223 -0.8519242  -0.7518886  ... -2.6276107  -2.3510516
    -2.095671  ]
   [-0.9914355   0.5469847  -0.90970635 ...  1.42835    -1.8041741
    -0.9023521 ]
   ...
   [ 0.08948848 -3.2772434   1.1196729  ... -1.5692592   0.3214281
    -2.4999745 ]
   [-0.9280126   0.5464524  -4.7493234  ... -0.0358384  -2.7542782
    -2.113748  ]
   [ 0.02598569 -1.9881015  -0.29764718 ... -2.7359507  -0.1029776
    -0.9546582 ]]

  [[-5.174063   -0.25325716  0.3866041  ... -1.3457057   1.7016315
    -1.4441957 ]
   [-1.2119255  -2.1607335   0.35146227 ... -1.4366907  -2.1652384
    -2.2140827 ]
   [-2.831359    0.20584694 -0.81023866 ...  2.0288918  -2.8721623
    -0.7928697 ]
   ...
   [-1.3220953  -2.310955   -0.39477628 ... -0.7281338  -0.12125605
    -1.8855897 ]
   [-1.2102162  -1.5876826  -3.905679   ... -2.001757   -3.3585353
     0.41321608]
   [-0.4862812  -1.3204458   1.1829753  ... -

In [48]:
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [49]:
def warp_corners_and_draw_matches(ref_points, dst_points, img1, img2):
    # Calculate the Homography matrix
    H, mask = cv2.findHomography(ref_points, dst_points, cv2.USAC_MAGSAC, 3.5, maxIters=1_000, confidence=0.999)
    mask = mask.flatten()

    # Get corners of the first image (image1)
    h, w = img1.shape[:2]
    corners_img1 = np.array([[0, 0], [w-1, 0], [w-1, h-1], [0, h-1]], dtype=np.float32).reshape(-1, 1, 2)

    # Warp corners to the second image (image2) space
    warped_corners = cv2.perspectiveTransform(corners_img1, H)

    # Draw the warped corners in image2
    img2_with_corners = img2.copy()
    for i in range(len(warped_corners)):
        start_point = tuple(warped_corners[i-1][0].astype(int))
        end_point = tuple(warped_corners[i][0].astype(int))
        cv2.line(img2_with_corners, start_point, end_point, (0, 255, 0), 4)  # Using solid green for corners

    # Prepare keypoints and matches for drawMatches function
    keypoints1 = [cv2.KeyPoint(p[0], p[1], 5) for p in ref_points]
    keypoints2 = [cv2.KeyPoint(p[0], p[1], 5) for p in dst_points]
    matches = [cv2.DMatch(i,i,0) for i in range(len(mask)) if mask[i]]

    # Draw inlier matches
    img_matches = cv2.drawMatches(img1, keypoints1, img2_with_corners, keypoints2, matches, None,
                                  matchColor=(0, 255, 0), flags=2)

    return img_matches


In [50]:
class InterpolateSparse2d(nn.Module):
    """ Efficiently interpolate tensor at given sparse 2D positions. """ 
    def __init__(self, mode = 'bicubic', align_corners = False): 
        super().__init__()
        self.mode = mode
        self.align_corners = align_corners

    def normgrid(self, x, H, W):
        """ Normalize coords to [-1,1]. """
        return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1.

    def forward(self, x, pos, H, W):
        """
        Input
            x: [B, C, H, W] feature tensor
            pos: [B, N, 2] tensor of positions
            H, W: int, original resolution of input 2d positions -- used in normalization [-1,1]

        Returns
            [B, N, C] sampled channels at 2d positions
        """
        grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype)
        x = F.grid_sample(x, grid, mode = self.mode , align_corners = False)
        return x.permute(0,2,3,1).squeeze(-2)

In [58]:
def preprocess_tensor(x):
		""" Guarantee that image is divisible by 32 to avoid aliasing artifacts. """
		if isinstance(x, np.ndarray) and len(x.shape) == 3:
			x = torch.tensor(x).permute(2,0,1)[None]
		x = x.to(dev).float()

		H, W = x.shape[-2:]
		_H, _W = (H//32) * 32, (W//32) * 32
		rh, rw = H/_H, W/_W

		x = F.interpolate(x, (_H, _W), mode='bilinear', align_corners=False)
		return x, rh, rw

In [52]:
def get_kpts_heatmap( kpts, softmax_temp = 1.0):
	scores = F.softmax(kpts*softmax_temp, 1)[:, :64]
	B, _, H, W = scores.shape
	heatmap = scores.permute(0, 2, 3, 1).reshape(B, H, W, 8, 8)
	heatmap = heatmap.permute(0, 1, 3, 2, 4).reshape(B, 1, H*8, W*8)
	return heatmap

def NMS( x, threshold = 0.05, kernel_size = 5):
	B, _, H, W = x.shape
	pad=kernel_size//2
	local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
	pos = (x == local_max) & (x > threshold)
	pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos]

	pad_val = max([len(x) for x in pos_batched])
	pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device)

	#Pad kpts and build (B, N, 2) tensor
	for b in range(len(pos_batched)):
		pos[b, :len(pos_batched[b]), :] = pos_batched[b]

	return pos

In [95]:
def detectAndCompute(x, top_k = None, detection_threshold = None):
		"""
			Compute sparse keypoints & descriptors. Supports batched mode.

			input:
				x -> torch.Tensor(B, C, H, W): grayscale or rgb image
				top_k -> int: keep best k features
			return:
				List[Dict]: 
					'keypoints'    ->   torch.Tensor(N, 2): keypoints (x,y)
					'scores'       ->   torch.Tensor(N,): keypoint scores
					'descriptors'  ->   torch.Tensor(N, 64): local features
		"""
		if top_k is None: top_k = top_k
		if detection_threshold is None: detection_threshold = detection_threshold
		x, rh1, rw1 = preprocess_tensor(x)

		B, _, _H1, _W1 = x.shape
		infer = model.signatures['serving_default']
		#output = infer(dummy_input_tf)
		MT,KT,HT = net(x)
		print(KT)
		x_np= x.numpy()
		x_tf = tf.convert_to_tensor(x_np)
		print("Before inference")
		#M1, K1, H1 = infer(x_tf)
		output = infer(x_tf)
		#print(output)
		print("After inference")
		M1_np = output['feats'].numpy()
		M1 = torch.from_numpy(M1_np)
		K1_np = output['heatmap'].numpy()
		print(K1_np)
		K1 = torch.from_numpy(K1_np)
		H1_np = output['keypoints'].numpy()
		H1 = torch.from_numpy(H1_np)
		M1 = F.normalize(M1, dim=1)

		#Convert logits to heatmap and extract kpts
		K1h = K1
		mkpts = NMS(K1h, threshold=detection_threshold, kernel_size=5)

		#Compute reliability scores
		_nearest = InterpolateSparse2d('nearest')
		_bilinear = InterpolateSparse2d('bilinear')
		scores = (_nearest(K1h, mkpts, _H1, _W1) * _bilinear(H1, mkpts, _H1, _W1)).squeeze(-1)
		scores[torch.all(mkpts == 0, dim=-1)] = -1

		#Select top-k features
		idxs = torch.argsort(-scores)
		mkpts_x  = torch.gather(mkpts[...,0], -1, idxs)[:, :top_k]
		mkpts_y  = torch.gather(mkpts[...,1], -1, idxs)[:, :top_k]
		mkpts = torch.cat([mkpts_x[...,None], mkpts_y[...,None]], dim=-1)
		scores = torch.gather(scores, -1, idxs)[:, :top_k]
		interpolator = InterpolateSparse2d('bicubic')
		#Interpolate descriptors at kpts positions
		feats = interpolator(M1, mkpts, H = _H1, W = _W1)

		#L2-Normalize
		feats = F.normalize(feats, dim=-1)

		#Correct kpt scale
		mkpts = mkpts * torch.tensor([rw1,rh1], device=mkpts.device).view(1, 1, -1)

		valid = scores > 0
		return [  
				   {'keypoints': mkpts[b][valid[b]],
					'scores': scores[b][valid[b]],
					'descriptors': feats[b][valid[b]]} for b in range(B) 
			   ]


In [54]:
def parse_input( x):
		if len(x.shape) == 3:
			x = x[None, ...]

		if isinstance(x, np.ndarray):
			x = torch.tensor(x).permute(0,3,1,2)/255

		return x

In [55]:
def match( feats1, feats2, min_cossim = 0.82):

		cossim = feats1 @ feats2.t()
		cossim_t = feats2 @ feats1.t()
		
		_, match12 = cossim.max(dim=1)
		_, match21 = cossim_t.max(dim=1)

		idx0 = torch.arange(len(match12), device=match12.device)
		mutual = match21[match12] == idx0

		if min_cossim > 0:
			cossim, _ = cossim.max(dim=1)
			good = cossim > min_cossim
			idx0 = idx0[mutual & good]
			idx1 = match12[mutual & good]
		else:
			idx0 = idx0[mutual]
			idx1 = match12[mutual]

		return idx0, idx1

In [69]:
def match_xfeat(img1, img2, top_k = None, min_cossim = -1):
		"""
			Simple extractor and MNN matcher.
			For simplicity it does not support batched mode due to possibly different number of kpts.
			input:
				img1 -> torch.Tensor (1,C,H,W) or np.ndarray (H,W,C): grayscale or rgb image.
				img2 -> torch.Tensor (1,C,H,W) or np.ndarray (H,W,C): grayscale or rgb image.
				top_k -> int: keep best k features
			returns:
				mkpts_0, mkpts_1 -> np.ndarray (N,2) xy coordinate matches from image1 to image2
		"""
		if top_k is None: top_k = top_k
		img1 = parse_input(img1)
		img2 = parse_input(img2)
		print("Before detectAndCompute")
		out1 = detectAndCompute(img1, top_k=top_k)[0]
		out2 = detectAndCompute(img2, top_k=top_k)[0]

		idxs0, idxs1 = match(out1['descriptors'], out2['descriptors'], min_cossim=min_cossim )

		return out1['keypoints'][idxs0].cpu().numpy(), out2['keypoints'][idxs1].cpu().numpy()


In [96]:
import matplotlib.pyplot as plt
#Load some example images
im1 = cv2.imread('./assets/ref2.png')
im2 = cv2.imread('./assets/tgt2.png')

mkpts_0, mkpts_1 = match_xfeat(im1, im2, top_k = 4096)

canvas = warp_corners_and_draw_matches(mkpts_0, mkpts_1, im1, im2)
plt.figure(figsize=(12,12))
plt.imshow(canvas[..., ::-1]), plt.show()

Before detectAndCompute
tensor([[[[-3.8992e+00, -1.2332e-01, -4.6616e+00,  ..., -2.5235e+00,
           -5.2348e+00, -4.2108e+00],
          [-2.0392e+00,  3.8189e-01, -7.9618e-01,  ..., -1.8770e+00,
           -3.4422e+00, -4.0921e+00],
          [ 7.7332e-02,  1.4201e+00, -2.6118e+00,  ..., -3.3361e+00,
           -4.9777e+00, -6.4316e+00],
          ...,
          [-4.2194e+00, -3.4710e+00, -4.6246e+00,  ..., -6.3243e+00,
           -2.5982e+00, -1.5106e+00],
          [-3.8231e+00, -4.7220e+00, -2.6725e+00,  ..., -3.7552e+00,
           -5.6677e+00, -3.8638e+00],
          [-5.7435e+00, -5.6427e+00, -7.1438e+00,  ..., -4.0333e+00,
           -4.2356e+00, -4.7967e+00]],

         [[-1.9709e-01, -1.3329e+00, -3.0519e+00,  ..., -3.7070e+00,
           -6.7399e+00, -3.0513e+00],
          [-2.3295e-01,  6.8636e-02, -7.7883e-02,  ..., -2.2640e+00,
           -4.0410e+00, -3.8703e+00],
          [-4.2295e-01, -5.8513e-02, -3.9685e+00,  ..., -4.3035e+00,
           -5.1683e+00, -7.0628e+0

TypeError: '>' not supported between instances of 'Tensor' and 'NoneType'