In [1]:
import pandas as pd
import numpy as np
import timm
import torchvision.transforms as T

from wildlife_tools.features import SuperPointExtractor
from wildlife_tools.similarity import MatchLOFTR, MatchLightGlue
from wildlife_tools.similarity.pairwise.collectors import CollectCounts, CollectCountsRansac, CollectAll
from wildlife_tools.data import ImageDataset

metadata = {'metadata':  pd.read_csv('../tests/TestDataset/metadata.csv'), 'root': '../tests/TestDataset'}
transform = T.Compose([T.Resize([224, 224]), T.ToTensor()])
dataset = ImageDataset(**metadata, transform=transform)

# LightGlue

In [5]:
extractor = SuperPointExtractor()
dataset.transform = T.Compose([T.Resize([224, 224]), T.ToTensor()])

features_query = extractor(dataset)
features_database = extractor(dataset)

100%|█████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 10.26it/s]
100%|█████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 16.27it/s]


### Default SuperPoint + LightGlue matching score
- Default score is number of significant (confidence > 0.5) correspondences.

In [6]:
matcher = MatchLightGlue(features='superpoint')
output = matcher(features_query, features_database)
output

100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 22.73it/s]


array([[197.,   0.,   2.],
       [  0., 239.,   0.],
       [  2.,   0., 203.]], dtype=float16)

### Collect scores from multiple thresholds
- Suitable for calculation of scores from multiple thresholds at once.

In [7]:
collector = CollectCounts(thresholds=[0.2, 0.5, 0.8])
matcher = MatchLightGlue(features='superpoint', collector=collector)
output = matcher(features_query, features_database)
output

100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 31.84it/s]


{0.2: array([[228.,   2.,  16.],
        [  2., 247.,   0.],
        [ 16.,   0., 236.]], dtype=float16),
 0.5: array([[197.,   0.,   2.],
        [  0., 239.,   0.],
        [  2.,   0., 203.]], dtype=float16),
 0.8: array([[156.,   0.,   1.],
        [  0., 202.,   0.],
        [  1.,   0., 121.]], dtype=float16)}

### Collect after RANSAC filtering

In [8]:
collector = CollectCountsRansac(ransacReprojThreshold=1.0, maxIters=100)
matcher = MatchLightGlue(features='superpoint', collector=collector)
output = matcher(features_query, features_database)
output

100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 26.46it/s]


array([[238.,   0.,  22.],
       [  0., 248.,   0.],
       [ 22.,   0., 247.]], dtype=float16)

## Collect all
- Collect all SuperPoint keypoint and descriptors as list of dict

In [9]:
collector = CollectAll()
matcher = MatchLightGlue(features='superpoint', collector=collector)
output = matcher(features_query, features_database)
print(len(output)) # = len(query) x len(database) 
output

100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 31.63it/s]

9





[{'idx0': 0,
  'idx1': 0,
  'kpts0': array([[119.5, 122.5],
         [131.5,  77.5],
         [115.5, 143.5],
         [ 82.5, 156.5],
         [ 95.5, 157.5],
         [156.5, 156.5],
         [ 76.5, 101.5],
         [ 79.5,  85.5],
         [164.5, 133.5],
         [ 90.5,  72.5],
         [ 73.5, 195.5],
         [ 50.5, 173.5],
         [144.5, 126.5],
         [ 98.5, 140.5],
         [151.5,  85.5],
         [ 67.5, 117.5],
         [113.5,  64.5],
         [ 64.5, 107.5],
         [202.5, 118.5],
         [128.5,  87.5],
         [143.5, 116.5],
         [154.5, 137.5],
         [159.5, 111.5],
         [ 43.5, 124.5],
         [140.5,  80.5],
         [174.5,  31.5],
         [ 79.5, 165.5],
         [134.5, 154.5],
         [145.5,  99.5],
         [110.5,  79.5],
         [129.5,  95.5],
         [ 99.5,  96.5],
         [170.5,  28.5],
         [ 88.5, 116.5],
         [144.5, 154.5],
         [ 91.5,  86.5],
         [207.5, 161.5],
         [138.5, 116.5],
         [ 79.5

# LoFTR
- Default score is number of significant (confidence > 0.5) correspondences.

In [2]:
dataset.transform = T.Compose([T.Resize([256, 256]), T.Grayscale(), T.ToTensor()])

matcher = MatchLOFTR()
output = matcher(dataset, dataset)
output

100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.50it/s]


array([[784.,   5.,  30.],
       [  7., 784.,   6.],
       [ 27.,   9., 774.]], dtype=float16)