In [1]:
import timm
import pandas as pd
import torchvision.transforms as T

from wildlife_tools.data import WildlifeDataset, SplitMetadata
from wildlife_tools.features import SIFTFeatures
from wildlife_tools.similarity import MatchDescriptors
from wildlife_tools.inference import KnnClassifier

# Prepare dataset
- Load metadata from pandas dataframe
- Create two datasets using split information from the metadata 
    - query - created from test split
    - database - created from train subset.

In [2]:
metadata = pd.read_csv('ExampleDataset/metadata.csv')
image_root = 'ExampleDataset'

transform = T.Compose([
    T.Resize([224, 224]),
    T.Grayscale()
])

dataset_database = WildlifeDataset(
    metadata = metadata, 
    root = image_root,
    split = SplitMetadata('split', 'train'),
    transform=transform
)

dataset_query = WildlifeDataset(
    metadata = metadata, 
    root = image_root,
    split = SplitMetadata('split', 'test'),
    transform=transform
)

# SIFT Feature Extraction

This process involves extracting a set of SIFT descriptors for each image in the dataset. 

The extractor takes a WildlifeDataset object as input and produces a list of arrays, where each array corresponds to an image. These arrays are 2D with a shape of (n_descriptors x 128), where "n_descriptors" varies depending on the number of SIFT descriptors extracted for the respective image.

In [3]:
extractor = SIFTFeatures()
query, database = extractor(dataset_query), extractor(dataset_database)

print(f'First 5 query features shape: {[i.shape for i in query[:5]]}')
print(f'First 5 database features shape: {[i.shape for i in database[:5]]}')

100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.67it/s]
100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 35.56it/s]

First 5 query features shape: [(238, 128)]
First 5 database features shape: [(182, 128), (190, 128)]





## Similarity and k-nn classification
- Calculate similarity between query and database as number of SIFT correspondences after filtering with Low ratio test.
    - Output is matrix with shape n_query x n_database.
- Use the similarity for KNN classifier
    - Output is array of n_query length.
    - Values in the array are nearest labels in the database. (with ordinal encoding - indexes of columns in the similarity matrix).

In [4]:
similarity = MatchDescriptors(descriptor_dim=128, thresholds=[0.8])
sim = similarity(query, database)[0.8]

print("Number of SIFT correspondences after 0.8 ratio test threshold: \n", sim)

100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 37.49it/s]

Number of SIFT correspondences after 0.8 ratio test threshold: 
 [[34. 28.]]





# Knn classification

In [5]:
# Nearest neigbour classifier using the similarity
classifier = KnnClassifier(k=1, database_labels=dataset_database.labels_map)
preds = classifier(sim)
print("Prediction \t", preds)
print("Ground truth \t", dataset_query.labels_string)

acc = sum(preds == dataset_query.labels_string) / len(dataset_query.labels_string)
print('\n Accuracy: ', acc)

Prediction 	 ['a']
Ground truth 	 ['a']

 Accuracy:  1.0
