### Imports

In [1]:
import numpy as np 
import pandas 
import open_clip
import torch 
import torch.nn as nn
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
import torch.nn.functional as Func
from sklearn.linear_model import LogisticRegression
import torch.optim as optim
from torch.utils.data import random_split, DataLoader
from tqdm import tqdm
import copy 
device = "cuda" if torch.cuda.is_available() else "cpu"

### Model Definition

In [2]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') # pretrained on pretrained on LAION-2B
                                                                                                         # preprocess: this is the image processing pipeline
tokenizer = open_clip.get_tokenizer('ViT-B-32')
model = model.to(device=device).eval() # we put it on eval mode as we want to investigate zero-shot mode. 

### Dataset

In [3]:
train_dataset, val_dataset = random_split(CIFAR10(root='./data', train=True, download=True, transform=preprocess), [45000, 5000])
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=preprocess)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset,   batch_size=256, shuffle=False)
test_loader = DataLoader(test_dataset,  batch_size=256, shuffle=False, num_workers=4)
class_names = test_dataset.classes  # CIFAR-10 names

Files already downloaded and verified
Files already downloaded and verified


### Zero-Shot Inference

In [22]:
def compute_text_embeddings(classnames, templates):
    texts = []
    for t in templates:
        for c in classnames:
            # texts.append(t.format(classname=c))
            texts.append(t.format(c) if "{}" in t else t.format(classname=c))
    tokens = tokenizer(texts).to(device)
    with torch.no_grad():
        text_emb = model.encode_text(tokens) # currently, its shape is: (#templates * #classes, Embedding_shape)
    # average across templates per class
    Embedding_shape = text_emb.shape[-1]
    text_emb = text_emb.reshape(len(templates), len(classnames), Embedding_shape).mean(dim=0) # (#templates * #classes, Embedding_shape) -> 
                                                                                              # (#templates, #classes, Embedding_shape) -> (#classes, Embedding_shape)
    text_emb = Func.normalize(text_emb, dim=-1) # we normalize our embeddings 
    return text_emb

# templates (it's adjustable)
templates = [
    "a photo of a {classname}.",
    "this is a {classname}.", 
    "this image is {classname}"
]
text_emb = compute_text_embeddings(class_names, templates)

# classify our test images
correct, total = 0, 0
with torch.no_grad():
    for imgs, labels in tqdm(test_loader):
        imgs = imgs.to(device)
        image_emb = model.encode_image(imgs)
        image_emb = Func.normalize(image_emb, dim=-1)
        logits = image_emb @ text_emb.T  # cosine similarity
        preds = logits.argmax(dim=1).cpu().numpy()
        correct += (preds == labels.numpy()).sum()
        total += labels.size(0)
print("Zero-shot accuracy:", correct/total)


100%|██████████| 40/40 [00:44<00:00,  1.10s/it]

Zero-shot accuracy: 0.9362





### Inference just based on obtained features from images 

#### 1. By Logistic Regression

In [9]:
def img_feature_extractor(loader):
    feats = [] # contains model-generated features for each batch
    labs  = [] # contains labels for each batch
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)
            feature = model.encode_image(imgs)
            feature = feature.cpu().numpy()
            feats.append(feature)
            labs.append(labels.numpy())
    feats = np.concatenate(feats, axis=0)
    labs  = np.concatenate(labs, axis=0)
    return feats, labs

train_features, train_labels = img_feature_extractor(train_loader)
val_features, val_labels = img_feature_extractor(val_loader)
test_features, test_labels = img_feature_extractor(test_loader)

clf = LogisticRegression(max_iter=1000, C=1.0)
clf.fit(train_features, train_labels)
acc = clf.score(test_features, test_labels)
print("Linear probe (logreg) accuracy:", acc)

Linear probe (logreg) accuracy: 0.9639


#### 1. By an MLP 

In [6]:
# Convert features and labels to tensors``
train_feats_tensor = torch.tensor(train_features, dtype=torch.float32)
train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)
val_feats_tensor = torch.tensor(val_features, dtype=torch.float32)
val_labels_tensor = torch.tensor(val_labels, dtype=torch.long)
test_feats_tensor  = torch.tensor(test_features, dtype=torch.float32)
test_labels_tensor  = torch.tensor(test_labels, dtype=torch.long)

# Create datasets and loaders
train_dataset = TensorDataset(train_feats_tensor, train_labels_tensor)
val_dataset   = TensorDataset(val_feats_tensor, val_labels_tensor)
test_dataset  = TensorDataset(test_feats_tensor, test_labels_tensor)

train_loader_mlp = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader_mlp   = DataLoader(val_dataset, batch_size=64)
test_loader_mlp  = DataLoader(test_dataset, batch_size=64)

# Define MLP 
num_features = train_features.shape[1]
num_classes  = len(np.unique(train_labels))
MLP_probe = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        ).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(MLP_probe.parameters(), lr=1e-3)

# Training loop
NUM_EPOCHS = 20
best_val_acc = 0.0
best_state_dict = None

