
This notebook demonstrates the use of the CLIP model for zero-shot image classification and fine-tuning on the CIFAR-10 dataset. First, we use CLIP's pre-trained image and text embeddings to perform zero-shot classification, where an image is classified without additional training by matching it to provided text labels. Next, we fine-tune CLIP on CIFAR-10 by replacing its visual projection head with a 10-class classifier and training it using cross-entropy loss. The notebook includes dataset preparation, model training, evaluation, and saving the fine-tuned model. By the end, you will understand CLIP's versatility in both zero-shot tasks and supervised fine-tuning.



**Zero-Shot Classification:** Demonstrates how pre-trained SSL models can generalize to unseen tasks.

**Fine-Tuning:** Shows how SSL models like CLIP can be adapted for specific downstream tasks.

In [None]:
# Install necessary libraries
!pip install torch torchvision transformers datasets




In [None]:
import torch
from PIL import Image
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel

# ---------------------------
# 1. Load CLIP Model and Processor
# ---------------------------
print("Loading CLIP model...")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Set device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# ---------------------------
# 2. Define Class Names for Zero-Shot Prediction
# ---------------------------
class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

# ---------------------------
# 3. Load and Preprocess Image
# ---------------------------
image_path = "example.jpg"  # Path to your image file
image = Image.open(image_path).convert("RGB")  # Ensure image is RGB

# ---------------------------
# 4. Perform Zero-Shot Prediction
# ---------------------------
print("Performing zero-shot prediction...")

# Preprocess the image and text inputs
inputs = processor(text=class_names, images=image, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Forward pass through CLIP model
model.eval()
with torch.no_grad():
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image  # Similarity scores
    probs = logits_per_image.softmax(dim=1)      # Convert to probabilities

# Get the predicted class
predicted_class_idx = probs.argmax(dim=1).item()
print(f"Predicted class: {class_names[predicted_class_idx]} (Probability: {probs[0][predicted_class_idx]:.4f})")


Loading CLIP model...
Performing zero-shot prediction...
Predicted class: dog (Probability: 0.9587)


In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
from PIL import Image
import random

# ---------------------------
# 1. Load CIFAR-10 Dataset and Preprocess
# ---------------------------
print("Loading CIFAR-10 dataset...")

# Load CIFAR-10 dataset
dataset = load_dataset("cifar10")

# Transformations: Resize, ToTensor, Normalize
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Preprocess function to ensure tensor consistency
def preprocess(example):
    if not isinstance(example["img"], Image.Image):
        example["img"] = Image.fromarray(example["img"])  # Convert to PIL Image
    tensor_image = transform(example["img"])
    example["image"] = tensor_image  # Ensure it's a tensor
    return example

# Subset selection
train_indices = random.sample(range(len(dataset["train"])), 1000)
test_indices = random.sample(range(len(dataset["test"])), 200)

train_dataset = dataset["train"].select(train_indices).map(preprocess, batched=False)
test_dataset = dataset["test"].select(test_indices).map(preprocess, batched=False)

# Collate function with strict tensor enforcement
def collate_fn(batch):
    images = []
    labels = []
    for item in batch:
        # Ensure 'image' is a tensor
        img = item["image"]
        if not torch.is_tensor(img):
            img = torch.tensor(img, dtype=torch.float32)
        images.append(img)
        labels.append(item["label"])
    return torch.stack(images), torch.tensor(labels, dtype=torch.long)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# ---------------------------
# 2. Load CLIP Model and Modify for Fine-Tuning
# ---------------------------
print("Loading CLIP model...")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Freeze all CLIP parameters and add a new classifier head
for param in model.parameters():
    param.requires_grad = False

num_classes = 10
model.classifier = nn.Sequential(
    nn.Linear(512, 256),  # Correct input size is 512
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, num_classes)
)

# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# ---------------------------
# 3. Fine-Tuning Setup
# ---------------------------
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=5e-5)

# Fine-tuning loop
print("Starting fine-tuning...")
model.train()
num_epochs = 3

for epoch in range(num_epochs):
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        features = model.get_image_features(images)
        logits = model.classifier(features)
        loss = criterion(logits, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = logits.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}, "
          f"Train Accuracy: {100 * correct / total:.2f}%")

# ---------------------------
# 4. Save Fine-Tuned Model
# ---------------------------
print("Saving fine-tuned model...")
model.save_pretrained("clip_finetuned_cifar10")
processor.save_pretrained("clip_finetuned_cifar10")
print("Fine-tuned model saved!")

# ---------------------------
# 5. Predict on a Test Image
# ---------------------------
print("Performing prediction on a test image...")

# Reload the fine-tuned model
model = CLIPModel.from_pretrained("clip_finetuned_cifar10").to(device)

# CIFAR-10 class names
cifar10_classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

# Load and preprocess a test image
image_path = "example.jpg"  # Replace with your image path
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension

# Perform prediction
model.eval()
with torch.no_grad():
    features = model.get_image_features(image_tensor)
    logits = model.classifier(features)
    probabilities = torch.softmax(logits, dim=1)
    predicted_class_idx = torch.argmax(probabilities, dim=1).item()

print(f"Predicted class: {cifar10_classes[predicted_class_idx]} "
      f"(Probability: {probabilities[0][predicted_class_idx]:.4f})")


Loading CIFAR-10 dataset...


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Loading CLIP model...
Starting fine-tuning...
Epoch [1/3], Loss: 2.2773, Train Accuracy: 16.40%
Epoch [2/3], Loss: 2.1925, Train Accuracy: 33.60%
Epoch [3/3], Loss: 2.1027, Train Accuracy: 50.90%
Saving fine-tuned model...
Fine-tuned model saved!
Performing prediction on a test image...


Some weights of the model checkpoint at clip_finetuned_cifar10 were not used when initializing CLIPModel: ['classifier.0.bias', 'classifier.0.weight', 'classifier.3.bias', 'classifier.3.weight']
- This IS expected if you are initializing CLIPModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing CLIPModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


AttributeError: 'CLIPModel' object has no attribute 'classifier'