In [1]:
import torch

from FingerNetV3.model import ShallowUNet, FingerNetV3

In [13]:
model = FingerNetV3()
model.eval()
dummy_input = torch.randn(1, 1, 512, 512)
with torch.no_grad():
    out = model(dummy_input)

print(out['segmentation'].shape)
[(k, v.shape) for (k, v) in out['minutiae'].items()]

torch.Size([1, 1, 512, 512])


[('mnt_s_score', torch.Size([1, 1, 64, 64])),
 ('mnt_w_score', torch.Size([1, 8, 64, 64])),
 ('mnt_y_score', torch.Size([1, 8, 64, 64])),
 ('mnt_o_score', torch.Size([1, 180, 64, 64]))]

In [None]:
import time

repeats = 10

model = FingerNetV3()
model.eval()
dummy_input = torch.randn(1, 1, 512, 512)
with torch.no_grad():
    # warm-up
    for _ in range(5):
        _ = model(dummy_input)
    start = time.time()
    for _ in range(repeats):
        _ = model(dummy_input)
    elapsed = time.time() - start
fps = repeats / elapsed
print(f"Average FPS on CPU: {fps:.2f}")
print(f"Average Inference Time: {elapsed/repeats:.3f}")

Average FPS on CPU: 2.32
Average Inference Time: 0.432


In [21]:
model.eval()

dummy_input = torch.randn(1, 1, 512, 512)

torch.onnx.export(
    model,
    dummy_input,
    "fingernetv3.onnx",
    export_params=True,
    opset_version=19,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['segmentation', 'score', 'x_offset', 'y_offset', 'angle'],
    dynamic_axes={'input': {0: 'batch_size'}, 'segmentation': {0: 'batch_size'},
                  'score': {0: 'batch_size'}, 'x_offset': {0: 'batch_size'},
                  'y_offset': {0: 'batch_size'}, 'angle': {0: 'batch_size'}}
)

In [29]:
import onnxruntime as ort
import numpy as np
import time

repeats = 10
session = ort.InferenceSession("fingernetv3.onnx", providers=['CPUExecutionProvider'])

input_name = session.get_inputs()[0].name
dummy_input = np.random.randn(1, 1, 512, 512).astype(np.float32)

# warm-up
for _ in range(5):
    _ = session.run(None, {input_name: dummy_input})

start = time.time()
for _ in range(repeats):
    _ = session.run(None, {input_name: dummy_input})
elapsed = time.time() - start

fps = repeats / elapsed
print(f"Average FPS (ONNX Runtime, CPU): {fps:.2f}")
print(f"Average Inference Time: {elapsed/repeats:.3f} sec")


Average FPS (ONNX Runtime, CPU): 3.26
Average Inference Time: 0.306 sec
