# Zero-Shot Image Classification on CIFAR-10 using ViT-Small (Image Encoder) + BERT (Text Encoder)

This notebook implements a CLIP-like zero-shot classifier on the **CIFAR-10** dataset (10 classes, 60,000 32×32 color images).

- **Image encoder**: Frozen ViT-Small (`WinKawaks/vit-small-patch16-224`, 384-dim)
- **Text encoder**: Frozen BERT-base-uncased (768-dim)
- **Trainable part**: Simple linear projection to align image embeddings to text space
- **Setup**: Standard split (50,000 train, 10,000 test)
- **Steps**:
  1. Random-weight baseline (~10% accuracy)
  2. Quick training of the projection head on the full training set (or subset for faster runs)
  3. Final accuracy after training

Note: Encoders are mismatched (not contrastively pre-trained together), so pure zero-shot is ~random. The small adapter learns alignment quickly.

In [1]:
import torch
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NO GPU")

!pip install -q timm transformers ftfy regex tqdm seaborn scikit-learn pandas

GPU: Tesla T4
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
# Cell 1: Imports and Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from torchvision import transforms
from transformers import AutoImageProcessor, ViTModel, AutoTokenizer, AutoModel
from tqdm.notebook import tqdm
import numpy as np
from sklearn.metrics import accuracy_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [3]:
# Cell 2: Load Encoders (Frozen)
# Image Encoder: ViT-Small
vit_model_name = "WinKawaks/vit-small-patch16-224"
vit_processor = AutoImageProcessor.from_pretrained(vit_model_name)
vit_model = ViTModel.from_pretrained(vit_model_name).to(device)
vit_model.eval()
vit_model.requires_grad_(False)

# Text Encoder: BERT-base-uncased
bert_model_name = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name).to(device)
bert_model.eval()
bert_model.requires_grad_(False)

print("ViT-Small dim:", vit_model.config.hidden_size)  # 384
print("BERT dim:", bert_model.config.hidden_size)       # 768

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

ViT-Small dim: 384
BERT dim: 768


In [4]:
# Cell 3: CIFAR-10 Classes and Text Prompts
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(classes)

prompts = [f"a photo of a {c}." for c in classes]  # CLIP-style prompt (works well even with BERT)

# Compute text embeddings (BERT CLS tokens)
inputs = bert_tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = bert_model(**inputs)
    text_embeds = outputs.last_hidden_state[:, 0, :]  # [10, 768]

# L2 normalize (essential for cosine similarity)
text_embeds = F.normalize(text_embeds, dim=-1)

print("text_embeds shape:", text_embeds.shape)

text_embeds shape: torch.Size([10, 768])


In [5]:
# Cell 4: Define Trainable Projection Head
class ProjectionHead(nn.Module):
    def __init__(self, img_dim=384, text_dim=768, temperature=10.0):
        super().__init__()
        self.proj = nn.Linear(img_dim, text_dim)
        self.temperature = temperature  # Fixed scale (common value)

    def forward(self, img_embeds, text_embeds):
        img_proj = self.proj(img_embeds)           # [B, 768]
        img_proj = F.normalize(img_proj, dim=-1)
        logits = img_proj @ text_embeds.T * self.temperature  # [B, 10]
        return logits

model = ProjectionHead().to(device)
print(model)

ProjectionHead(
  (proj): Linear(in_features=384, out_features=768, bias=True)
)


In [6]:
# Cell 5: Load CIFAR-10 Datasets (ToTensor → tensors in [0,1])
transform = transforms.ToTensor()  # Images become [3, 32, 32] float tensors

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset  = CIFAR10(root='./data', train=False, download=True, transform=transform)

# Optional: Use a subset of training data for faster experimentation
# Uncomment the lines below for quick runs (~few minutes)
# indices = torch.randperm(len(train_dataset))[:10000]
# train_dataset = Subset(train_dataset, indices)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

print(f"Train size: {len(train_dataset)} | Test size: {len(test_dataset)}")

100%|██████████| 170M/170M [00:05<00:00, 30.6MB/s]


Train size: 50000 | Test size: 10000




In [7]:
# Cell 6: Evaluation Function (Batched, uses tensors directly)
@torch.no_grad()
def evaluate(loader):
    all_preds = []
    all_labels = []

    model.eval()
    for imgs, labels in tqdm(loader, desc="Evaluating"):
        imgs = imgs.to(device)
        labels = labels.to(device)

        inputs = vit_processor(images=imgs, return_tensors="pt").to(device)
        outputs = vit_model(**inputs)
        img_embeds = outputs.last_hidden_state[:, 0, :]  # [B, 384]

        logits = model(img_embeds, text_embeds)
        preds = logits.argmax(dim=-1)

        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

    acc = accuracy_score(all_labels, all_preds)
    return acc

In [8]:
# Cell 7: Zero-Shot with Random Weights
print("Running zero-shot with random projection...")
acc_random = evaluate(test_loader)
print(f"Zero-shot accuracy (random weights): {acc_random:.3%}")
# Expected: ~10% (1/10 chance)

Running zero-shot with random projection...


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


Zero-shot accuracy (random weights): 10.000%


In [10]:
# Cell 8: Train the Projection Head
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.train()

print("Starting training (5 epochs)...")
for epoch in range(5):
    losses = []
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/5"):
        imgs = imgs.to(device)
        labels = labels.to(device)

        inputs = vit_processor(images=imgs, return_tensors="pt").to(device)
        outputs = vit_model(**inputs)
        img_embeds = outputs.last_hidden_state[:, 0, :]  # [B, 384]

        logits = model(img_embeds, text_embeds)
        loss = F.cross_entropy(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    print(f"Epoch {epoch+1} - Avg loss: {np.mean(losses):.4f}")

# Final evaluation
acc_trained = evaluate(test_loader)
print(f"Accuracy after training: {acc_trained:.3%}")
# Expected: 50–70%+ (strong improvement even with mismatched encoders)

Starting training (5 epochs)...


Epoch 1/5:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 1 - Avg loss: 2.3005


Epoch 2/5:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 2 - Avg loss: 2.2964


Epoch 3/5:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 3 - Avg loss: 2.2817


Epoch 4/5:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 4 - Avg loss: 2.2665


Epoch 5/5:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 5 - Avg loss: 2.2508


Evaluating:   0%|          | 0/79 [00:00<?, ?it/s]

Accuracy after training: 15.700%
