Set up import packages

# Zero-shot Classification on CIFAR-10 using ViT-Small + BERT

This notebook implements a CLIP-style zero-shot classifier:
- Image encoder: ViT-Small (`WinKawaks/vit-small-patch16-224`)
- Text encoder: BERT (`bert-base-uncased`)
- We freeze both encoders and learn simple projection heads + temperature.

Note: This is not official CLIP (which uses its own trained encoders). Expect random-init accuracy ~10-15%, and after light training ~40-65% depending on prompts and epochs.

In [5]:
!pip install -q transformers datasets torch torchvision accelerate

In [1]:
import torch
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "NO GPU")

!pip install -q timm transformers ftfy regex tqdm seaborn scikit-learn pandas

GPU: Tesla T4
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
!pip install -q timm transformers peft scikit-learn

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import (
    ViTImageProcessor,
    ViTModel,
    BertTokenizer,
    BertModel
)
from datasets import load_dataset

import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score

Set device and load models (Code cell)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Image encoder: ViT-Small
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
vit_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224").to(device)

# Text encoder: BERT
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased").to(device)

# Freeze both encoders
for p in vit_model.parameters():
    p.requires_grad = False
for p in bert_model.parameters():
    p.requires_grad = False

print("ViT-Small output dim:", vit_model.config.hidden_size)   # 384
print("BERT output dim:", bert_model.config.hidden_size)       # 768

Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

ViT-Small output dim: 384
BERT output dim: 768


Define the projection model

In [8]:
class ZeroShotCLIP(nn.Module):
    def __init__(self, image_dim=384, text_dim=768, proj_dim=512):
        super().__init__()
        self.img_proj = nn.Linear(image_dim, proj_dim)
        self.txt_proj = nn.Linear(text_dim, proj_dim)
        self.temperature = nn.Parameter(torch.tensor(0.07))

    def forward(self, image_embeds, text_embeds):
        img_feat = self.img_proj(image_embeds)
        txt_feat = self.txt_proj(text_embeds)

        img_feat = F.normalize(img_feat, dim=-1)
        txt_feat = F.normalize(txt_feat, dim=-1)

        logits = (img_feat @ txt_feat.T) * self.temperature.exp()
        return logits

model = ZeroShotCLIP(proj_dim=512).to(device)
print(model)

ZeroShotCLIP(
  (img_proj): Linear(in_features=384, out_features=512, bias=True)
  (txt_proj): Linear(in_features=768, out_features=512, bias=True)
)


Define CIFAR-10 classes and prompts

In [9]:
cifar10_classes = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# Simple prompt template
class_prompts = [f"a photo of a {cls}" for cls in cifar10_classes]
print(class_prompts)

['a photo of a airplane', 'a photo of a automobile', 'a photo of a bird', 'a photo of a cat', 'a photo of a deer', 'a photo of a dog', 'a photo of a frog', 'a photo of a horse', 'a photo of a ship', 'a photo of a truck']


Compute text embeddings (once)

In [10]:
@torch.no_grad()
def get_text_embeddings(prompts):
    inputs = bert_tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(device)
    outputs = bert_model(**inputs)
    return outputs.last_hidden_state[:, 0, :]  # CLS token

print("Encoding class prompts...")
text_embeds = get_text_embeddings(class_prompts)
print("Text embeddings shape:", text_embeds.shape)  # [10, 768]

Encoding class prompts...
Text embeddings shape: torch.Size([10, 768])


Load CIFAR-10 test set

In [15]:
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10

# Define a basic transform (adjust as needed)
transform = transforms.Compose([
    transforms.ToTensor(),
    # Optional: transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) for normalized data
])

# Load the test split (train=False)
dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

print(dataset)
print(len(dataset))  # Should print 10000

100%|██████████| 170M/170M [00:02<00:00, 60.5MB/s]


Dataset CIFAR10
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
           )
10000


Zero-shot evaluation function


In [17]:
from torch.utils.data import DataLoader

