## Evaluate models based on nearest-neighbor

In [16]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F

import matplotlib.pyplot as plt

from sklearn.neighbors import NearestNeighbors

import seaborn as sns
sns.set_theme(color_codes=True)
import os
import sys

import transformers
from transformers import AutoModelForImageClassification, AutoConfig, AutoFeatureExtractor
from transformers.utils import logging
from transformers import DefaultDataCollator
import scipy.spatial.distance as distance
from transformers import TrainerCallback

logging.set_verbosity(transformers.logging.ERROR) 
logging.disable_progress_bar() 

p = os.path.abspath('../')
sys.path.insert(1, p)

from functools import partial 
from torchtext.vocab import build_vocab_from_iterator
from transformers import DefaultDataCollator

from torchtext.data.utils import get_tokenizer
from torch.utils.data import Dataset

from datasets import Dataset

from torchvision.io import read_image
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

import evaluate
from src.utils.utils import *

import math
import copy

from collections import defaultdict

import random
import torchvision
import torchvision.transforms as transforms

from datasets import Image, Features, Value

from src.wordnet_ontology.wordnet_ontology import WordnetOntology
from src.contrastive_transformers.collators import ImageCollator

import os
from datasets import load_dataset 

seed=7631
n_excluded_classes = int(556 * 0.05)
N_EXAMPLES = 32
batch_size = 16

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Preparing datasets

In [2]:
mapping_filename = './data/external/imagenet/LOC_synset_mapping.txt'
wn = WordnetOntology(mapping_filename)

vocab = torch.load('./models/vocab.pt')

sketch = load_dataset("imagenet_sketch", split='train', cache_dir='./cache/')
sketch = sketch.map(lambda x: {
    'label': vocab[wn.hypernym(wn.class_for_index[x['label']])],
})

Found cached dataset imagenet_sketch (/mnt/HDD/kevinds/sketch/./cache/imagenet_sketch/default/0.0.0/9bbda26372327ae1daa792112c8bbd2545a91b9f397ea6f285576add0a70ab6e)
Loading cached processed dataset at /mnt/HDD/kevinds/sketch/./cache/imagenet_sketch/default/0.0.0/9bbda26372327ae1daa792112c8bbd2545a91b9f397ea6f285576add0a70ab6e/cache-f827821a6a95155f.arrow


In [3]:
imagenet_classes_folder = './data/external/imagenet/ILSVRC/Data/CLS-LOC/train'

image_labels = [] 
image_paths = []

N_IMAGENET_EXAMPLES = N_EXAMPLES
imagenet_classes = sorted(os.listdir(imagenet_classes_folder))
for img_class in imagenet_classes:
    all_imgs = os.listdir(f"{imagenet_classes_folder}/{img_class}/")
    img_names = [random.choice(all_imgs) for _ in range(0, N_IMAGENET_EXAMPLES)]
                              
    image_paths.extend([f"{imagenet_classes_folder}/{img_class}/{name}" for name in img_names])
    image_labels.extend([img_class] * len(img_names))

In [4]:
_classes = list(set(sketch['label']))
excluded_classes = [random.choice(_classes) for i in range(n_excluded_classes)]
dt = train_test_split(sketch, excluded_labels=excluded_classes)
train, test = dt['train'], dt['test']

## Evaluate the model

In [17]:
torch.hub.set_dir('./cache')
model_path = f'./models/contrastive-pretraining-{seed}/last-checkpoint'
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
model = AutoModel.from_pretrained(model_path)
model.eval()
model.to(device)

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0): ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
      

In [18]:
test_transforms = Compose([
    transforms.Resize((feature_extractor.size, feature_extractor.size)), 
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
])

def get_pixel_values(example):
    return test_transforms(example.convert("RGB"))

image_classes = pd.concat([
    pd.DataFrame({
        'image': [img['path'] for img in train.cast_column('image', Image(decode=False))['image']], 
        'label': train['label']
    }),
    pd.DataFrame({
        'image': image_paths, 
        'label': [vocab[wn.hypernym(lb)] for lb in image_labels]
})], axis=0)

