# Zero-Shot Image Classification on Caltech-101 using ViT-Small (Image Encoder) + BERT (Text Encoder)

This notebook implements a CLIP-like zero-shot classifier on the **Caltech-101** dataset (101 object classes, excluding background).

- **Image encoder**: Frozen ViT-Small (`WinKawaks/vit-small-patch16-224`, 384-dim)
- **Text encoder**: Frozen BERT-base-uncased (768-dim)
- **Trainable part**: Simple linear projection to align image embeddings to text space
- **Setup**: Standard few-shot split (30 training images per class, rest for test)
- **Steps**:
  1. Random-weight baseline (~1% accuracy)
  2. Quick training of the projection head
  3. Final accuracy after training

Note: Because the encoders are **mismatched** (not contrastively pre-trained together), pure zero-shot will be poor. We train a small adapter for better alignment.

In [1]:
import torch
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NO GPU")

!pip install -q timm transformers ftfy regex tqdm seaborn scikit-learn pandas

GPU: Tesla T4
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
# Cell 1: Imports and Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.datasets import Caltech101
from transformers import AutoImageProcessor, ViTModel, AutoTokenizer, AutoModel
from tqdm.notebook import tqdm
import numpy as np
from sklearn.metrics import accuracy_score
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [3]:
# Cell 2: Load Encoders (Frozen)
# Image Encoder: ViT-Small
vit_model_name = "WinKawaks/vit-small-patch16-224"
vit_processor = AutoImageProcessor.from_pretrained(vit_model_name)
vit_model = ViTModel.from_pretrained(vit_model_name).to(device)
vit_model.eval()
vit_model.requires_grad_(False)

# Text Encoder: BERT-base
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name).to(device)
bert_model.eval()
bert_model.requires_grad_(False)

print("ViT dim:", vit_model.config.hidden_size)  # 384
print("BERT dim:", bert_model.config.hidden_size)  # 768

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

ViT dim: 384
BERT dim: 768


In [4]:
# Cell 3: Define the 101 Caltech-101 Classes (excluding BACKGROUND_Google)
classes = [
    'Faces', 'Faces_easy', 'Leopards', 'Motorbikes', 'accordion', 'airplanes', 'anchor', 'ant',
    'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha',
    'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair',
    'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile',
    'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly',
    'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'ferry', 'flamingo',
    'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill',
    'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree',
    'ketch', 'lamp', 'laptop', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly',
    'menorah', 'metronome', 'minaret', 'nautilus', 'octopus', 'okapi', 'pagoda',
    'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster',
    'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy',
    'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry',
    'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lily', 'wheelchair',
    'wild_cat', 'windsor_chair', 'wrench', 'yin_yang'
]
num_classes = len(classes)
print(f"Number of classes: {num_classes}")

Number of classes: 100


In [5]:
# Cell 4: Compute Text Embeddings (BERT CLS tokens)
prompts = [f"a photo of a {c.lower().replace('_', ' ')}." for c in classes]

inputs = bert_tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = bert_model(**inputs)
    text_embeds = outputs.last_hidden_state[:, 0, :]  # [101, 768]

# L2 normalize (important for cosine similarity)
text_embeds = F.normalize(text_embeds, dim=-1)

print("text_embeds shape:", text_embeds.shape)

text_embeds shape: torch.Size([100, 768])


In [6]:
# Cell 5: Define Trainable Projection Head
class ProjectionHead(nn.Module):
    def __init__(self, img_dim=384, text_dim=768, temperature=10.0):
        super().__init__()
        self.proj = nn.Linear(img_dim, text_dim)
        self.temperature = temperature

    def forward(self, img_embeds, text_embeds):
        img_proj = self.proj(img_embeds)           # [B, 768]
        img_proj = F.normalize(img_proj, dim=-1)   # L2 norm
        logits = img_proj @ text_embeds.T * self.temperature  # [B, 101]
        return logits

model = ProjectionHead().to(device)
print(model)

ProjectionHead(
  (proj): Linear(in_features=384, out_features=768, bias=True)
)


