In [1]:
import pandas as pd
import numpy as np
import timm
import torchvision.transforms as T

from wildlife_tools.features import SuperPointExtractor, SiftExtractor, DiskExtractor, AlikedExtractor, DeepFeatures
from wildlife_tools.data import ImageDataset

metadata = {'metadata':  pd.read_csv('../tests/TestDataset/metadata.csv'), 'root': '../tests/TestDataset'}
transform = T.Compose([T.Resize([224, 224]), T.ToTensor()])
dataset = ImageDataset(**metadata, transform=transform)


# Extract local features

In [10]:
extractor = SiftExtractor()
output = extractor(dataset)

assert len(output) == len(dataset)
assert tuple(output.features[0]['keypoints'].shape) == (256, 2)
assert tuple(output.features[0]['descriptors'].shape) == (256, 128)

100%|█████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.45it/s]


In [11]:
extractor = SuperPointExtractor()
output = extractor(dataset)

assert len(output) == len(dataset)
assert tuple(output.features[0]['keypoints'].shape) == (256, 2)
assert tuple(output.features[0]['descriptors'].shape) == (256, 256)

100%|█████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  4.46it/s]


In [12]:
extractor = AlikedExtractor()
output = extractor(dataset)

assert len(output) == len(dataset)
assert tuple(output.features[0]['keypoints'].shape) == (256, 2)
assert tuple(output.features[0]['descriptors'].shape) == (256, 128)

100%|█████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  5.84it/s]


In [13]:
extractor = DiskExtractor()
output = extractor(dataset)

assert len(output) == len(dataset)
assert tuple(output.features[0]['keypoints'].shape) == (256, 2)
assert tuple(output.features[0]['descriptors'].shape) == (256, 128)

100%|█████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 11.04it/s]


# Extract deep features

In [15]:
backbone = timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', num_classes=0, pretrained=True)
extractor = DeepFeatures(backbone)
output = extractor(dataset)

assert len(output) == len(dataset)
assert tuple(output.features.shape) == (len(dataset), 768)

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (BVRA/MegaDescriptor-T-224)
100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.13s/it]
