# Retraining CLIP

## **IMPORTANT NOTE**

This data set was too big to train on the entire thing. I used a sample of 50,000 images for training, 5,000 for validation, and 10,000 for testing. The full data set has about 312,000 images.

## Set Up

Load in all our packages

In [None]:
# Install necessary packages
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install scipy

In [None]:
import random
import clip
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
from torchvision import datasets
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import shutil
import packaging
import kagglehub
import pandas as pd
import json

In [None]:
# Check PyTorch version
# Ensure compatibility with CUDA
version = packaging.version.parse(torch.__version__)
if version > packaging.version.parse('1.7.0'):
    print("Pytorch version is above 1.7.0")
    print("It is version:", version)
else:
    print("PyTorch version is not above 1.7.0. Please Upgrade")

Pytorch version is above 1.7.0
It is version: 2.6.0+cu124


Get the Clip Model

In [None]:
# Load CLIP model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, preprocess = clip.load("ViT-B/32", device=device)
model = model.float()

100%|███████████████████████████████████████| 338M/338M [00:03<00:00, 93.8MiB/s]


### Unfreeze more layers from CLIP

By default, many pre-trained models like CLIP freeze their internal layers. This means the weights of those layers don't get updated during training. Freezing maintains the extracted features from the initial training. But if we want the model to adapt to our new data, we need to "unfreeze" certain layers so they can be trained.

In [None]:
for name, param in model.named_parameters():
    # This loop goes through every parameter (weight/bias) in the CLIP model.
    # `name` is a string describing which layer the parameter belongs to.
    # `param` is the actual parameter tensor (a PyTorch object containing weights).

    if "visual" in name:
        # Only unfreeze layers in the "visual" part of the model.
        # CLIP has two main parts: a visual encoder (for images) and a text encoder (for text).
        # We only want to modify the visual encoder.

        param.requires_grad = True
        # This tells PyTorch: "Yes, this parameter should be updated during training."
        # Any parameter with `requires_grad = False` will be ignored during backpropagation.


###  Define linear classification head

This is a very simple MLP neural network: a single fully connected linear layer. It's used to map the output of CLIP's image encoder to a set of class predictions. Think of it like the final decision layer that says: "I think this image is class X."

In [None]:
class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        # Constructor for the class. Called when we create an instance of LinearClassifier.
        # `input_dim` is the size of the input features (from the CLIP image encoder: 512).
        # `num_classes` is the number of categories we want to classify

        super(LinearClassifier, self).__init__()

        self.fc = nn.Linear(input_dim, num_classes)
        # This creates the linear (fully connected) layer.
        # It takes a vector of size `input_dim`
        # and outputs a vector of size `num_classes` with values
        # representing the similarity of an image to each class.

    def forward(self, image_features):
        # This function defines how the data flows through the model during forward propogation.
        # It's called automatically during training and inference.

        return self.fc(image_features)
        # The output is a set of raw scores (logits) for each class.


## **DATA: This is the part you edit**

To run this re-training procedure this is the **only** part you want to edit. All necessary changes can be made here. Changes elsewhere may effect the model and make them difficult to compare.

### Ok, now we actually do this on the MS COCO data set

Load data

In [None]:
# Download latest version
path = kagglehub.dataset_download("sabahesaraki/2017-2017")
print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/sabahesaraki/2017-2017?dataset_version_number=1...


