In [None]:
%%capture
!pip install git+https://github.com/openai/CLIP.git
!pip install loraclip
!pip install faiss-gpu

In [21]:
import os
import argparse
import torch
import random
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torch.nn.functional as F
import clip
import loraclip
from torchvision import datasets, transforms
from transformers import ViTImageProcessor, ViTForImageClassification
from datasets import load_dataset # HuggingFace dedicated lib
from encoder_utils import build_faiss_index, predict_with_faiss

In [29]:
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/16", device=device)

In [30]:
clip_preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x1501daef0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [28]:
path_data = "./data"
os.makedirs(path_data, exist_ok=True)

In [48]:
cifar10_dataset_train = datasets.CIFAR10(root=path_data, train=True, download=True, transform=clip_preprocess)
cifar10_dataset_test = datasets.CIFAR10(root=path_data, train=False, download=True, transform=clip_preprocess)

cifar10_loader_train = torch.utils.data.DataLoader(cifar10_dataset_train, batch_size=64, shuffle=True)
cifar10_loader_test = torch.utils.data.DataLoader(cifar10_dataset_test, batch_size=64, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [52]:
print(cifar10_dataset_train)
print(cifar10_dataset_test)
print()
print(type(cifar10_dataset_train))

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
               CenterCrop(size=(224, 224))
               <function _convert_image_to_rgb at 0x1501daef0>
               ToTensor()
               Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
           )
Dataset CIFAR10
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: Compose(
               Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
               CenterCrop(size=(224, 224))
               <function _convert_image_to_rgb at 0x1501daef0>
               ToTensor()
               Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
           )

<class 'torchvision.datasets.cifar.CIFAR10'>


In [53]:
def clip_preprocess_fn(imgs):
    """
    Preprocess CLIP embeddings - normalize for cosine similarity.
    """
    clip_model.eval()
    with torch.no_grad():
        # Extract features using CLIP's image encoder
        features = clip_model.encode_image(imgs)  # Shape: [Batch, 768]
        features = F.normalize(features, p=2, dim=-1)  # L2 normalization
    return features

In [None]:
faiss_labels, faiss_index = build_faiss_index(
    dataloader=cifar10_loader_train,
    preprocess_fn=clip_preprocess_fn,
    device=device
)

clip_results = predict_with_faiss(
    dataloader=cifar10_loader_test,
    preprocess_fn=clip_preprocess_fn,
    faiss_index=faiss_index,
    faiss_labels=faiss_labels,
    device=device,
    top_k=5,
    distractor_classes={}
)