# Imagenet Evaluation Script
modified from [the evluation script by OpenAI](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Prompt_Engineering_for_ImageNet.ipynb).

In [None]:
!pip install -q -U jax jaxlib
!pip install -q pandas
!pip install -q ipywidgets
!pip install -q -U flax
!pip install -q sentence-transformers
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q transformers
!pip install -q torch torchvision

In [None]:
import os 
import sys
import json

import numpy as np
import pandas as pd

os.environ['TOKENIZERS_PARALLELISM'] = "false"

import transformers
from transformers import AutoTokenizer
from transformers import FlaxVisionTextDualEncoderModel, VisionTextDualEncoderModel

import torch
import torchvision
from torchvision import transforms
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ToTensor
from torchvision.transforms.functional import InterpolationMode
from tqdm.notebook import tqdm

sys.path.append('.')


# Choosing the model to evaluate

In [None]:
MODEL_TYPE = 'mClip'
#MODEL_TYPE = 'Arabic_clip'

# Loading the model

In [None]:
if MODEL_TYPE == 'mClip':
    from sentence_transformers import SentenceTransformer
    # Here we load the multilingual CLIP model. Note, this model can only encode text.
    # If you need embeddings for images, you must load the 'openai/clip-vit-base-patch32' model
    se_language_model = SentenceTransformer('clip-ViT-B-32-multilingual-v1')
    from transformers import CLIPVisionModelWithProjection
    se_image_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
    language_model = lambda queries: se_language_model.encode(queries, convert_to_tensor=True, show_progress_bar=False).cpu().detach().numpy()
    image_model = lambda images: se_image_model(images)[0].cpu().detach().numpy()
elif MODEL_TYPE == 'Arabic_clip':
    import jax
    from jax import numpy as jnp
    TOKENIZER_NAME = "asafaya/bert-large-arabic"
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, cache_dir=None, use_fast=True)
    from transformers import FlaxVisionTextDualEncoderModel, VisionTextDualEncoderModel
    model = FlaxVisionTextDualEncoderModel.from_pretrained("xx", logit_scale_init_value=1)
    model.save_pretrained("Arabic_clip") 

    # model_pt = VisionTextDualEncoderModel.from_pretrained("xx", from_flax=True)

    def tokenize(texts):
        inputs = tokenizer(texts, max_length=96, padding="max_length", return_tensors="np")
        return inputs['input_ids'], inputs['attention_mask']

    language_model = lambda queries: np.asarray(model.get_text_features(*tokenize(queries)))
    image_model = lambda images: np.asarray(model.get_image_features(images.permute(0, 2, 3, 1).numpy(),))

# Preparing the translated ImageNet labels

In [None]:
# !wget -N -q https://huggingface.co/datasets/LinaAlhuri/ArabicImageNet/blob/main/ArabicImageNet.csv
classes_df = pd.read_csv('xx.csv')
imagenet_classes = list(classes_df['Arabic_Query_Short'])
imagenet_templates = ['{}']

print(f"{len(imagenet_classes)} classes, {len(imagenet_templates)} templates")

In [None]:
print(type(imagenet_classes))

# Set up Validation Set

In [None]:
# Composes several transforms together
val_preprocess = transforms.Compose([
    Resize([224], interpolation=InterpolationMode.BICUBIC),
    CenterCrop(224),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

In [None]:
print('Downloading Imagenet validation set...')
!wget -N -q --show-progress https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
print('Downloading Imagenet devkit...')
!wget -N -q --show-progress https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
print('Done.')

images = torchvision.datasets.ImageNet('./', split='val', transform=val_preprocess)
# wraps an iterable around the Dataset to enable easy access to the samples
loader = torch.utils.data.DataLoader(
    images,
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    persistent_workers=True,
    drop_last=False
)

# Creating zero-shot classifier weights

In [None]:
def zeroshot_classifier(classnames, templates):
    zeroshot_weights = []
    for classname in tqdm(classnames):
        texts = [template.format(classname) for template in templates]
        class_embeddings = language_model(texts)
        # np.linalg.norm this function is able to return one of eight different matrix norms, or one of an infinite number of vector norms
        class_embeddings = class_embeddings / np.linalg.norm(class_embeddings, axis=-1, keepdims=True)
        class_embedding = np.mean(class_embeddings, axis=0)
        class_embedding /= np.linalg.norm(class_embedding, axis=-1)
        zeroshot_weights.append(class_embedding)
    zeroshot_weights = np.stack(zeroshot_weights, axis=1)
    return zeroshot_weights

zeroshot_weights = zeroshot_classifier(imagenet_classes, imagenet_templates)

# Zero-shot prediction

In [None]:
def accuracy(output, target, topk=(1,)):
    output = torch.from_numpy(np.asarray(output))
    target = torch.from_numpy(np.asarray(target))
    pred = output.topk(max(topk), dim=1, largest=True, sorted=True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [None]:
top_ns = [1, 5, 10, 100]
acc_counters = [0. for _ in top_ns]
n = 0.

for i, (images, target) in enumerate(tqdm(loader)):
    images = images
    target = target.numpy()
    # predict
    image_features = image_model(images)
    image_features = image_features / np.linalg.norm(image_features, axis=-1, keepdims=True)
    logits = 100. * image_features @ zeroshot_weights

    # measure accuracy
    accs = accuracy(logits, target, topk=top_ns)
    for j in range(len(top_ns)):
        acc_counters[j] += accs[j]
    n += images.shape[0]

tops = {f'top{top_ns[i]}': acc_counters[i] / n * 100 for i in range(len(top_ns))}
print(tops)