100%|██████████| 43.7G/43.7G [05:05<00:00, 154MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/sabahesaraki/2017-2017/versions/1


Split out validation set

In [None]:
dataset_root = os.path.join(path, 'train2017', 'train2017')
val_root = os.path.join(path, 'val2017', 'val2017')
test_root = os.path.join(path, 'test2017', 'test2017')

Labels

In [None]:
idx_to_class = {
    "1": "person",
    "2": "bicycle",
    "3": "car",
    "4": "motorcycle",
    "5": "airplane",
    "6": "bus",
    "7": "train",
    "8": "truck",
    "9": "boat",
    "10": "traffic light",
    "11": "fire hydrant",
    "13": "stop sign",
    "14": "parking meter",
    "15": "bench",
    "16": "bird",
    "17": "cat",
    "18": "dog",
    "19": "horse",
    "20": "sheep",
    "21": "cow",
    "22": "elephant",
    "23": "bear",
    "24": "zebra",
    "25": "giraffe",
    "27": "backpack",
    "28": "umbrella",
    "31": "handbag",
    "32": "tie",
    "33": "suitcase",
    "34": "frisbee",
    "35": "skis",
    "36": "snowboard",
    "37": "sports ball",
    "38": "kite",
    "39": "baseball bat",
    "40": "baseball glove",
    "41": "skateboard",
    "42": "surfboard",
    "43": "tennis racket",
    "44": "bottle",
    "46": "wine glass",
    "47": "cup",
    "48": "fork",
    "49": "knife",
    "50": "spoon",
    "51": "bowl",
    "52": "banana",
    "53": "apple",
    "54": "sandwich",
    "55": "orange",
    "56": "broccoli",
    "57": "carrot",
    "58": "hot dog",
    "59": "pizza",
    "60": "donut",
    "61": "cake",
    "62": "chair",
    "63": "couch",
    "64": "potted plant",
    "65": "bed",
    "67": "dining table",
    "70": "toilet",
    "72": "tv",
    "73": "laptop",
    "74": "mouse",
    "75": "remote",
    "76": "keyboard",
    "77": "cell phone",
    "78": "microwave",
    "79": "oven",
    "80": "toaster",
    "81": "sink",
    "82": "refrigerator",
    "84": "book",
    "85": "clock",
    "86": "vase",
    "87": "scissors",
    "88": "teddy bear",
    "89": "hair drier",
    "90": "toothbrush"
}

Preprocess and Get Classes

In [None]:
coco_categories = sorted(idx_to_class.items(), key=lambda x: int(x[0]))  # sort by category ID
idx_to_class_list = [name for _, name in coco_categories]  # 0-based index to name
class_to_idx = {name: i for i, name in enumerate(idx_to_class_list)}     # name -> index
catid_to_idx = {int(coco_id): i for i, (coco_id, name) in enumerate(coco_categories)}  # category_id -> index

In [None]:
def load_annotations(annotation_path):
    with open(annotation_path, 'r') as f:
        return json.load(f)

class CustomCOCODataset(Dataset):
    def __init__(self, image_dir, annotations, transform=preprocess, image_filenames=None, catid_to_idx=None, idx_to_class=None):
        self.image_dir = image_dir
        self.annotations = annotations  # {filename: category_id}
        self.transform = transform
        self.image_filenames = image_filenames if image_filenames else list(self.annotations.keys())
        self.catid_to_idx = catid_to_idx
        self.idx_to_class = idx_to_class

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_filename = self.image_filenames[idx]
        image_path = os.path.join(self.image_dir, image_filename)
        image = Image.open(image_path).convert("RGB")

        category_id = int(self.annotations[image_filename])  # original COCO ID
        label_idx = self.catid_to_idx[category_id]  # 0-based index

        if self.transform:
            image = self.transform(image)

        return image, label_idx

# Update annotations to use string category IDs to match idx_to_class
def build_filename_to_label(ann_json):
    images = {img['id']: img['file_name'] for img in ann_json['images']}
    annotations = ann_json['annotations']

    filename_to_label = {}
    for ann in annotations:
        image_id = ann['image_id']
        category_id = ann['category_id']
        filename = images[image_id]

        # Convert category_id to string to match idx_to_class keys
        filename_to_label[filename] = str(category_id)
    return filename_to_label

In [None]:
# Paths to the annotation files
train_annotation_path = os.path.join(path, 'annotations_trainval2017', 'annotations', 'instances_train2017.json')
val_annotation_path = os.path.join(path, 'annotations_trainval2017', 'annotations', 'instances_val2017.json')

# Load the JSON files
train_json = load_annotations(train_annotation_path)
val_json = load_annotations(val_annotation_path)

train_annotations = build_filename_to_label(train_json)
val_annotations = build_filename_to_label(val_json)

# Get image filenames
train_image_filenames = list(train_annotations.keys())
val_image_filenames = list(val_annotations.keys())

In [None]:
# Split:
train_sample = random.sample(train_image_filenames, 50000)               # 50,000 train images
remaining_train = list(set(train_image_filenames) - set(train_sample))   # leftovers for test
test_sample = random.sample(remaining_train, 10000)                      # 10,000 test images

val_sample = val_image_filenames  # All 5000 validation images

In [None]:
# Dataset construction
train_dataset = CustomCOCODataset(
    image_dir=dataset_root,
    annotations=train_annotations,
    transform=preprocess,
    image_filenames=train_sample,
    catid_to_idx=catid_to_idx,
    idx_to_class=idx_to_class_list
)

val_dataset = CustomCOCODataset(
    image_dir=val_root,
    annotations=val_annotations,
    transform=preprocess,
    image_filenames=val_sample,
    catid_to_idx=catid_to_idx,
    idx_to_class=idx_to_class_list
)

test_dataset = CustomCOCODataset(
    image_dir=dataset_root,
    annotations=train_annotations,
    transform=preprocess,
    image_filenames=test_sample,
    catid_to_idx=catid_to_idx,
    idx_to_class=idx_to_class_list
)

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
with torch.no_grad():
    all_text_prompts = [f"A photo of a {classname}" for classname in idx_to_class_list]
    tokenized_texts = clip.tokenize(all_text_prompts).to(device)
    text_features_all = model.encode_text(tokenized_texts)
    text_features_all = F.normalize(text_features_all, dim=-1).float()

Get test set set up for later

In [None]:
# COCO-style version

# Get the class names from the test dataset (these are in correct order)
class_names_test = idx_to_class_list

print("Class names:", class_names_test)

# Build natural language prompts for each class
all_texts = [f"A photo of a {classname}" for classname in class_names_test]  # Feel free to customize these prompts

# Tokenize and encode with CLIP
with torch.no_grad():
    tokenized_texts = clip.tokenize(all_texts).to(device)
    text_features_all = model.encode_text(tokenized_texts)
    text_features_all = F.normalize(text_features_all, dim=-1).float()

Class names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']


## **OK, STOP EDITING HERE**

The rest of this file should work just fine without edits if you didn't change any variable names

## Let's Retrain

In [None]:
classifier = LinearClassifier(input_dim=512, num_classes=len(class_names)).to(device)
# Initializes the classifier we defined earlier

optimizer = torch.optim.AdamW([
    {"params": model.visual.parameters(), "lr": 1e-6},
    {"params": classifier.parameters(), "lr": 1e-4}
], weight_decay=1e-4)
# We're training two parts:
# 1) model.visual: The vision encoder from CLIP — we fine-tune it very gently using a small learning rate (1e-6)
# 2) classifier: Our new linear layer — it starts from scratch, so we train it more aggressively (1e-4)
# AdamW is a common optimizer

criterion = nn.CrossEntropyLoss()
# Cross-entropy compares the predicted scores (logits) against the true label and penalizes wrong guesses.

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
# This slowly reduces the learning rate over time in a smooth cosine curve
# this is a common trick to make training more stable and avoid overshooting the minimum loss.


num_epochs = 10
# how many times we loop through the whole dataset

best_val_acc = 0
# keeps track of the best accuracy we've seen so far

patience = 3
# For early stopping — we stop training if validation accuracy doesn’t improve for 3 straight epochs
# This trains more efficiently and prevents overfitting

epochs_no_improve = 0
# how many times we've failed to beat our best accuracy

for epoch in range(num_epochs):

    classifier.train()
    # classifier.train() puts the model in training mode

    total_loss, correct, total = 0, 0, 0

    # Training ##################################################

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Train)"):

        images, labels = images.to(device), labels.to(device)
        image_features = model.encode_image(images).float()
        # Use CLIP’s vision model to encode the images into 512-dimension feature vectors
        image_features = F.normalize(image_features, dim=-1)
        # Normalize them (unit length) so comparisons (dot products) behave like cosine similarity

        with torch.no_grad():
            clip_logits = image_features @ text_features_all.T  # (B, num_classes)
        # Dot product between image and text features. Gives similarity scores
        # (logits) between each image and all class names.

        classifier_logits = classifier(image_features)
        # our classifier’s own guess — based on its trained weights

        clip_logits = clip_logits / clip_logits.norm(dim=-1, keepdim=True)
        classifier_logits = classifier_logits / classifier_logits.norm(dim=-1, keepdim=True)
        # Normalize - ensures the same scale

        blended_logits = 0.5 * classifier_logits + 0.5 * clip_logits
        # average the scores from CLIP and our linear classifier

        loss = criterion(blended_logits, labels)
        # Calculate the loss from the blended prediction

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Clear old gradients, backpropagate new ones, and take an optimizer step

        total_loss += loss.item()
        correct += (blended_logits.argmax(dim=1) == labels).sum().item()
        total += labels.size(0)
        # Count how many predictions were correct and update total loss and accuracy

    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1}: Train Loss = {total_loss:.4f}, Train Acc = {train_acc:.2f}%")

    # Validation ################################################

    classifier.eval()
    # Switch model to evaluation mode

    correct, total = 0, 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Val)"):
            images, labels = images.to(device), labels.to(device)
            image_features = model.encode_image(images).float()
            image_features = F.normalize(image_features, dim=-1)

            with torch.no_grad():
                clip_logits = image_features @ text_features_all.T
            classifier_logits = classifier(image_features)
            clip_logits = clip_logits / clip_logits.norm(dim=-1, keepdim=True)
            classifier_logits = classifier_logits / classifier_logits.norm(dim=-1, keepdim=True)
            logits = 0.5 * classifier_logits + 0.5 * clip_logits
            correct += (logits.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)

    val_acc = 100 * correct / total
    print(f"Epoch {epoch+1}: Val Acc = {val_acc:.2f}%")
    # Count correct predictions to compute validation accuracy

    # Early Stopping #############################################

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        epochs_no_improve = 0
        torch.save(classifier.state_dict(), 'best_linear_classifier.pth')
        print("Improved validation accuracy. Saved model.")
    # If we beat our best validation accuracy, save the model

    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping.")
            break
    # If we’ve gone patience epochs with no improvement, stop training early

    scheduler.step()
    # Move along the cosine schedule — lower the learning rate a bit