In [11]:
# Cell 6 (Updated): Load and Split Caltech-101 with Grayscale → RGB Conversion
class Caltech101Filtered(Dataset):
    def __init__(self, root='./data', train=True, download=True):
        self.full_dataset = Caltech101(root=root, download=download)  # Returns PIL images
        self.train = train

        # Build class indices (exclude BACKGROUND_Google)
        class_indices = {c: [] for c in classes}
        background_name = 'BACKGROUND_Google'

        for idx in range(len(self.full_dataset)):
            _, label = self.full_dataset[idx]
            class_name = self.full_dataset.categories[label]
            if class_name in classes:
                class_indices[class_name].append(idx)

        # Standard split: 30 train per class
        self.indices = []
        self.labels = []
        self.old_to_new = {classes[i]: i for i in range(num_classes)}

        random.seed(42)
        for class_name, idxs in class_indices.items():
            random.shuffle(idxs)
            split_point = min(30, len(idxs)) if train else 30
            selected = idxs[:split_point] if train else idxs[split_point:]
            self.indices.extend(selected)
            new_label = self.old_to_new[class_name]
            self.labels.extend([new_label] * len(selected))

        print(f"{'Train' if train else 'Test'} size: {len(self.indices)}")

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        orig_idx = self.indices[idx]
        img, _ = self.full_dataset[orig_idx]  # PIL image
        img = img.convert("RGB")               # ← FIX: Convert all to RGB (handles grayscale images)
        label = self.labels[idx]
        return img, label

# Recreate datasets/loaders after the fix
train_dataset = Caltech101Filtered(train=True, download=False)   # download=False if already downloaded
test_dataset = Caltech101Filtered(train=False, download=False)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

Train size: 2970
Test size: 5584


In [12]:
# Cell 7: Evaluation Function (Batched)
@torch.no_grad()
def evaluate(loader):
    all_preds = []
    all_labels = []

    for pil_imgs, labels in tqdm(loader, desc="Evaluating"):
        labels = labels.to(device)

        inputs = vit_processor(images=pil_imgs, return_tensors="pt").to(device)
        outputs = vit_model(**inputs)
        img_embeds = outputs.last_hidden_state[:, 0, :]  # [B, 384]

        logits = model(img_embeds, text_embeds)
        preds = logits.argmax(dim=-1).cpu().tolist()

        all_preds.extend(preds)
        all_labels.extend(labels.cpu().tolist())

    acc = accuracy_score(all_labels, all_preds)
    return acc

In [13]:
# Cell 8: Zero-Shot with Random Weights
model.eval()
print("Running zero-shot with random projection...")
acc_random = evaluate(test_loader)
print(f"Zero-shot accuracy (random weights): {acc_random:.3%}")
# Expected: ~1% (random guessing over 101 classes)

Running zero-shot with random projection...


Evaluating:   0%|          | 0/175 [00:00<?, ?it/s]

Zero-shot accuracy (random weights): 5.211%


In [14]:
# Cell 9: Train the Projection Head
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.train()

print("Starting training (5 epochs on ~30 images/class)...")
for epoch in range(5):
    losses = []
    for pil_imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/5"):
        labels = labels.to(device)

        inputs = vit_processor(images=pil_imgs, return_tensors="pt").to(device)
        outputs = vit_model(**inputs)
        img_embeds = outputs.last_hidden_state[:, 0, :]  # [B, 384]

        logits = model(img_embeds, text_embeds)
        loss = F.cross_entropy(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    print(f"Epoch {epoch+1} - Avg loss: {np.mean(losses):.4f}")

# Final evaluation
model.eval()
acc_trained = evaluate(test_loader)
print(f"Accuracy after training: {acc_trained:.3%}")
# Expected: 50-75%+ depending on run (good improvement from random!)

Starting training (5 epochs on ~30 images/class)...


Epoch 1/5:   0%|          | 0/93 [00:00<?, ?it/s]

Epoch 1 - Avg loss: 3.8399


Epoch 2/5:   0%|          | 0/93 [00:00<?, ?it/s]

Epoch 2 - Avg loss: 3.3996


Epoch 3/5:   0%|          | 0/93 [00:00<?, ?it/s]

Epoch 3 - Avg loss: 3.3163


Epoch 4/5:   0%|          | 0/93 [00:00<?, ?it/s]

Epoch 4 - Avg loss: 3.2823


Epoch 5/5:   0%|          | 0/93 [00:00<?, ?it/s]

Epoch 5 - Avg loss: 3.2639


Evaluating:   0%|          | 0/175 [00:00<?, ?it/s]

Accuracy after training: 83.506%
