In [33]:
from transformers import AutoImageProcessor, Dinov2ForImageClassification
import torch
import datasets
from tqdm import tqdm

image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-small-imagenet1k-1-layer")
model = Dinov2ForImageClassification.from_pretrained("facebook/dinov2-small-imagenet1k-1-layer")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(DEVICE)

In [47]:
from itertools import islice

def batched(iterable, n):
    iterator = iter(iterable)
    while batch := list(islice(iterator, n)):
        yield batch

def inference(image):
    inputs = image_processor(image, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        logits = model(**inputs).logits.cpu()
    predicted_labels = logits.argmax(-1)
    return predicted_labels

def classify_images(dataset:datasets.arrow_dataset.Dataset, batch_size=32):
    correct = 0
    for batch in tqdm(batched(dataset, batch_size)):
        images = [x['image'] for x in batch]
        inputs = image_processor(images, return_tensors="pt").to(DEVICE)
        with torch.no_grad():
            logits = model(**inputs).logits.cpu()
        predicted_labels = logits.argmax(-1)
        correct += (predicted_labels == torch.tensor([example['label'] for example in batch])).sum().item()
    return correct / len(dataset)

## classification accuracy on imagenet-tiny

In [49]:
# load imagenet-1k validation set
data_path = '/home/gyc/datasets/imagenet-tiny'
dataset = datasets.load_dataset('imagefolder', data_dir=data_path)

Resolving data files:   0%|          | 0/1999 [00:00<?, ?it/s]

In [37]:
# classification accuracy with batch size 16
print('accuracy:', classify_images(dataset['validation'], 64))

32it [00:34,  1.08s/it]

accuracy: 0.8178178178178178





## classification accuracy on imagenet-c

In [52]:
# load imagenet-c
c_data_path = '/home/gyc/datasets/imagenet-c/blur/defocus_blur/1'
c_dataset = datasets.load_dataset('imagefolder', data_dir=c_data_path)

Resolving data files:   0%|          | 0/50000 [00:00<?, ?it/s]

In [60]:
print('accuracy:', classify_images(c_dataset['validation'], 64))

782it [11:01,  1.18it/s]

accuracy: 0.69592



