# MegaDescriptor-B-224
- Run inference with MegaDescriptor-B-224 (https://huggingface.co/BVRA/MegaDescriptor-B-224)

In [None]:
import pandas as pd
from torchvision import transforms as T
from timm import create_model

from wildlife_tools.features import DeepFeatures
from wildlife_tools.data import WildlifeDataset
from wildlife_tools.similarity import CosineSimilarity
from wildlife_tools.inference import KnnClassifier


datasets = [
    'BirdIndividualID',
    'SealID',
    'FriesianCattle2015',
    'ATRW',
    'NDD20',
    'SMALST',
    'SeaTurtleIDHeads',
    'AAUZebraFish',
    'CZoo',
    'CTai',
    'Giraffes',
    'HyenaID2022',
    'MacaqueFaces',
    'OpenCows2020',
    'StripeSpotter',
    'AerialCattle2017',
    'GiraffeZebraID',
    'IPanda50',
    'WhaleSharkID',
    'FriesianCattle2017',
    'Cows2021',
    'LeopardID2022',
    'NOAARightWhale',
    'HappyWhale',
    'HumpbackWhaleID',
    'LionData',
    'NyalaData',
    'ZindiTurtleRecall',
    'BelugaID',
    ]

model = create_model("hf-hub:BVRA/MegaDescriptor-B-224", pretrained=True)
extractor = DeepFeatures(model, device='cuda')

root_images = '../data/images/size-256'
root_metadata = '../data/metadata/datasets'

In [None]:
results = {}
for name in datasets:
    metadata = pd.read_csv(f'{root_metadata}/{name}/metadata.csv', index_col=0)

    transform = T.Compose([
        T.Resize(size=(224, 224)),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ])

    database = WildlifeDataset(
        metadata=metadata.query('split == "train"'),
        root=f'{root_images}/{name}/',
        transform=transform,
    )

    query = WildlifeDataset(
        metadata=metadata.query('split == "test"'),
        root=f'{root_images}/{name}/',
        transform=transform,
    )

    matcher = CosineSimilarity()
    similarity = matcher(query=extractor(query), database=extractor(database))
    preds = KnnClassifier(k=1, database_labels=database.labels_string)(similarity['cosine'])
    
    acc = sum(preds == query.labels_string) / len(preds)
    print(name, acc)
    results[name] = acc


pd.Series(results).to_csv('results/MegaDescriptor-B-224.csv')