Epoch 1/10 (Train): 100%|██████████| 1563/1563 [09:18<00:00,  2.80it/s]


Epoch 1: Train Loss = 6320.2912, Train Acc = 63.92%


Epoch 1/10 (Val): 100%|██████████| 155/155 [00:45<00:00,  3.38it/s]


Epoch 1: Val Acc = 52.26%
Improved validation accuracy. Saved model.


Epoch 2/10 (Train): 100%|██████████| 1563/1563 [09:09<00:00,  2.85it/s]


Epoch 2: Train Loss = 6225.4611, Train Acc = 73.71%


Epoch 2/10 (Val): 100%|██████████| 155/155 [00:45<00:00,  3.40it/s]


Epoch 2: Val Acc = 51.07%


Epoch 3/10 (Train): 100%|██████████| 1563/1563 [09:08<00:00,  2.85it/s]


Epoch 3: Train Loss = 6191.5421, Train Acc = 78.06%


Epoch 3/10 (Val): 100%|██████████| 155/155 [00:45<00:00,  3.41it/s]


Epoch 3: Val Acc = 51.60%


Epoch 4/10 (Train): 100%|██████████| 1563/1563 [09:08<00:00,  2.85it/s]


Epoch 4: Train Loss = 6168.4263, Train Acc = 80.86%


