# MegaDescriptor-S-224
- Run inference with MegaDescriptor-S-224 (https://huggingface.co/BVRA/MegaDescriptor-S-224)

In [5]:
import pandas as pd
from torchvision import transforms as T
from timm import create_model
import torch
import torch.nn as nn
from pathlib import Path
import os

from wildlife_tools.features.deep import DeepFeatures
from wildlife_tools.data.dataset import WildlifeDataset
from wildlife_tools.data.split import SplitMetadata
from wildlife_tools.similarity.cosine import CosineSimilarity
from wildlife_tools.evaluation.classifier import KnnClassifier

In [18]:
import numpy as np

In [13]:
# root of model paths
models_root = Path('/Users/fmb/GitHub/764WildlifeReID/megadescriptor/models/')
root_images = Path('/Users/fmb/GitHub/764WildlifeReID/megadescriptor/data/images/size-256')
root_metadata = Path('/Users/fmb/GitHub/764WildlifeReID/megadescriptor/metadata/datasets')

In [15]:
# store results for all inferencing here
results = {}

## Baseline Inference

In [8]:
model_baseline = torch.load(models_root/'md_baseline.pth', map_location='cpu')

  model_baseline = torch.load(models_root/'md_baseline.pth', map_location='cpu')


In [9]:
extractor_baseline = DeepFeatures(model_baseline, device='cpu')

In [11]:
# defined for all four datasets
transform = T.Compose([
    T.Resize(size=(224, 224)),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

In [14]:
# prepare dataset

#***
metadata_baseline = pd.read_csv(root_metadata/'fca_base/metadata.csv', index_col=0)
images_path = root_images/'fca_base'

#***
database_baseline = WildlifeDataset(
    metadata=metadata_baseline,
    root=images_path,
    transform=transform,
    split=SplitMetadata('split', 'train'),
)

#***
query_baseline = WildlifeDataset(
    metadata=metadata_baseline,
    root=images_path,
    transform=transform,
    split=SplitMetadata('split', 'test'),
)

In [None]:
matcher = CosineSimilarity()

#***
similarity_baseline = matcher(query=extractor_baseline(query_baseline), 
                     database=extractor_baseline(database_baseline))

In [21]:
similarity_baseline.keys()

dict_keys(['default'])

In [29]:
classifier = KnnClassifier(k=1)
predictions_baseline = classifier(similarity_baseline['default'], labels=database_baseline.labels_string)

In [None]:
predictions_baseline

In [31]:
#***
acc_baseline = sum(predictions_baseline == query_baseline.labels_string) / len(predictions_baseline)

#***
print('fca_base', acc_baseline)
results['fca_base'] = acc_baseline

fca_base 0.9606299212598425


## Backbone Inference

In [54]:
model_backbone = torch.load(models_root/'backbone.pth', map_location='cpu')

  model_backbone = torch.load(models_root/'backbone.pth', map_location='cpu')


In [55]:
extractor_backbone = DeepFeatures(model_backbone, device='cpu')

In [56]:
# prepare dataset

#***
metadata_backbone = pd.read_csv(root_metadata/'fca_base/metadata.csv', index_col=0)
images_path = root_images/'fca_base'

#***
database_backbone = WildlifeDataset(
    metadata=metadata_backbone,
    root=images_path,
    transform=transform,
    split=SplitMetadata('split', 'train'),
)

#***
query_backbone = WildlifeDataset(
    metadata=metadata_backbone,
    root=images_path,
    transform=transform,
    split=SplitMetadata('split', 'test'),
)

In [None]:
matcher = CosineSimilarity()

#***
similarity_backbone = matcher(query=extractor_backbone(query_backbone), 
                     database=extractor_backbone(database_backbone))

In [58]:
classifier = KnnClassifier(k=1)
predictions_backbone = classifier(similarity_backbone['default'], labels=database_backbone.labels_string)

In [59]:
#***
acc_backbone = sum(predictions_backbone == query_backbone.labels_string) / len(predictions_backbone)

#***
print('backbone', acc_backbone)
results['backbone'] = acc_backbone

backbone 0.9501312335958005


## maxim Inference

In [32]:
model_maxim = torch.load(models_root/'md_maxim.pth', map_location='cpu')

  model_maxim = torch.load(models_root/'md_maxim.pth', map_location='cpu')


In [33]:
extractor_maxim = DeepFeatures(model_maxim, device='cpu')

In [35]:
# prepare dataset

#***
metadata_maxim = pd.read_csv(root_metadata/'FeralCatsAkl_maxim/metadata.csv', index_col=0)
images_path = root_images/'FeralCatsAkl_maxim'

#***
database_maxim = WildlifeDataset(
    metadata=metadata_maxim,
    root=images_path,
    transform=transform,
    split=SplitMetadata('split', 'train'),
)

#***
query_maxim = WildlifeDataset(
    metadata=metadata_maxim,
    root=images_path,
    transform=transform,
    split=SplitMetadata('split', 'test'),
)

In [None]:
matcher = CosineSimilarity()

#***
similarity_maxim = matcher(query=extractor_maxim(query_maxim), 
                     database=extractor_maxim(database_maxim))

In [38]:
classifier = KnnClassifier(k=1)
predictions_maxim = classifier(similarity_maxim['default'], labels=database_maxim.labels_string)

In [40]:
#***
acc_maxim = sum(predictions_maxim == query_maxim.labels_string) / len(predictions_maxim)

#***
print('FeralCatsAkl_maxim', acc_maxim)
results['FeralCatsAkl_maxim'] = acc_maxim

FeralCatsAkl_maxim 0.9658792650918635


## hidiff Inference

In [41]:
model_hidiff = torch.load(models_root/'md_hidiff.pth', map_location='cpu')

  model_hidiff = torch.load(models_root/'md_hidiff.pth', map_location='cpu')


In [42]:
extractor_hidiff = DeepFeatures(model_hidiff, device='cpu')

In [43]:
# prepare dataset

#***
metadata_hidiff = pd.read_csv(root_metadata/'FeralCatsAkl_HIDiff/metadata.csv', index_col=0)
images_path = root_images/'FeralCatsAkl_HIDiff'

#***
database_hidiff = WildlifeDataset(
    metadata=metadata_hidiff,
    root=images_path,
    transform=transform,
    split=SplitMetadata('split', 'train'),
)

#***
query_hidiff = WildlifeDataset(
    metadata=metadata_hidiff,
    root=images_path,
    transform=transform,
    split=SplitMetadata('split', 'test'),
)

In [None]:
matcher = CosineSimilarity()

#***
similarity_hidiff = matcher(query=extractor_hidiff(query_hidiff), 
                     database=extractor_hidiff(database_hidiff))

In [45]:
classifier = KnnClassifier(k=1)
predictions_hidiff = classifier(similarity_hidiff['default'], labels=database_hidiff.labels_string)

In [46]:
#***
acc_hidiff = sum(predictions_hidiff == query_hidiff.labels_string) / len(predictions_hidiff)

#***
print('FeralCatsAkl_hidiff', acc_hidiff)
results['FeralCatsAkl_hidiff'] = acc_hidiff

FeralCatsAkl_hidiff 0.9763779527559056


## srmnet Inference

In [47]:
model_srmnet = torch.load(models_root/'md_srmnet.pth', map_location='cpu')

  model_srmnet = torch.load(models_root/'md_srmnet.pth', map_location='cpu')


In [48]:
extractor_srmnet = DeepFeatures(model_srmnet, device='cpu')

In [49]:
# prepare dataset

#***
metadata_srmnet = pd.read_csv(root_metadata/'FeralCatsAkl_SRMNet/metadata.csv', index_col=0)
images_path = root_images/'FeralCatsAkl_SRMNet'

#***
database_srmnet = WildlifeDataset(
    metadata=metadata_srmnet,
    root=images_path,
    transform=transform,
    split=SplitMetadata('split', 'train'),
)

#***
query_srmnet = WildlifeDataset(
    metadata=metadata_srmnet,
    root=images_path,
    transform=transform,
    split=SplitMetadata('split', 'test'),
)

In [None]:
matcher = CosineSimilarity()

#***
similarity_srmnet = matcher(query=extractor_srmnet(query_srmnet), 
                     database=extractor_srmnet(database_srmnet))

In [51]:
classifier = KnnClassifier(k=1)
predictions_srmnet = classifier(similarity_srmnet['default'], labels=database_srmnet.labels_string)

In [52]:
#***
acc_srmnet = sum(predictions_srmnet == query_srmnet.labels_string) / len(predictions_srmnet)

#***
print('FeralCatsAkl_srmnet', acc_srmnet)
results['FeralCatsAkl_srmnet'] = acc_srmnet

FeralCatsAkl_srmnet 0.9553805774278216


## Results

In [60]:
pd.Series(results).to_csv('results/FCA_MegaDescriptor-S-224.csv')

## Reference cells

In [None]:
# model = create_model("hf-hub:BVRA/MegaDescriptor-S-224", pretrained=True)
# model = timm.create_model('swin_small_patch4_window7_224', num_classes=5, pretrained=True)
# extractor = DeepFeatures(model, device='cuda')

# root_images = '../data/images/size-256'
# root_metadata = '../data/metadata/datasets'

In [None]:
results = {}
for name in datasets:
    metadata = pd.read_csv(f'{root_metadata}/{name}/metadata.csv', index_col=0)

    transform = T.Compose([
        T.Resize(size=(224, 224)),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])

    database = WildlifeDataset(
        metadata=metadata,
        root=f'{root_images}/{name}/',
        transform=transform,
        split=SplitMetadata('split', 'train'),
    )

    query = WildlifeDataset(
        metadata=metadata,
        root=f'{root_images}/{name}/',
        transform=transform,
        split=SplitMetadata('split', 'test'),
    )

    matcher = CosineSimilarity()
    similarity = matcher(query=extractor(query), database=extractor(database))
    preds = KnnClassifier(k=1, database_labels=database.labels_string)(similarity['cosine'])
    
    acc = sum(preds == query.labels_string) / len(preds)
    print(name, acc)
    results[name] = acc


pd.Series(results).to_csv('results/MegaDescriptor-S-224.csv')