for epoch in range(NUM_EPOCHS):
    MLP_probe.train()
    batch_bar = tqdm(train_loader_mlp, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [train]", leave=False)

    for x_batch, y_batch in batch_bar:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        optimizer.zero_grad()
        logits = MLP_probe(x_batch)
        loss = criterion(logits, y_batch)
        loss.backward()
        optimizer.step()

        batch_bar.set_postfix({"loss": loss.item()})

    # Validation
    MLP_probe.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x_batch, y_batch in val_loader_mlp:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)

            logits = MLP_probe(x_batch)
            preds = torch.argmax(logits, dim=1)

            correct += (preds == y_batch).sum().item()
            total += y_batch.size(0)

    val_acc = correct / total
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Validation Accuracy: {val_acc:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_state_dict = MLP_probe.state_dict()

# Load the best model
MLP_probe.load_state_dict(best_state_dict)
print(f"Loaded best model (Val Acc = {best_val_acc:.4f})")

# Final Test Evaluation
MLP_probe.eval()
correct = 0
total = 0

with torch.no_grad():
    for x_batch, y_batch in test_loader_mlp:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)
        logits = MLP_probe(x_batch)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == y_batch).sum().item()
        total += y_batch.size(0)
test_acc = correct / total
print(f"Final Test Accuracy: {test_acc:.4f}")

                                                                                    

Epoch 1/20 - Validation Accuracy: 0.9594


                                                                                    

Epoch 2/20 - Validation Accuracy: 0.9590


                                                                                    

Epoch 3/20 - Validation Accuracy: 0.9644


                                                                                    

Epoch 4/20 - Validation Accuracy: 0.9676


                                                                                     

Epoch 5/20 - Validation Accuracy: 0.9630


                                                                                    

Epoch 6/20 - Validation Accuracy: 0.9668


                                                                                     

Epoch 7/20 - Validation Accuracy: 0.9646


                                                                                     

Epoch 8/20 - Validation Accuracy: 0.9680


                                                                                     

Epoch 9/20 - Validation Accuracy: 0.9660


                                                                                      

Epoch 10/20 - Validation Accuracy: 0.9654


                                                                                      

Epoch 11/20 - Validation Accuracy: 0.9628


                                                                                      

Epoch 12/20 - Validation Accuracy: 0.9646


                                                                                      

Epoch 13/20 - Validation Accuracy: 0.9660


                                                                                      

Epoch 14/20 - Validation Accuracy: 0.9654


                                                                                      

Epoch 15/20 - Validation Accuracy: 0.9642


                                                                                      

Epoch 16/20 - Validation Accuracy: 0.9654


                                                                                      

Epoch 17/20 - Validation Accuracy: 0.9638


                                                                                      

Epoch 18/20 - Validation Accuracy: 0.9644


                                                                                      

Epoch 19/20 - Validation Accuracy: 0.9680


                                                                                      

Epoch 20/20 - Validation Accuracy: 0.9662
Loaded best model (Val Acc = 0.9680)
Final Test Accuracy: 0.9680


### Inference based on obtained features from both images and texts

In [None]:
class ProjectionMLP(nn.Module):
    def __init__(self, input_dim=512, hidden=512, output_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, output_dim)
        )
    def forward(self, x):
        return self.net(x)

# freeze CLIP model 
for p in model.parameters():
    p.requires_grad = False

img_head = ProjectionMLP(input_dim=512).to(device)
txt_head = ProjectionMLP(input_dim=512).to(device)
optimizer = torch.optim.Adam(list(img_head.parameters()) + list(txt_head.parameters()), lr=1e-3)
templates = ["a photo of a {}.", "a cropped photo of a {}."]
text_embeddings = compute_text_embeddings(class_names, templates).to(device)


epochs = 3
best_val_acc = 0.0
best_img_head_state = None
best_txt_head_state = None

for epoch in range(epochs):
    total_loss = 0
    # training loop
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        with torch.no_grad():
            img_emb = model.encode_image(imgs)
            img_emb = Func.normalize(img_emb, dim=-1)

        # pass through MLP heads
        img_proj = img_head(img_emb)
        txt_proj = txt_head(text_embeddings)
        img_proj = Func.normalize(img_proj, dim=-1)
        txt_proj = Func.normalize(txt_proj, dim=-1)

        logits = img_proj @ txt_proj.T
        loss = Func.cross_entropy(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, loss={total_loss:.4f}")

    # validation evaluation
    correct, total = 0, 0
    with torch.no_grad():
        txt_proj = txt_head(text_embeddings)
        txt_proj = Func.normalize(txt_proj, dim=-1)

        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            img_emb = model.encode_image(imgs)
            img_emb = Func.normalize(img_emb, dim=-1)

            img_proj = img_head(img_emb)
            img_proj = Func.normalize(img_proj, dim=-1)

            logits = img_proj @ txt_proj.T
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_acc = correct / total
    print(f"Validation accuracy: {val_acc:.4f}")

    # save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_img_head_state = copy.deepcopy(img_head.state_dict())
        best_txt_head_state = copy.deepcopy(txt_head.state_dict())

# load best model for testing 
img_head.load_state_dict(best_img_head_state)
txt_head.load_state_dict(best_txt_head_state)

# test evaluation 
correct, total = 0, 0
with torch.no_grad():
    txt_proj = txt_head(text_embeddings)
    txt_proj = Func.normalize(txt_proj, dim=-1)

    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        img_emb = model.encode_image(imgs)
        img_emb = Func.normalize(img_emb, dim=-1)

        img_proj = img_head(img_emb)
        img_proj = Func.normalize(img_proj, dim=-1)

        logits = img_proj @ txt_proj.T
        preds = logits.argmax(dim=1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

print("Test accuracy by best validation model:", correct / total)

Epoch 1, loss=510.3517
Epoch 2, loss=494.9237
Epoch 3, loss=493.2440
Accuracy after training MLP heads: 0.9674