Epoch 4/10 (Val): 100%|██████████| 155/155 [00:46<00:00,  3.37it/s]

Epoch 4: Val Acc = 49.82%
Early stopping.





## Compute Accuracy with Newly Trained Model

In [None]:
def compute_topk_accuracy(logits, labels, topk=(1, 3, 5)):
    max_k = max(topk)
    batch_size = labels.size(0)

    _, pred = logits.topk(max_k, dim=1, largest=True, sorted=True)
    pred = pred.t()
    correct = pred.eq(labels.view(1, -1).expand_as(pred))

    topk_accs = {}
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        topk_accs[f"top{k}"] = (correct_k / batch_size).item() * 100.0

    return topk_accs

# Evaluate fine-tuned classifier
classifier.eval()
top1_total, top3_total, top5_total, total_samples = 0, 0, 0, 0

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Evaluating Fine-tuned Classifier"):
        images, labels = images.to(device), labels.to(device)
        image_features = model.encode_image(images).float()
        image_features = F.normalize(image_features, dim=-1)

        logits = classifier(image_features)
        accs = compute_topk_accuracy(logits, labels)

        top1_total += accs['top1'] * images.size(0)
        top3_total += accs['top3'] * images.size(0)
        top5_total += accs['top5'] * images.size(0)
        total_samples += images.size(0)