@torch.no_grad()
def evaluate_zero_shot(batch_size=128):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    all_preds = []
    all_labels = []

    for imgs, labels in tqdm(loader, desc="Zero-shot evaluation"):
        inputs = vit_processor(images=imgs, return_tensors="pt").to(device)
        outputs = vit_model(**inputs)
        img_embeds = outputs.last_hidden_state[:, 0, :]  # [B, D]

        logits = model(img_embeds, text_embeds)  # assume model handles batched img_embeds
        preds = logits.argmax(dim=-1).cpu().tolist()

        all_preds.extend(preds)
        all_labels.extend(labels.tolist())

    acc = accuracy_score(all_labels, all_preds)
    return acc

In [18]:
@torch.no_grad()
def evaluate_zero_shot():
    all_preds = []
    all_labels = []

    for example in tqdm(dataset, desc="Zero-shot evaluation"):
        img, label = example          # ← Fixed: unpack the (tensor, int) tuple returned by CIFAR10

        inputs = vit_processor(images=img, return_tensors="pt").to(device)
        outputs = vit_model(**inputs)
        img_embed = outputs.last_hidden_state[:, 0, :]  # CLS token

        logits = model(img_embed, text_embeds)
        pred = logits.argmax(dim=-1).item()

        all_preds.append(pred)
        all_labels.append(label)

    acc = accuracy_score(all_labels, all_preds)
    return acc

# Run zero-shot (random projections)
print("Running zero-shot with random weights...")
acc_random = evaluate_zero_shot()
print(f"Zero-shot accuracy (random): {acc_random:.3%}")

Running zero-shot with random weights...


Zero-shot evaluation:   0%|          | 0/10000 [00:00<?, ?it/s]

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


Zero-shot accuracy (random): 10.000%


Quick training of projections

In [20]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms

# Load training set consistently with torchvision (same as your test set)
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())

optimizer = AdamW(model.parameters(), lr=5e-4)
model.train()

print("Starting quick training on a subset...")
for epoch in range(5):
    # Create a random subset of 10,000 examples (reshuffled each epoch)
    indices = torch.randperm(len(train_dataset))[:10000]
    subset = torch.utils.data.Subset(train_dataset, indices)

    # Batched DataLoader for much faster training
    loader = DataLoader(subset, batch_size=64, shuffle=False, pin_memory=True)

    losses = []
    for imgs, labels in tqdm(loader, desc=f"Epoch {epoch+1}/5"):
        # Move batch to device (imgs: [B, 3, 32, 32] float tensor in [0, 1])
        imgs = imgs.to(device)
        labels = labels.to(device)

        # Process the entire batch at once
        inputs = vit_processor(images=imgs, return_tensors="pt").to(device)
        outputs = vit_model(**inputs)
        img_embeds = outputs.last_hidden_state[:, 0, :]  # [B, D] CLS tokens

        # Forward through your trainable model (assumes it handles batched img_embeds)
        logits = model(img_embeds, text_embeds)  # shape: [B, num_classes]

        loss = F.cross_entropy(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

    print(f"Epoch {epoch+1} - avg loss: {np.mean(losses):.4f}")

# Switch back to eval mode and re-evaluate
model.eval()
acc_trained = evaluate_zero_shot()
print(f"Zero-shot accuracy after training: {acc_trained:.3%}")

Starting quick training on a subset...


Epoch 1/5: 100%|██████████| 157/157 [01:07<00:00,  2.34it/s]


Epoch 1 - avg loss: 2.3031


Epoch 2/5: 100%|██████████| 157/157 [01:07<00:00,  2.34it/s]


Epoch 2 - avg loss: 2.3028


Epoch 3/5: 100%|██████████| 157/157 [01:04<00:00,  2.43it/s]


Epoch 3 - avg loss: 2.3025


Epoch 4/5: 100%|██████████| 157/157 [01:10<00:00,  2.22it/s]


Epoch 4 - avg loss: 2.3027


Epoch 5/5: 100%|██████████| 157/157 [01:04<00:00,  2.42it/s]


Epoch 5 - avg loss: 2.3025


Zero-shot evaluation: 100%|██████████| 10000/10000 [01:46<00:00, 93.81it/s]

Zero-shot accuracy after training: 11.510%



