In [None]:
import torch
import os

import numpy as np

from imagebind import data
from imagebind.models import imagebind_model
from src.imagenet_labels import lab_dict
from tqdm.notebook import tqdm
from imagebind.models.imagebind_model import ModalityType

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)

In [None]:
text_list = [lab_dict[i].replace('_', ' ') for i in os.listdir('../data/imagenet')]
text_list = [f"a {c}" for c in text_list]

In [None]:
def get_acc(gt, preds = None):
    if preds is not None: 
        return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()
    return ((preds.argmax(1)==gt).sum()/len(preds)).cpu().numpy()


def compute(model, text, images, labels, device):
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(text, device),
        ModalityType.VISION: data.load_and_transform_vision_data(images, device),
    }
    
    with torch.no_grad():
        embeddings = model(inputs)
    
    probs = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1)
    val_acc = get_acc(labels, probs)
    return val_acc
    
def get_image_paths(root):
    path_dict = {}
    for cls in tqdm(os.listdir(root)):
        path_list = []
        cls_path = os.path.join(root, cls)
        for img in os.listdir(cls_path):
            img_path = os.path.join(cls_path, img)
            path_list.append(img_path)
        path_dict[lab_dict[cls].replace('_', ' ')] = path_list
    return path_dict

def get_test_acc(image_paths, device):
    eval_acc = []
    for i in tqdm(range(len(text_list))):
        eval_acc.append(
            compute(model, text_list, image_paths[text_list[i][2:]], torch.tensor([i]*50), device)
        ) # 50 samples per class; first 2 chars are "a "
        
    return np.mean(eval_acc)

In [None]:
path_to_imagenet = '../data/imagenet'
path_to_imagenet_c = '../data/imagenet-c'

In [None]:
image_paths = get_image_paths(path_to_imagenet)
clean_acc = get_test_acc(image_paths, device)

In [None]:
clean_acc

In [None]:
gaussian_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    image_paths = get_image_paths(os.path.join(path_to_imagenet_c, 'gaussian_noise', str(sev)))
    gaussian_noise_acc.append(get_test_acc(image_paths, device))

In [None]:
gaussian_noise_acc

In [None]:
impulse_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    image_paths = get_image_paths(os.path.join(path_to_imagenet_c, 'impulse_noise', str(sev)))
    impulse_noise_acc.append(get_test_acc(image_paths, device))

In [None]:
impulse_noise_acc

In [None]:
shot_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    image_paths = get_image_paths(os.path.join(path_to_imagenet_c, 'shot_noise', str(sev)))
    shot_noise_acc.append(get_test_acc(image_paths, device))

In [None]:
shot_noise_acc

In [None]:
speckle_noise_acc = []
for sev in tqdm([1, 2, 3, 4, 5]):
    image_paths = get_image_paths(os.path.join(path_to_imagenet_c, 'speckle_noise', str(sev)))
    speckle_noise_acc.append(get_test_acc(image_paths, device))

In [None]:
speckle_noise_acc

In [None]:
res = {
    'gaussian_noise_acc' : gaussian_noise_acc,
    'impulse_noise_acc' : impulse_noise_acc,
    'shot_noise_acc': shot_noise_acc,
    'speckle_noise_acc': speckle_noise_acc,
    'clean_acc': clean_acc
}

In [None]:
res