print(f"\nFine-tuned Classifier Accuracy:")
print(f"Top-1: {top1_total / total_samples:.2f}%")
print(f"Top-3: {top3_total / total_samples:.2f}%")
print(f"Top-5: {top5_total / total_samples:.2f}%")


Evaluating Fine-tuned Classifier: 100%|██████████| 313/313 [01:34<00:00,  3.31it/s]


Fine-tuned Classifier Accuracy:
Top-1: 67.23%
Top-3: 80.79%
Top-5: 84.57%





## Compare To Zero Shot Accuracy

In [None]:
original_model = clip.load("ViT-B/32", device=device)[0].float().eval()
# Tokenize and encode with CLIP
with torch.no_grad():
    tokenized_texts = clip.tokenize(all_texts).to(device)
    text_features_all = model.encode_text(tokenized_texts)
    text_features_all = F.normalize(text_features_all, dim=-1).float()

def compute_zero_shot_topk_accuracy(model, image_loader, text_features_all, device):
    model.eval()
    text_features_all = F.normalize(text_features_all, dim=-1)

    top1_total, top3_total, top5_total, total_samples = 0, 0, 0, 0

    with torch.no_grad():
        for images, labels in tqdm(image_loader, desc="Evaluating Zero-Shot CLIP"):
            images, labels = images.to(device), labels.to(device)
            image_features = model.encode_image(images).float()
            image_features = F.normalize(image_features, dim=-1)

            logits = image_features @ text_features_all.T
            accs = compute_topk_accuracy(logits, labels)

            top1_total += accs['top1'] * images.size(0)
            top3_total += accs['top3'] * images.size(0)
            top5_total += accs['top5'] * images.size(0)
            total_samples += images.size(0)

    return {
        'top1': top1_total / total_samples,
        'top3': top3_total / total_samples,
        'top5': top5_total / total_samples,
    }

# Run zero-shot evaluation
zero_shot_results = compute_zero_shot_topk_accuracy(original_model, test_loader, text_features_all, device)

print("\nZero-Shot CLIP Accuracy:")
print(f"Top-1: {zero_shot_results['top1']:.2f}%")
print(f"Top-3: {zero_shot_results['top3']:.2f}%")
print(f"Top-5: {zero_shot_results['top5']:.2f}%")

Evaluating Zero-Shot CLIP: 100%|██████████| 313/313 [01:33<00:00,  3.35it/s]


Zero-Shot CLIP Accuracy:
Top-1: 36.94%
Top-3: 54.41%
Top-5: 64.12%



