In [71]:
import tensorflow as tf
import torch
import onnx
from onnx_tf.backend import prepare
from torchsummary import summary


In [72]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time

class BasicLayer(nn.Module):
	"""
	  Basic Convolutional Layer: Conv2d -> BatchNorm -> ReLU
	"""
	def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
		super().__init__()
		self.layer = nn.Sequential(
									  nn.Conv2d( in_channels, out_channels, kernel_size, padding = padding, stride=stride, dilation=dilation, bias = bias),
									  nn.BatchNorm2d(out_channels, affine=False),
									  nn.ReLU(inplace = True),
									)

	def forward(self, x):
	  return self.layer(x)

class XFeatModel(nn.Module):
	"""
	   Implementation of architecture described in 
	   "XFeat: Accelerated Features for Lightweight Image Matching, CVPR 2024."
	"""

	def __init__(self):
		super().__init__()
		self.norm = nn.InstanceNorm2d(1)


		########### ⬇️ CNN Backbone & Heads ⬇️ ###########

		self.skip1 = nn.Sequential(	 nn.AvgPool2d(4, stride = 4),
			  						 nn.Conv2d (1, 24, 1, stride = 1, padding=0) )

		self.block1 = nn.Sequential(
										BasicLayer( 1,  4, stride=1),
										BasicLayer( 4,  8, stride=2),
										BasicLayer( 8,  8, stride=1),
										BasicLayer( 8, 24, stride=2),
									)

		self.block2 = nn.Sequential(
										BasicLayer(24, 24, stride=1),
										BasicLayer(24, 24, stride=1),
									 )

		self.block3 = nn.Sequential(
										BasicLayer(24, 64, stride=2),
										BasicLayer(64, 64, stride=1),
										BasicLayer(64, 64, 1, padding=0),
									 )
		self.block4 = nn.Sequential(
										BasicLayer(64, 64, stride=2),
										BasicLayer(64, 64, stride=1),
										BasicLayer(64, 64, stride=1),
									 )

		self.block5 = nn.Sequential(
										BasicLayer( 64, 128, stride=2),
										BasicLayer(128, 128, stride=1),
										BasicLayer(128, 128, stride=1),
										BasicLayer(128,  64, 1, padding=0),
									 )

		self.block_fusion =  nn.Sequential(
										BasicLayer(64, 64, stride=1),
										BasicLayer(64, 64, stride=1),
										nn.Conv2d (64, 64, 1, padding=0)
									 )

		self.heatmap_head = nn.Sequential(
										BasicLayer(64, 64, 1, padding=0),
										BasicLayer(64, 64, 1, padding=0),
										nn.Conv2d (64, 1, 1),
										nn.Sigmoid()
									)


		self.keypoint_head = nn.Sequential(
										BasicLayer(64, 64, 1, padding=0),
										BasicLayer(64, 64, 1, padding=0),
										BasicLayer(64, 64, 1, padding=0),
										nn.Conv2d (64, 65, 1),
									)


  		########### ⬇️ Fine Matcher MLP ⬇️ ###########

		self.fine_matcher =  nn.Sequential(
											nn.Linear(128, 512),
											nn.BatchNorm1d(512, affine=False),
									  		nn.ReLU(inplace = True),
											nn.Linear(512, 512),
											nn.BatchNorm1d(512, affine=False),
									  		nn.ReLU(inplace = True),
											nn.Linear(512, 512),
											nn.BatchNorm1d(512, affine=False),
									  		nn.ReLU(inplace = True),
											nn.Linear(512, 512),
											nn.BatchNorm1d(512, affine=False),
									  		nn.ReLU(inplace = True),
											nn.Linear(512, 64),
										)

	def _unfold2d(self, x, ws = 2):
		"""
			Unfolds tensor in 2D with desired ws (window size) and concat the channels
		"""
		B, C, H, W = x.shape
		x = x.unfold(2,  ws , ws).unfold(3, ws,ws)                             \
			.reshape(B, C, H//ws, W//ws, ws**2)
		return x.permute(0, 1, 4, 2, 3).reshape(B, -1, H//ws, W//ws)


	def forward(self, x):
		"""
			input:
				x -> torch.Tensor(B, C, H, W) grayscale or rgb images
			return:
				feats     ->  torch.Tensor(B, 64, H/8, W/8) dense local features
				keypoints ->  torch.Tensor(B, 65, H/8, W/8) keypoint logit map
				heatmap   ->  torch.Tensor(B,  1, H/8, W/8) reliability map

		"""
		#dont backprop through normalization
		with torch.no_grad():
			x = x.mean(dim=1, keepdim = True)
			x = self.norm(x)

		#main backbone
		x1 = self.block1(x)
		x2 = self.block2(x1 + self.skip1(x))
		x3 = self.block3(x2)
		x4 = self.block4(x3)
		x5 = self.block5(x4)

		#pyramid fusion
		x4 = F.interpolate(x4, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
		x5 = F.interpolate(x5, (x3.shape[-2], x3.shape[-1]), mode='bilinear')
		feats = self.block_fusion( x3 + x4 + x5 )

		#heads
		heatmap = self.heatmap_head(feats) # Reliability map
		keypoints = self.keypoint_head(self._unfold2d(x, ws=8)) #Keypoint map logits

		return feats, keypoints, heatmap


In [73]:
import numpy as np

dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')
net = XFeatModel().to(dev)
        
net.load_state_dict(torch.load('./weights/xfeat.pt'))

X_test = np.random.randn(20, 800, 608).astype(np.float32)

dummy_input = torch.from_numpy(X_test[0].reshape(1, 1, 800, 608)).float().to(dev)  # Batch size of 1, with an extra channel dimension
dummy_output = net(dummy_input)
print(dummy_output)

# Export to ONNX format
torch.onnx.export(net, dummy_input, './weights/model_simple.onnx', input_names=['test_input'], output_names=['test_output'])

  net.load_state_dict(torch.load('./weights/xfeat.pt'))


(tensor([[[[ 2.0592,  1.7883,  2.2113,  ...,  2.4185,  2.8530,  3.3155],
          [ 2.2933,  2.0149,  3.3666,  ...,  3.6051,  3.7774,  4.2822],
          [ 2.6828,  2.7297,  4.0402,  ...,  3.7708,  4.0118,  4.8345],
          ...,
          [ 2.1845,  3.3963,  4.0702,  ...,  3.0705,  3.2902,  2.9739],
          [ 2.1938,  3.1471,  3.3755,  ...,  3.0994,  3.1083,  2.9355],
          [ 2.3504,  2.9607,  3.6166,  ...,  3.6936,  3.1457,  3.6846]],

         [[-2.0440, -1.8958, -2.1538,  ..., -2.0935, -2.0237, -1.6681],
          [-2.1088, -2.0416, -2.2958,  ..., -2.3279, -2.5747, -1.8126],
          [-2.0275, -1.7703, -1.9649,  ..., -2.8184, -2.9898, -1.8230],
          ...,
          [-1.5847, -1.5397, -1.2872,  ..., -2.1622, -2.1957, -1.9384],
          [-1.9627, -2.0216, -1.7677,  ..., -1.7896, -1.6343, -1.5834],
          [-0.7774, -1.5525, -1.4862,  ..., -0.7645, -0.5759, -0.8638]],

         [[-0.5703, -1.5295, -1.9162,  ..., -1.5360, -0.9567, -1.0760],
          [-1.7525, -3.1745, 



In [74]:
# Load the model and its weights
net = XFeatModel().to(dev)
net.load_state_dict(torch.load('./weights/xfeat.pt'))

# Create a dummy input with the shape (1, 3, 480, 640)
dummy_input = torch.randn(1, 3, 480, 640, device=dev)

# Optionally, you can check the output
dummy_output = net(dummy_input)
print(dummy_output)

# Export the model to ONNX format
torch.onnx.export(
    net, 
    dummy_input, 
    './weights/model_simple.onnx',
    input_names=['test_input'],
    output_names=['test_output'],
    export_params=True,  # Export the trained parameters
    opset_version=12,    # ONNX opset version (you can choose according to your needs)
    do_constant_folding=False,  # Whether to fold constants for optimization
    dynamic_axes={
        'test_input': {0: 'batch_size'},  # Variable batch size
        'test_output': {0: 'batch_size'}  # Variable batch size
    }
)

  net.load_state_dict(torch.load('./weights/xfeat.pt'))


(tensor([[[[ 2.2627,  2.3550,  2.5949,  ...,  2.9316,  3.3071,  3.6963],
          [ 2.6049,  2.7892,  3.4781,  ...,  3.4213,  3.9088,  4.5148],
          [ 2.6001,  2.8053,  3.7432,  ...,  3.1684,  3.2790,  3.7884],
          ...,
          [ 2.3211,  3.1391,  3.8005,  ...,  2.3073,  2.5605,  3.1727],
          [ 1.9682,  2.5246,  3.1685,  ...,  1.8944,  2.5613,  2.8464],
          [ 1.9913,  2.7007,  3.1047,  ...,  2.6669,  2.4257,  2.9477]],

         [[-1.7510, -1.6659, -1.8959,  ..., -1.5952, -1.5452, -1.5763],
          [-1.7307, -2.1544, -2.0905,  ..., -1.8369, -1.9975, -1.9546],
          [-1.7782, -2.0312, -2.3405,  ..., -2.6400, -2.6114, -2.4001],
          ...,
          [-1.0651, -0.9902, -1.9190,  ..., -1.6860, -0.8849, -1.3166],
          [-1.6521, -1.4353, -1.9476,  ..., -2.6026, -1.8172, -1.4598],
          [-0.7969, -0.8434, -1.6171,  ..., -1.6216, -0.6549, -0.6310]],

         [[-0.7820, -1.1914, -1.5915,  ..., -1.4582, -0.6227, -0.7254],
          [-1.7494, -2.1337, 

In [75]:
# Load ONNX model and convert to TensorFlow format
model_onnx = onnx.load('./weights/model_simple.onnx')

tf_rep = prepare(model_onnx)
# Export model as .pb file
tf_rep.export_graph('./weights/model_simple.pb')

INFO:absl:Function `__call__` contains input name(s) x with unsupported characters which will be renamed to transpose_84_x in the SavedModel.
INFO:absl:Found untraced functions such as gen_tensor_dict while saving (showing 1 of 1). These functions will not be directly callable after loading.


INFO:tensorflow:Assets written to: ./weights/model_simple.pb/assets


INFO:tensorflow:Assets written to: ./weights/model_simple.pb/assets
INFO:absl:Writing fingerprint to ./weights/model_simple.pb/fingerprint.pb


In [76]:
model_dir = './weights/model_simple.pb'  # Path to the saved model directory
model = tf.saved_model.load(model_dir)


In [77]:
print("Available functions:")
for func in model.signatures:
    print(func)

Available functions:
serving_default


In [78]:
# Access the default serving function
infer = model.signatures['serving_default']

# Print input and output details
print("Inputs:")
for input_tensor in infer.inputs:
    print(f"{input_tensor.name}: {input_tensor.shape}")

print("Outputs:")
for output_tensor in infer.outputs:
    print(f"{output_tensor.name}: {output_tensor.shape}")

Inputs:
test_input:0: (None, 3, 480, 640)
unknown:0: (4, 1, 3, 3)
unknown_0:0: (4,)
unknown_1:0: (4,)
unknown_2:0: (8, 4, 3, 3)
unknown_3:0: (8,)
unknown_4:0: (8,)
unknown_5:0: (8, 8, 3, 3)
unknown_6:0: (8,)
unknown_7:0: (8,)
unknown_8:0: (24, 8, 3, 3)
unknown_9:0: (24,)
unknown_10:0: (24,)
unknown_11:0: (24, 1, 1, 1)
unknown_12:0: (24,)
unknown_13:0: (24, 24, 3, 3)
unknown_14:0: (24,)
unknown_15:0: (24,)
unknown_16:0: (24, 24, 3, 3)
unknown_17:0: (24,)
unknown_18:0: (24,)
unknown_19:0: (64, 24, 3, 3)
unknown_20:0: (64,)
unknown_21:0: (64,)
unknown_22:0: (64, 64, 3, 3)
unknown_23:0: (64,)
unknown_24:0: (64,)
unknown_25:0: (64, 64, 1, 1)
unknown_26:0: (64,)
unknown_27:0: (64,)
unknown_28:0: (64, 64, 3, 3)
unknown_29:0: (64,)
unknown_30:0: (64,)
unknown_31:0: (64, 64, 3, 3)
unknown_32:0: (64,)
unknown_33:0: (64,)
unknown_34:0: (64, 64, 3, 3)
unknown_35:0: (64,)
unknown_36:0: (64,)
unknown_37:0: (128, 64, 3, 3)
unknown_38:0: (128,)
unknown_39:0: (128,)
unknown_40:0: (128, 128, 3, 3)
unkno

In [79]:
# Get the concrete function from the model
concrete_func = model.signatures['serving_default']

# Print graph details
print("Concrete function details:")
print(concrete_func.graph.as_graph_def())


Concrete function details:
node {
  name: "test_input"
  op: "Placeholder"
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: -1
        }
        dim {
          size: 3
        }
        dim {
          size: 480
        }
        dim {
          size: 640
        }
      }
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
  attr {
    key: "_user_specified_name"
    value {
      s: "test_input"
    }
  }
}
node {
  name: "unknown"
  op: "Placeholder"
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: 4
        }
        dim {
          size: 1
        }
        dim {
          size: 3
        }
        dim {
          size: 3
        }
      }
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_FLOAT
    }
  }
}
node {
  name: "unknown_0"
  op: "Placeholder"
  attr {
    key: "shape"
    value {
      shape {
        dim {
          size: 4
        }
      }
    }
  }
  attr {
    

In [80]:
infer = model.signatures['serving_default']

np.random.seed(42)
# Create a dummy input using NumPy
dummy_input_np = np.random.uniform(low=0.0, high=1.0, size=(1, 3, 480, 640)).astype(np.float32)
# Convert the NumPy array to a PyTorch tensor
dummy_input_tf = tf.convert_to_tensor(dummy_input_np)

#print(dummy_input)
# Run inference
output = infer(dummy_input_tf)

print("Model output:", output)


2024-09-13 23:34:50.323404: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: StatefulPartitionedCall/assert_equal_1/Assert/AssertGuard/branch_executed/_91


Model output: {'1166': <tf.Tensor: shape=(1, 65, 60, 80), dtype=float32, numpy=
array([[[[-1.8495512e-01,  6.3961363e-01, -8.6854255e-01, ...,
          -1.7785213e+00, -1.1314076e-01, -3.9815478e+00],
         [-5.1498395e-01, -1.0131242e+00, -7.4150646e-01, ...,
          -2.0155451e+00, -2.4077456e+00, -2.2520568e+00],
         [-5.9490645e-01,  4.8376939e-01,  2.3018152e-02, ...,
           1.1579248e+00, -2.0847485e+00, -9.8239386e-01],
         ...,
         [-2.1725285e-01, -2.4981592e+00,  4.5723805e-01, ...,
          -1.5516578e+00, -2.4716577e-01, -2.0903511e+00],
         [-1.0920081e+00,  1.4583084e-01, -5.0347838e+00, ...,
          -9.0519607e-02, -2.7775352e+00, -2.0052803e+00],
         [-4.7640264e-02, -2.1327050e+00, -2.7427179e-01, ...,
          -2.6697397e+00, -2.3282218e-01, -1.4435974e+00]],

        [[-5.1657019e+00,  2.8134158e-01,  1.6140938e-03, ...,
          -1.6933240e+00,  7.7473152e-01, -1.2232771e+00],
         [-9.6807170e-01, -2.3879664e+00,  6.96547

In [81]:
x = torch.randn(1,3,480,640)
net = XFeatModel().to(dev)
summary(net, input_size=(3, 480, 640))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
    InstanceNorm2d-1          [-1, 1, 480, 640]               0
            Conv2d-2          [-1, 4, 480, 640]              36
       BatchNorm2d-3          [-1, 4, 480, 640]               0
              ReLU-4          [-1, 4, 480, 640]               0
        BasicLayer-5          [-1, 4, 480, 640]               0
            Conv2d-6          [-1, 8, 240, 320]             288
       BatchNorm2d-7          [-1, 8, 240, 320]               0
              ReLU-8          [-1, 8, 240, 320]               0
        BasicLayer-9          [-1, 8, 240, 320]               0
           Conv2d-10          [-1, 8, 240, 320]             576
      BatchNorm2d-11          [-1, 8, 240, 320]               0
             ReLU-12          [-1, 8, 240, 320]               0
       BasicLayer-13          [-1, 8, 240, 320]               0
           Conv2d-14         [-1, 24, 1

In [82]:
net = XFeatModel().to(dev)
net.load_state_dict(torch.load('./weights/xfeat.pt',weights_only=True))
dummy_input = torch.tensor(dummy_input_np).to(dev)

output = net(dummy_input)
print(output)

Output is a tuple with elements:
Output 0 shape: torch.Size([1, 64, 60, 80])
Output 1 shape: torch.Size([1, 65, 60, 80])
Output 2 shape: torch.Size([1, 1, 60, 80])
(tensor([[[[ 2.0339,  2.0162,  2.4294,  ...,  2.9459,  2.7474,  3.0622],
          [ 2.2621,  2.2193,  3.5295,  ...,  3.5592,  3.6374,  4.0465],
          [ 2.5319,  2.6700,  3.6643,  ...,  3.3428,  3.3196,  3.8285],
          ...,
          [ 2.6655,  3.5698,  3.8302,  ...,  2.4511,  2.7391,  3.1545],
          [ 2.4450,  3.1618,  3.6899,  ...,  3.0513,  3.0640,  2.9057],
          [ 1.8802,  2.9788,  4.0835,  ...,  3.5376,  2.8929,  3.3399]],

         [[-1.7989, -1.6022, -1.8334,  ..., -2.0971, -1.8097, -1.1508],
          [-2.0277, -1.5544, -1.4291,  ..., -2.2789, -2.1056, -1.4540],
          [-1.9461, -1.8042, -1.7633,  ..., -2.5108, -2.5897, -1.8712],
          ...,
          [-0.8095, -1.4622, -1.1222,  ..., -1.7730, -1.3524, -1.5609],
          [-0.9608, -1.8014, -1.6007,  ..., -1.3224, -1.3463, -1.2188],
          [

In [83]:
onnx_model = onnx.load("./weights/model_simple.onnx")  # load onnx model
output = prepare(onnx_model).run(dummy_input_np)  # run the loaded model
print(output)

2024-09-13 23:34:56.346845: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: assert_equal_1/Assert/AssertGuard/branch_executed/_9


Outputs(test_output=array([[[[ 2.79719472e+00,  2.94649315e+00,  3.58056498e+00, ...,
           4.00976849e+00,  3.92117834e+00,  4.97884226e+00],
         [ 2.45618010e+00,  2.55628228e+00,  2.81903362e+00, ...,
           4.61351538e+00,  4.48172665e+00,  5.53187609e+00],
         [ 2.61257505e+00,  2.32753158e+00,  2.86259460e+00, ...,
           4.70882988e+00,  4.75821733e+00,  5.79021931e+00],
         ...,
         [ 1.32151520e+00,  8.84162903e-01,  4.79141951e-01, ...,
           3.19680500e+00,  2.99817538e+00,  3.58570814e+00],
         [ 9.50574100e-01,  6.72974706e-01,  1.10827088e+00, ...,
           2.90448570e+00,  3.12899160e+00,  3.60073423e+00],
         [ 8.10312986e-01,  8.11572671e-01,  6.38979316e-01, ...,
           3.18787384e+00,  3.44675064e+00,  3.58374286e+00]],

        [[-1.20778728e+00, -2.29394913e+00, -1.90345061e+00, ...,
          -1.62944341e+00, -1.19536495e+00,  3.74163985e-02],
         [-8.99906635e-01, -1.85234022e+00, -1.17065918e+00, ...,
  

In [85]:
# Converting a SavedModel to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_saved_model('./weights/model_simple.pb')
converter.target_spec.supported_ops = [
  tf.lite.OpsSet.TFLITE_BUILTINS, # enable LiteRT ops.
  tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()

2024-09-13 23:50:50.028316: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2024-09-13 23:50:50.028331: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2024-09-13 23:50:50.028535: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: ./weights/model_simple.pb
2024-09-13 23:50:50.039322: I tensorflow/cc/saved_model/reader.cc:91] Reading meta graph with tags { serve }
2024-09-13 23:50:50.039343: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: ./weights/model_simple.pb
2024-09-13 23:50:50.073804: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2024-09-13 23:50:50.266555: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: ./weights/model_simple.pb
2024-09-13 23:50:50.450242: I tensorflow/cc/saved_model/loader.cc:314] SavedModel load for tags { serve }; Status

In [86]:
# Save the model.
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

In [88]:
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

# Get input and output tensors details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Step 3: Prepare the dummy input
np.random.seed(42)
dummy_input_np = np.random.uniform(low=0.0, high=1.0, size=(1, 3, 480, 640)).astype(np.float32)  # Shape in HWC format for TFLite
interpreter.set_tensor(input_details[0]['index'], dummy_input_np)

# Step 4: Run the inference
interpreter.invoke()

# Step 5: Get the model output
output = interpreter.get_tensor(output_details[0]['index'])
print("Model output:", output)

Model output: [[[[-1.85004115e-01  6.39580488e-01 -8.68525684e-01 ... -1.77848315e+00
    -1.13055885e-01 -3.98147464e+00]
   [-5.14963090e-01 -1.01311195e+00 -7.41571069e-01 ... -2.01561451e+00
    -2.40777946e+00 -2.25202298e+00]
   [-5.94912291e-01  4.83763605e-01  2.30179727e-02 ...  1.15798008e+00
    -2.08469105e+00 -9.82368588e-01]
   ...
   [-2.17238337e-01 -2.49818254e+00  4.57290083e-01 ... -1.55164504e+00
    -2.47141048e-01 -2.09037662e+00]
   [-1.09198701e+00  1.45878285e-01 -5.03475094e+00 ... -9.05269086e-02
    -2.77753496e+00 -2.00529695e+00]
   [-4.76371348e-02 -2.13268757e+00 -2.74254918e-01 ... -2.66967702e+00
    -2.32838243e-01 -1.44356191e+00]]

  [[-5.16571474e+00  2.81328410e-01  1.63233280e-03 ... -1.69331014e+00
     7.74804950e-01 -1.22330713e+00]
   [-9.68074679e-01 -2.38796091e+00  6.96097016e-02 ... -9.54555035e-01
    -2.14839530e+00 -2.35194969e+00]
   [-2.14650297e+00  5.46005726e-01 -1.13176823e+00 ...  1.62752593e+00
    -2.64638233e+00 -1.72669792e+