In [2]:
import timm
import pandas as pd
import torchvision.transforms as T

from wildlife_tools.data import WildlifeDataset, SplitMetadata
from wildlife_tools.features import DeepFeatures
from wildlife_tools.similarity import CosineSimilarity
from wildlife_tools.inference import KnnClassifier

# Prepare dataset
- Load metadata from pandas dataframe
- Create two datasets using split information from the metadata 
    - query - created from test split
    - database - created from train subset.

In [3]:
metadata = pd.read_csv('ExampleDataset/metadata.csv')
image_root = 'ExampleDataset'

transform = T.Compose([
    T.Resize([224, 224]),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

dataset_database = WildlifeDataset(
    metadata = metadata, 
    root = image_root,
    split = SplitMetadata('split', 'train'),
    transform=transform
)

dataset_query = WildlifeDataset(
    metadata = metadata, 
    root = image_root,
    split = SplitMetadata('split', 'test'),
    transform=transform
)

# Extract features
- Extract features using MegaDescriptor-Tiny (https://huggingface.co/BVRA/MegaDescriptor-T-224)

- Input to the extractor is WildlifeDataset object.
- Output is numpy array with shape (n_images x dim_embeddings)

In [4]:
from wildlife_tools.features import DeepFeatures
import timm

backbone = timm.create_model('hf-hub:BVRA/MegaDescriptor-T-224', num_classes=0, pretrained=True)
extractor = DeepFeatures(backbone)
query, database = extractor(dataset_query), extractor(dataset_database)

print(f'Query features shape: {query.shape}, Database features shape: {database.shape}')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.35it/s]
100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.47it/s]

Query features shape: (1, 768), Database features shape: (2, 768)





## Similarity and k-nn classification
- Calculate cosine similarity between query and database features.
    - Inputs are arrays with query and database features
    - Output is matrix with shape n_query x n_database.

- Use the similarity for KNN classifier
    - Output is array of n_query length.
    - Values in the array are nearest labels in the database. (with ordinal encoding - indexes of columns in the similarity matrix).

In [5]:
# Cosine similarity between deep features
similarity = CosineSimilarity()
sim = similarity(query, database)['cosine']
print("Similarity matrix: \n", sim.shape)


# Nearest neigbour classifier using the similarity
classifier = KnnClassifier(k=1, database_labels=dataset_database.labels_map)
preds = classifier(sim)
print("Prediction \t", preds)
print("Ground truth \t", dataset_query.labels_string)

acc = sum(preds == dataset_query.labels_string) / len(dataset_query.labels_string)
print('\n Accuracy: ', acc)

Similarity matrix: 
 (1, 2)
Prediction 	 ['a']
Ground truth 	 ['a']

 Accuracy:  1.0
