In [4]:
from argparse import ArgumentParser
import pandas as pd
from urllib.request import urlopen
from PIL import Image
import timm
import torch
import numpy as np
from transformers import AutoModel, AutoImageProcessor

In [5]:
def load_class_mapping(class_list_file):
    with open(class_list_file) as f:
        class_index_to_class_name = {i: line.strip() for i, line in enumerate(f)}
    return class_index_to_class_name


def load_species_mapping(species_map_file):
    df = pd.read_csv(species_map_file, sep=';', quoting=1, dtype={'species_id': str})
    df = df.set_index('species_id')
    return  df['species'].to_dict()

In [13]:
model_name = "facebook/dinov2-base"
model = AutoModel.from_pretrained(model_name, output_attentions=True)
processor =AutoImageProcessor.from_pretrained(model_name)

In [6]:
cid_to_spid = load_class_mapping('models/pretrained_models/class_mapping.txt')
spid_to_sp = load_species_mapping("models/pretrained_models/species_id_to_name.txt")

device = torch.device("cuda")

model = timm.create_model('vit_base_patch14_reg4_dinov2.lvd142m',
                          pretrained=False,
                          num_classes=len(cid_to_spid),
                          checkpoint_path="models/pretrained_models/vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all\model_best.pth.tar")
model = model.to(device)
model = model.eval()

In [7]:
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

In [11]:

img = Image.open("models/pretrained_models/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg")
    
if img != None:
    img = transforms(img).unsqueeze(0)
    img = img.to(device)
    with torch.no_grad():
        output = model(img)  # unsqueeze single image into batch of 1
    top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)
    top5_probabilities = top5_probabilities.cpu().detach().numpy()
    top5_class_indices = top5_class_indices.cpu().detach().numpy()

    for proba, cid in zip(top5_probabilities[0], top5_class_indices[0]):
        species_id = cid_to_spid[cid]
        species = spid_to_sp[species_id]
        print(species_id, species, proba)

1361687 Orchis simia Lam. 43.589096
1361678 Orchis italica Poir. 12.680444
1628935 Orchis × bergonii Nanteuil 8.529679
1628933 Orchis × angusticruris Franch. 2.163737
1394279 Orchis militaris L. 1.8919474
