In [1]:
import os
import argparse
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.nn.functional as F
import clip
import loraclip
from torchvision import datasets, transforms
from transformers import ViTImageProcessor, ViTForImageClassification
from datasets import load_dataset # HuggingFace dedicated lib
from encoder_utils import build_faiss_index, predict_with_faiss, compute_topk_accuracy
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/16", device=device)

In [3]:
clip_preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7fbe90d57d90>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [4]:
path_data = "./data"
os.makedirs(path_data, exist_ok=True)

In [5]:
cifar10_dataset_train = datasets.CIFAR10(root=path_data, train=True, download=True, transform=clip_preprocess)
cifar10_dataset_test = datasets.CIFAR10(root=path_data, train=False, download=True, transform=clip_preprocess)

cifar10_loader_train = torch.utils.data.DataLoader(cifar10_dataset_train, batch_size=64, shuffle=True)
cifar10_loader_test = torch.utils.data.DataLoader(cifar10_dataset_test, batch_size=64, shuffle=False)


In [6]:
for imgs, file_ids in tqdm(cifar10_loader_test, desc="Predicting with FAISS"):
    print(file_ids)
    print(imgs)
    break

Predicting with FAISS:   0%|          | 0/157 [00:00<?, ?it/s]

tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6, 7, 0, 4, 9,
        5, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 2, 4, 1, 9, 5, 4, 6, 5, 6, 0, 9, 3, 9,
        7, 6, 9, 8, 0, 3, 8, 8, 7, 7, 4, 6, 7, 3, 6, 3])
tensor([[[[ 0.5289,  0.5289,  0.5289,  ..., -0.1134, -0.1134, -0.1134],
          [ 0.5289,  0.5289,  0.5289,  ..., -0.1134, -0.1134, -0.1134],
          [ 0.5143,  0.5143,  0.5143,  ..., -0.1134, -0.1134, -0.1134],
          ...,
          [-1.0039, -1.0039, -1.0039,  ..., -1.5003, -1.5003, -1.5149],
          [-1.0185, -1.0185, -1.0185,  ..., -1.5003, -1.5003, -1.5149],
          [-1.0185, -1.0185, -1.0185,  ..., -1.5003, -1.5003, -1.5149]],

         [[-0.0712, -0.0712, -0.0712,  ..., -0.4764, -0.4914, -0.4914],
          [-0.0712, -0.0712, -0.0712,  ..., -0.4764, -0.4914, -0.4914],
          [-0.0712, -0.0712, -0.0712,  ..., -0.4764, -0.4914, -0.4914],
          ...,
          [-0.1613, -0.1613, -0.1613,  ..., -0.7616, -0.7766, -0.7766],
          [-0.1613, -0.1613, -0




In [7]:
print(cifar10_dataset_train)
print(cifar10_dataset_test)
print()
print(type(cifar10_dataset_train))

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
               CenterCrop(size=(224, 224))
               <function _convert_image_to_rgb at 0x7fbe90d57d90>
               ToTensor()
               Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
           )
Dataset CIFAR10
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: Compose(
               Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
               CenterCrop(size=(224, 224))
               <function _convert_image_to_rgb at 0x7fbe90d57d90>
               ToTensor()
               Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
           )

<class 'torchvision.datasets.cifar.CIFAR10'>


In [8]:
def clip_preprocess_fn(imgs):
    """
    Preprocess CLIP embeddings - normalize for cosine similarity.
    """
    device = next(clip_model.parameters()).device
    imgs = imgs.to(device)
    
    clip_model.eval()
    with torch.no_grad():
        features = clip_model.encode_image(imgs)
        features = F.normalize(features, p=2, dim=-1)
    
    return features

In [9]:
faiss_labels, faiss_index = build_faiss_index(
    dataloader=cifar10_loader_train,
    preprocess_fn=clip_preprocess_fn,
    device=device
)

Building FAISS Index: 100%|██████████| 782/782 [05:59<00:00,  2.18it/s]


FAISS index built with 50000 entries.


In [12]:
import numpy as np
def predict_with_faiss(dataloader, preprocess_fn, faiss_index, faiss_labels,
                              device="cuda", top_k=5, distractor_classes=None):
    """
    Predict top-k classes using FAISS with duplicate and distractor handling.

    Args:
        dataloader (DataLoader): DataLoader for test data.
        preprocess_fn (function): Preprocessing function for embeddings.
        faiss_index (faiss.Index): Prebuilt FAISS index.
        faiss_labels (np.ndarray): Labels corresponding to FAISS index.
        device (str): Device to run computations ("cuda" or "cpu").
        top_k (int): Number of predictions to return.
        distractor_classes (set): Classes treated as distractors.

    Returns:
        ground_truth (list): List of true labels.
        results (list): List of top-k predictions.
    """
    assert faiss_index is not None, "FAISS index is not built. Call build_faiss_index_global() first."

    if distractor_classes is None:
        distractor_classes = {}

    results = []
    ground_truth = []

    with torch.no_grad():
        for imgs, lbls in tqdm(dataloader, desc="Predicting with FAISS"):
            imgs = imgs.to(device)
            embeddings = preprocess_fn(imgs)

            features = np.ascontiguousarray(embeddings.cpu().numpy(), dtype=np.float32)
            distances, indices = faiss_index.search(features, top_k * 2)

            for i in range(len(features)):
                top_classes = faiss_labels[indices[i]].tolist()

                seen = set()
                filtered_classes = []
                for cls in top_classes:
                    if cls not in seen:
                        filtered_classes.append(cls)
                        seen.add(cls)
                    if len(filtered_classes) == top_k:
                        break

                predictions = []
                if filtered_classes and filtered_classes[0] in distractor_classes:
                    predictions.append(filtered_classes[0])
                    predictions.append(-1)
                    predictions.extend(filtered_classes[1:])
                else:
                    predictions = filtered_classes

                predictions = predictions[:top_k]

                if len(predictions) < top_k:
                    predictions += [-1] * (top_k - len(predictions))

                results.append(predictions)

            ground_truth.extend(lbls.cpu().numpy().tolist())

    return ground_truth, results

In [13]:
ground_truth, predictions = predict_with_faiss(
    dataloader=cifar10_loader_test,
    preprocess_fn=clip_preprocess_fn,
    faiss_index=faiss_index,
    faiss_labels=faiss_labels,
    device=device,
    top_k=5,
    distractor_classes=None
)

Predicting with FAISS: 100%|██████████| 157/157 [02:12<00:00,  1.19it/s]


In [14]:
accuracy_top1 = compute_topk_accuracy(ground_truth, predictions, top_k=1)
print('top 1 accuracy', accuracy_top1)

top 1 accuracy 0.9354


In [15]:
accuracy_top2 = compute_topk_accuracy(ground_truth, predictions, top_k=2)
print('top 2 accuracy', accuracy_top2)

top 2 accuracy 0.9747


In [16]:
accuracy_top5 = compute_topk_accuracy(ground_truth, predictions, top_k=5)
print('top 5 accuracy', accuracy_top5)

top 5 accuracy 0.988