cl = image_classes.groupby('label').sample(n=N_EXAMPLES, replace=True)
#cl = image_classes#.drop_duplicates(['label']).reset_index(drop=True)
classes_dataset = Dataset.from_pandas(cl.reset_index(drop=True), features=Features({'image': Image(decode=True), 'label': Value('int64')}))

data_collator = DefaultDataCollator()

accuracyk = evaluate.load("KevinSpaghetti/accuracyk")

seen, unseen = get_seen_unseen_split(train, test, label_col='label')
results= {}

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/13 [00:00<?, ?ba/s]

In [19]:
def encode_classes(model, classes, num_classes, label_col='label'):
    class_encodings = torch.zeros((num_classes, model.config.hidden_size), device=device)
    labels, counts = np.unique(classes[label_col], return_counts=True)
    with torch.no_grad():
        for idx in tqdm(range(0, len(classes), batch_size), total=len(classes) // batch_size):
            batch_start = idx
            batch_end = min(len(classes), batch_start+batch_size)
            imgs = torch.stack([
                test_transforms(cl.convert("RGB")) for cl in classes[batch_start: batch_end]['image']
            ])
            model_output = model(pixel_values=imgs.to(device, non_blocking=True), output_hidden_states=True)
            embedding = model_output.pooler_output
            class_encodings[classes[batch_start:batch_end]['label']] += embedding
    
    class_encodings = class_encodings.to('cpu', non_blocking=True)
    class_encodings /= counts[:, None]
    return class_encodings

In [20]:
def encode_samples(model, samples):
    with torch.no_grad():
        sample_encodings = torch.zeros((len(samples), model.config.hidden_size), device=device)
        for idx in tqdm(range(0, len(samples), batch_size), total=len(samples) // batch_size):
            batch_start = idx
            batch_end = min(len(samples), batch_start+batch_size)
            imgs = torch.stack([
                test_transforms(sample.convert("RGB")) for sample in samples[batch_start: batch_end]['image']
            ])
            model_output = model(pixel_values=imgs.to(device, non_blocking=True), output_hidden_states=True)
            model_prediction = model_output.pooler_output
            sample_encodings[batch_start:batch_end] = model_prediction
    return sample_encodings

In [21]:
class_encodings = encode_classes(model, classes_dataset, len(vocab))

100%|███████████████████████████████████████████| 4448/4448 [07:00<00:00, 10.57it/s]


In [22]:
embedding_index = NearestNeighbors(n_neighbors=5, 
                                   metric='cosine', 
                                   algorithm='brute', 
                                   n_jobs=-1).fit(class_encodings.to('cpu'))

encoded_samples = encode_samples(model, test)
_, predictions = embedding_index.kneighbors(encoded_samples.cpu())
results['complete'] = {
    'top1': accuracyk.compute(predictions=predictions[:, 0][:, None], references=test['label'])['accuracy'],
    'top5': accuracyk.compute(predictions=predictions, references=test['label'])['accuracy']
}

encoded_samples = encode_samples(model, seen)
_, predictions = embedding_index.kneighbors(encoded_samples.cpu())
results['seen'] = {
    'top1': accuracyk.compute(predictions=predictions[:, 0][:, None], references=seen['label'])['accuracy'],
    'top5': accuracyk.compute(predictions=predictions, references=seen['label'])['accuracy']
}

encoded_samples = encode_samples(model, unseen)
_, predictions = embedding_index.kneighbors(encoded_samples.cpu())
results['unseen'] = {
    'top1': accuracyk.compute(predictions=predictions[:, 0][:, None], references=unseen['label'])['accuracy'],
    'top5': accuracyk.compute(predictions=predictions, references=unseen['label'])['accuracy']
}

3065it [01:59, 25.55it/s]                                                           
2545it [01:39, 25.58it/s]                                                           
100%|█████████████████████████████████████████████| 520/520 [00:21<00:00, 24.66it/s]


In [23]:
results

{'complete': {'top1': 0.7157542628701966, 'top5': 0.9423186750428326},
 'seen': {'top1': 0.723887196619829, 'top5': 0.9443843961874816},
 'unseen': {'top1': 0.6759615384615385, 'top5': 0.9322115384615385}}