In [16]:
# dataset and labels are in the drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:
!pip install transformers torch torchvision




In [9]:
from transformers import CLIPProcessor, CLIPModel
import torch

# Init model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to('cuda')
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
model.eval()  # evaluation mode

CLIPModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 512)
      (position_embedding): Embedding(77, 512)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=512, out_features=2048, bias=True)
            (fc2): Linear(in_features=2048, out_features=512, bias=True)
          )
          (layer_norm2): LayerNorm((512,), eps=1e-05,

In [10]:
import os
import ast

# dataset folder path + check
dataset_path = '/content/drive/MyDrive/imagenetv2-matched-frequency/imagenetv2-matched-frequency-format-val'
if os.path.exists(dataset_path):
    print(f"Dataset folder {dataset_path} found")
else:
    print(f"Dataset folder {dataset_path} not found")


# dataset labels path
file_path = '/content/drive/MyDrive/imagenet1000_clsidx_to_labels.txt'


# reading and creating dictionary
with open(file_path, 'r') as f:
    file_content = f.read()

clsidx_to_labels = ast.literal_eval(file_content)



La cartella /content/drive/MyDrive/imagenetv2-matched-frequency/imagenetv2-matched-frequency-format-val è presente.


In [11]:
import torchvision.transforms as transforms

# standard augmentation of CLIP processor
def augmentations(image, processor):
    return processor(images=image, return_tensors="pt")["pixel_values"].squeeze()

In [12]:
import torch.nn.functional as F

def entropy_loss(preds):
    p_log_p = preds * torch.log(preds + 1e-10)
    return -p_log_p.sum(dim=1).mean()


In [15]:
from torch.optim import Adam
from torch.nn import BatchNorm1d
from PIL import Image


def memo_adaptation(model, processor, image, num_augmentations=5, lr=1e-5):
    model.train()
    optimizer = Adam(model.parameters(), lr=lr)

    # augmentations generation and batch processing
    augmented_images = [augmentations(image, processor) for _ in range(num_augmentations)]
    augmented_images = torch.stack(augmented_images).to('cuda')

    batch_norm = BatchNorm1d(512).to('cuda')
    optimizer.zero_grad()

    # compute predictions
    outputs = model.get_image_features(augmented_images)
    outputs = batch_norm(outputs)
    print(outputs.shape)

    # compute marginal distributions
    outputs = F.softmax(outputs, dim=1)
    marginal_output = outputs.mean(dim=0, keepdim=True)

    # compute entropy loss
    loss = entropy_loss(marginal_output)
    loss.backward()
    optimizer.step()

    return model

In [14]:
import os
import numpy as np

results = []
ground_truth_labels = []

# cycle on dataset folders
for folder_counter, class_folder in enumerate(os.listdir(dataset_path)):
    if folder_counter >= 1:  # to limit the number of folders
        break

    class_folder_path = os.path.join(dataset_path, class_folder)
    if not os.path.isdir(class_folder_path):
        continue

    ground_truth_label_idx = int(class_folder)  # the name of the folders is the class name (see labels file)
    ground_truth_label = clsidx_to_labels[ground_truth_label_idx]

    image_counter = 0
    for img_name in os.listdir(class_folder_path):
        if image_counter >= 1:  # to limit the number of images
          break
        image_counter+=1
        img_path = os.path.join(class_folder_path, img_name)
        image = Image.open(img_path).convert('RGB')

        # MEMO
        adapted_model = memo_adaptation(model, processor, image)

        # new evaluation
        with torch.no_grad():
            inputs = processor(images=image, return_tensors="pt").to('cuda')
            outputs = adapted_model.get_image_features(inputs['pixel_values'])
            pred_label_idx = torch.argmax(outputs, dim=1).item()
            pred_label = clsidx_to_labels[pred_label_idx]
            results.append(pred_label)
            ground_truth_labels.append(ground_truth_label)

        # reset of the weights
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16").to('cuda')

# accuracy
accuracy = np.mean([1 if pred == true else 0 for pred, true in zip(results, ground_truth_labels)])
print(f"Accuracy: {accuracy * 100:.2f}%")

torch.Size([5, 512])
Accuracy: 0.00%
