In [24]:
import pandas as pd
import numpy as np
import timm
import torchvision.transforms as T

import sys
sys.path.append('../wildlife_tools/')

from features.local import SuperPointExtractor, SiftExtractor, AlikedExtractor, DiskExtractor
from similarity.cosine import CosineSimilarity
from similarity.pairwise.loftr import MatchLOFTR
from similarity.pairwise.lightglue import MatchLightGlue
from similarity.pairwise.collectors import CollectCounts, CollectCountsRansac, CollectAll

from data.dataset import WildlifeDataset


metadata = {'metadata':  pd.read_csv('TestDataset/metadata.csv'), 'root': 'TestDataset'}
transform = T.Compose([T.Resize([224, 224]), T.ToTensor()])
dataset = WildlifeDataset(**metadata, transform=transform)


# LightGlue

In [4]:
extractor = SuperPointExtractor()
dataset.transform = T.Compose([T.Resize([224, 224]), T.ToTensor()])

features_query = extractor(dataset)
features_database = extractor(dataset)

100%|█████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  8.00it/s]
100%|█████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  7.92it/s]


In [6]:
matcher = MatchLightGlue(features='superpoint')
output = matcher(features_query, features_database)
output

100%|█████████████████████████████████████████████████████████████████| 5/5 [00:36<00:00,  7.28s/it]


array([[198.,   0.,   2.],
       [  0., 239.,   0.],
       [  2.,   0., 198.]], dtype=float16)

In [7]:
collector = CollectCounts(thresholds=[0.2, 0.5, 0.8])
matcher = MatchLightGlue(features='superpoint', collector=collector)
output = matcher(features_query, features_database)
output

100%|█████████████████████████████████████████████████████████████████| 5/5 [00:35<00:00,  7.20s/it]


{0.2: array([[229.,   2.,  16.],
        [  2., 247.,   0.],
        [ 16.,   0., 238.]], dtype=float16),
 0.5: array([[198.,   0.,   2.],
        [  0., 239.,   0.],
        [  2.,   0., 198.]], dtype=float16),
 0.8: array([[157.,   0.,   1.],
        [  0., 202.,   0.],
        [  1.,   0., 117.]], dtype=float16)}

In [15]:
collector = CollectCountsRansac(ransacReprojThreshold=1.0, maxIters=100)
matcher = MatchLightGlue(features='superpoint', collector=collector)
output = matcher(features_query, features_database)
output

100%|█████████████████████████████████████████████████████████████████| 2/2 [00:12<00:00,  6.13s/it]


array([[238.,   0.,  20.],
       [  0., 248.,   0.],
       [ 21.,   0., 248.]], dtype=float16)

In [25]:
collector = CollectAll()
matcher = MatchLightGlue(features='superpoint', collector=collector)
output = matcher(features_query, features_database)
output

100%|█████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00,  7.10s/it]


[{'idx0': 0,
  'idx1': 0,
  'kpts0': array([[119.5, 122.5],
         [131.5,  77.5],
         [115.5, 143.5],
         [ 82.5, 156.5],
         [ 95.5, 157.5],
         [156.5, 156.5],
         [ 76.5, 101.5],
         [ 79.5,  85.5],
         [164.5, 133.5],
         [ 90.5,  72.5],
         [ 73.5, 195.5],
         [ 50.5, 173.5],
         [144.5, 126.5],
         [ 98.5, 140.5],
         [151.5,  85.5],
         [ 67.5, 117.5],
         [113.5,  64.5],
         [ 64.5, 107.5],
         [202.5, 118.5],
         [128.5,  87.5],
         [143.5, 116.5],
         [154.5, 137.5],
         [159.5, 111.5],
         [ 43.5, 124.5],
         [140.5,  80.5],
         [174.5,  31.5],
         [ 79.5, 165.5],
         [145.5,  99.5],
         [134.5, 154.5],
         [110.5,  79.5],
         [129.5,  95.5],
         [ 99.5,  96.5],
         [170.5,  28.5],
         [ 88.5, 116.5],
         [144.5, 154.5],
         [ 91.5,  86.5],
         [207.5, 161.5],
         [138.5, 116.5],
         [ 79.5

In [26]:
len(output)

9

# LoFTR

In [19]:
dataset.transform = T.Compose([T.Resize([256, 256]), T.Grayscale(), T.ToTensor()])

matcher = MatchLOFTR()
output = matcher(dataset, dataset)
output

100%|█████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.20s/it]


array([[784.,   5.,  29.],
       [  7., 784.,   6.],
       [ 27.,   8., 774.]], dtype=float16)