In [None]:
from google.colab import drive
drive.mount('/content/drive')
!pip install transformers torch torchvision
from transformers import CLIPProcessor, CLIPModel
import torch
import torchvision.transforms as transforms
import os
import urllib
import json
from PIL import Image
from copy import deepcopy
import numpy as np
from torch.optim import Adam
from torch.nn import BatchNorm1d
from tqdm import tqdm

# init the processor and model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to('cuda')
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
model.eval()  # Evaluation mode

# dataset and labels paths
dataset_path = '/content/drive/MyDrive/imagenetv2-matched-frequency/imagenetv2-matched-frequency-format-val'
labels_path = '/content/drive/MyDrive/imagenet1000_clsidx_to_labels.txt'

if os.path.exists(dataset_path):
    print(f"Dataset folder {dataset_path} found")
else:
    print(f"Dataset folder {dataset_path} not found")

# read labels
with open(labels_path, 'r') as f:
    clsidx_to_labels = ast.literal_eval(f.read())

# Download ImageNet labels
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
response = urllib.request.urlopen(LABELS_URL)
imagenet_labels = json.loads(response.read().decode())

# Generate CLIP-compatible text labels
def generate_clip_labels(imagenet_labels):
    return [f"a photo of a {label}" for label in imagenet_labels]

clip_labels = generate_clip_labels(imagenet_labels)
text_inputs = processor(text=clip_labels, return_tensors="pt", padding=True).to('cuda')

with torch.no_grad():
    text_features = model.get_text_features(**text_inputs)

# augmentations
augmentations = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

# apply augmentations
def apply_augmentations(image, num_augmentations):
    augmented_images = [augmentations(image) for _ in range(num_augmentations)]
    return torch.stack(augmented_images)

# compute entropy loss
def entropy_loss(preds):
    p_log_p = preds * torch.log(preds + 1e-10)
    return -p_log_p.sum(dim=1).mean()

# MEMO adaptation function
def memo_adaptation(model, image, num_augmentations=5, num_steps=1, lr=1e-5):
    model.eval()  # evaluation mode

    # generate custom augmentations
    augmented_images = apply_augmentations(image, num_augmentations).to('cuda')

    # optimizer
    optimizer = Adam(model.parameters(), lr=lr)

    # batch normalization
    batch_norm = BatchNorm1d(512)#.to('cuda')

    for step in range(num_steps):
        optimizer.zero_grad()

        # compute predictions
        outputs = model.get_image_features(augmented_images)
        outputs = batch_norm(outputs)

        # compute marginal distribution
        outputs = torch.softmax(outputs, dim=1)
        marginal_output = outputs.mean(dim=0, keepdim=True)

        # compute entropy loss
        loss = entropy_loss(marginal_output)
        loss.backward()
        optimizer.step()

    return model

results = []
ground_truth_labels = []

for folders_counter, class_folder in enumerate(os.listdir(dataset_path)):
    if folders_counter >= 9999:  # limit folders for faster testing
        break

    class_folder_path = os.path.join(dataset_path, class_folder)

    if not os.path.isdir(class_folder_path):
        continue

    ground_truth_label_idx = int(class_folder)  # class index is the folder name
    ground_truth_label = clsidx_to_labels[ground_truth_label_idx]
    images_counter = 0
    for img_name in os.listdir(class_folder_path):
        if images_counter >= 5:  # limit the images for faster testing
          break
        images_counter += 1
        img_path = os.path.join(class_folder_path, img_name)
        image = Image.open(img_path).convert('RGB')

        # apply MEMO
        adapted_model = memo_adaptation(model, image)

        # evaluate the adapted model
        with torch.no_grad():
            inputs = processor(images=image, return_tensors="pt").to('cuda')
            outputs = adapted_model.get_image_features(inputs['pixel_values'])
            similarities = torch.matmul(outputs, text_features.T)
            pred_label_idx = torch.argmax(similarities, dim=1).item()
            pred_label = imagenet_labels[pred_label_idx]
            results.append(pred_label)
            ground_truth_labels.append(ground_truth_label)

        # reset the weights
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to('cuda')

# compute accuracy
accuracy = np.mean([1 if pred == true else 0 for pred, true in zip(results, ground_truth_labels)])
print(f"Accuracy: {accuracy * 100:.2f}%")

