# Retraining CLIP

## Set Up

Load in all our packages

In [None]:
# Install necessary packages
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install scipy

import clip
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
from torchvision import datasets
from PIL import Image
import os
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import shutil
import packaging
import kagglehub
import pandas as pd

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-ecsgolh2
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-ecsgolh2
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->clip==1.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting

In [None]:
# Check PyTorch version
# Ensure compatibility with CUDA
version = packaging.version.parse(torch.__version__)
if version > packaging.version.parse('1.7.0'):
    print("Pytorch version is above 1.7.0")
    print("It is version:", version)
else:
    print("PyTorch version is not above 1.7.0. Please Upgrade")

Pytorch version is above 1.7.0
It is version: 2.6.0+cu124


Get the Clip Model

In [None]:
# Load CLIP model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, preprocess = clip.load("ViT-B/32", device=device)
model = model.float()

100%|███████████████████████████████████████| 338M/338M [00:06<00:00, 50.9MiB/s]


### Unfreeze more layers from CLIP

By default, many pre-trained models like CLIP freeze their internal layers. This means the weights of those layers don't get updated during training. Freezing maintains the extracted features from the initial training. But if we want the model to adapt to our new data, we need to "unfreeze" certain layers so they can be trained.

In [None]:
for name, param in model.named_parameters():
    # This loop goes through every parameter (weight/bias) in the CLIP model.
    # `name` is a string describing which layer the parameter belongs to.
    # `param` is the actual parameter tensor (a PyTorch object containing weights).

    if "visual" in name:
        # Only unfreeze layers in the "visual" part of the model.
        # CLIP has two main parts: a visual encoder (for images) and a text encoder (for text).
        # We only want to modify the visual encoder.

        param.requires_grad = True
        # This tells PyTorch: "Yes, this parameter should be updated during training."
        # Any parameter with `requires_grad = False` will be ignored during backpropagation.

###  Define linear classification head

This is a very simple MLP neural network: a single fully connected linear layer. It's used to map the output of CLIP's image encoder to a set of class predictions. Think of it like the final decision layer that says: "I think this image is class X."

In [None]:
class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        # Constructor for the class. Called when we create an instance of LinearClassifier.
        # `input_dim` is the size of the input features (from the CLIP image encoder: 512).
        # `num_classes` is the number of categories we want to classify

        super(LinearClassifier, self).__init__()

        self.fc = nn.Linear(input_dim, num_classes)
        # This creates the linear (fully connected) layer.
        # It takes a vector of size `input_dim`
        # and outputs a vector of size `num_classes` with values
        # representing the similarity of an image to each class.

    def forward(self, image_features):
        # This function defines how the data flows through the model during forward propogation.
        # It's called automatically during training and inference.

        return self.fc(image_features)
        # The output is a set of raw scores (logits) for each class.


## **DATA: This is the part you edit**

To run this re-training procedure this is the **only** part you want to edit. All necessary changes can be made here. Changes elsewhere may effect the model and make them difficult to compare.

### Now we actually do this on GTSRB data set

Load data

In [None]:
# Download latest version
path = kagglehub.dataset_download("meowmeowmeowmeowmeow/gtsrb-german-traffic-sign") # Change this to your data set
print(os.listdir(path))

# Load CSV files for train and test
train_dir = os.path.join(path, 'Train.csv') # Change to your train folder
test_dir = os.path.join(path, 'Test.csv') # Change to your test folder

train_df = pd.read_csv(train_dir)
test_df = pd.read_csv(test_dir)

['Meta', 'meta', 'Meta.csv', 'Train.csv', 'Test.csv', 'Test', 'test', 'Train', 'train']


Define dataset class to handle images and class connections

In [None]:
class GTSRBDataset(Dataset):
    def __init__(self, dataframe, root_dir, transform=None):
        """
        Args:
            dataframe (pd.DataFrame): Dataframe with image paths and labels.
            root_dir (str): Directory where the "Train" folder is located.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data = dataframe
        self.root_dir = root_dir  # Directory where the "Train" folder is located
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Correct the path construction to point directly to the image file
        img_name = os.path.join(self.root_dir, self.data.iloc[idx]['Path'])  # 'Path' is the filename, change to the column containing file names in csv
        image = Image.open(img_name).convert("RGB")

        # Accessing the label from the 'ClassId' column
        label = int(self.data.iloc[idx]['ClassId']) # 'ClassId' is the class names, change to the column containing class names in csv

        if self.transform:
            image = self.transform(image)

        return image, label

Split out validation set

In [None]:
# Split the training dataset into train and validation sets
train_data, val_data = train_test_split(train_df, test_size=0.2, random_state=42)

Preprocess

In [None]:
# Create train and validation datasets
train_dataset = GTSRBDataset(dataframe=train_data, root_dir=path, transform=preprocess)
val_dataset = GTSRBDataset(dataframe=val_data, root_dir=path, transform=preprocess)
test_dataset = GTSRBDataset(dataframe=test_df, root_dir=path, transform=preprocess)

# Create DataLoader for train, validation, and test sets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Get classes

In [None]:
# ClassId to human-readable traffic sign names
# This data only has numeric labels so I added text labels
# No need to do this if your data already has text labels
class_id_to_name = {
    0: "speed limit 20 km/h",
    1: "speed limit 30 km/h",
    2: "speed limit 50 km/h",
    3: "speed limit 60 km/h",
    4: "speed limit 70 km/h",
    5: "speed limit 80 km/h",
    6: "end of speed limit 80 km/h",
    7: "speed limit 100 km/h",
    8: "speed limit 120 km/h",
    9: "no passing",
    10: "no passing for vehicles over 3.5 metric tons",
    11: "right-of-way at the next intersection",
    12: "priority road",
    13: "yield",
    14: "stop",
    15: "no vehicles",
    16: "vehicles over 3.5 metric tons prohibited",
    17: "no entry",
    18: "general caution",
    19: "dangerous curve to the left",
    20: "dangerous curve to the right",
    21: "double curve",
    22: "bumpy road",
    23: "slippery road",
    24: "road narrows on the right",
    25: "road work",
    26: "traffic signals",
    27: "pedestrians",
    28: "children crossing",
    29: "bicycles crossing",
    30: "beware of ice/snow",
    31: "wild animals crossing",
    32: "end of all speed and passing limits",
    33: "turn right ahead",
    34: "turn left ahead",
    35: "ahead only",
    36: "go straight or right",
    37: "go straight or left",
    38: "keep right",
    39: "keep left",
    40: "roundabout mandatory",
    41: "end of no passing",
    42: "end of no passing by vehicles over 3.5 metric tons"
}

class_ids = sorted(train_df['ClassId'].unique()) # 'ClassId' is the class names, change to the column containing class names in csv
class_names = [class_id_to_name[i] for i in class_ids] # Don't do this if your data is already in text format
# If your labels are in text format already, change the class_ids variable name to class_names and delete the second line

all_texts = [f"A traffic sign that means {name}" for name in class_names] # Feel free to change prompts here
tokenized_texts = clip.tokenize(all_texts).to(device)
with torch.no_grad():
    text_features_all = model.encode_text(tokenized_texts)
    text_features_all = F.normalize(text_features_all, dim=-1).float()

Get test set set up for later

In [None]:
test_dataset = GTSRBDataset(dataframe=test_df, root_dir=path, transform=preprocess)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# Use the same GTSRBDataset for the test set

class_ids = sorted(test_df['ClassId'].unique()) # 'ClassId' is the class names, change to the column containing class names in csv
class_names_test = [class_id_to_name[i] for i in class_ids] # Don't do this if your data is already in text format
print("Class names:", class_names_test)
# If your labels are in text format already, change the class_ids variable name to class_names and delete the second line

all_texts = [f"A photo of a road sign showing {classname}" for classname in class_names_test] # Feel free to change prompts here
tokenized_texts = clip.tokenize(all_texts).to(device)
with torch.no_grad():
    text_features_all = model.encode_text(tokenized_texts)  # Shape: (num_classes, 512)
    text_features_all = F.normalize(text_features_all, dim=-1).float()

Class names: ['speed limit 20 km/h', 'speed limit 30 km/h', 'speed limit 50 km/h', 'speed limit 60 km/h', 'speed limit 70 km/h', 'speed limit 80 km/h', 'end of speed limit 80 km/h', 'speed limit 100 km/h', 'speed limit 120 km/h', 'no passing', 'no passing for vehicles over 3.5 metric tons', 'right-of-way at the next intersection', 'priority road', 'yield', 'stop', 'no vehicles', 'vehicles over 3.5 metric tons prohibited', 'no entry', 'general caution', 'dangerous curve to the left', 'dangerous curve to the right', 'double curve', 'bumpy road', 'slippery road', 'road narrows on the right', 'road work', 'traffic signals', 'pedestrians', 'children crossing', 'bicycles crossing', 'beware of ice/snow', 'wild animals crossing', 'end of all speed and passing limits', 'turn right ahead', 'turn left ahead', 'ahead only', 'go straight or right', 'go straight or left', 'keep right', 'keep left', 'roundabout mandatory', 'end of no passing', 'end of no passing by vehicles over 3.5 metric tons']


## **OK, STOP EDITING HERE**

The rest of this file should work just fine without edits if you didn't change any variable names

## Let's Retrain

In [None]:
classifier = LinearClassifier(input_dim=512, num_classes=len(class_names)).to(device)
# Initializes the classifier we defined earlier

optimizer = torch.optim.AdamW([
    {"params": model.visual.parameters(), "lr": 1e-6},
    {"params": classifier.parameters(), "lr": 1e-4}
], weight_decay=1e-4)
# We're training two parts:
# 1) model.visual: The vision encoder from CLIP — we fine-tune it very gently using a small learning rate (1e-6)
# 2) classifier: Our new linear layer — it starts from scratch, so we train it more aggressively (1e-4)
# AdamW is a common optimizer

criterion = nn.CrossEntropyLoss()
# Cross-entropy compares the predicted scores (logits) against the true label and penalizes wrong guesses.

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
# This slowly reduces the learning rate over time in a smooth cosine curve
# this is a common trick to make training more stable and avoid overshooting the minimum loss.


num_epochs = 10
# how many times we loop through the whole dataset

best_val_acc = 0
# keeps track of the best accuracy we've seen so far

patience = 3
# For early stopping — we stop training if validation accuracy doesn’t improve for 3 straight epochs
# This trains more efficiently and prevents overfitting

epochs_no_improve = 0
# how many times we've failed to beat our best accuracy

for epoch in range(num_epochs):

    classifier.train()
    # classifier.train() puts the model in training mode

    total_loss, correct, total = 0, 0, 0

    # Training ##################################################

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Train)"):

        images, labels = images.to(device), labels.to(device)
        image_features = model.encode_image(images).float()
        # Use CLIP’s vision model to encode the images into 512-dimension feature vectors
        image_features = F.normalize(image_features, dim=-1)
        # Normalize them (unit length) so comparisons (dot products) behave like cosine similarity

        with torch.no_grad():
            clip_logits = image_features @ text_features_all.T  # (B, num_classes)
        # Dot product between image and text features. Gives similarity scores
        # (logits) between each image and all class names.

        classifier_logits = classifier(image_features)
        # our classifier’s own guess — based on its trained weights

        clip_logits = clip_logits / clip_logits.norm(dim=-1, keepdim=True)
        classifier_logits = classifier_logits / classifier_logits.norm(dim=-1, keepdim=True)
        # Normalize - ensures the same scale

        blended_logits = 0.5 * classifier_logits + 0.5 * clip_logits
        # average the scores from CLIP and our linear classifier

        loss = criterion(blended_logits, labels)
        # Calculate the loss from the blended prediction

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Clear old gradients, backpropagate new ones, and take an optimizer step

        total_loss += loss.item()
        correct += (blended_logits.argmax(dim=1) == labels).sum().item()
        total += labels.size(0)
        # Count how many predictions were correct and update total loss and accuracy

    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1}: Train Loss = {total_loss:.4f}, Train Acc = {train_acc:.2f}%")

    # Validation ################################################

    classifier.eval()
    # Switch model to evaluation mode

    correct, total = 0, 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} (Val)"):
            images, labels = images.to(device), labels.to(device)
            image_features = model.encode_image(images).float()
            image_features = F.normalize(image_features, dim=-1)

            with torch.no_grad():
                clip_logits = image_features @ text_features_all.T
            classifier_logits = classifier(image_features)
            clip_logits = clip_logits / clip_logits.norm(dim=-1, keepdim=True)
            classifier_logits = classifier_logits / classifier_logits.norm(dim=-1, keepdim=True)
            logits = 0.5 * classifier_logits + 0.5 * clip_logits
            correct += (logits.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)

    val_acc = 100 * correct / total
    print(f"Epoch {epoch+1}: Val Acc = {val_acc:.2f}%")
    # Count correct predictions to compute validation accuracy

    # Early Stopping #############################################

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        epochs_no_improve = 0
        torch.save(classifier.state_dict(), 'best_linear_classifier.pth')
        print("Improved validation accuracy. Saved model.")
    # If we beat our best validation accuracy, save the model

    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping.")
            break
    # If we’ve gone patience epochs with no improvement, stop training early

    scheduler.step()
    # Move along the cosine schedule — lower the learning rate a bit


Epoch 1/10 (Train): 100%|██████████| 981/981 [06:04<00:00,  2.69it/s]


Epoch 1: Train Loss = 3276.9511, Train Acc = 89.30%


Epoch 1/10 (Val): 100%|██████████| 246/246 [01:16<00:00,  3.22it/s]


Epoch 1: Val Acc = 99.50%
Improved validation accuracy. Saved model.


Epoch 2/10 (Train): 100%|██████████| 981/981 [02:23<00:00,  6.83it/s]


Epoch 2: Train Loss = 3201.0141, Train Acc = 99.77%


Epoch 2/10 (Val): 100%|██████████| 246/246 [00:25<00:00,  9.63it/s]


Epoch 2: Val Acc = 99.76%
Improved validation accuracy. Saved model.


Epoch 3/10 (Train): 100%|██████████| 981/981 [02:23<00:00,  6.83it/s]


Epoch 3: Train Loss = 3196.6524, Train Acc = 99.97%


Epoch 3/10 (Val): 100%|██████████| 246/246 [00:25<00:00,  9.74it/s]


Epoch 3: Val Acc = 99.89%
Improved validation accuracy. Saved model.


Epoch 4/10 (Train): 100%|██████████| 981/981 [02:23<00:00,  6.82it/s]


Epoch 4: Train Loss = 3195.4229, Train Acc = 99.97%


Epoch 4/10 (Val): 100%|██████████| 246/246 [00:25<00:00,  9.58it/s]


Epoch 4: Val Acc = 99.81%


Epoch 5/10 (Train): 100%|██████████| 981/981 [02:23<00:00,  6.81it/s]


Epoch 5: Train Loss = 3194.8895, Train Acc = 99.98%


Epoch 5/10 (Val): 100%|██████████| 246/246 [00:25<00:00,  9.61it/s]


Epoch 5: Val Acc = 99.78%


Epoch 6/10 (Train): 100%|██████████| 981/981 [02:22<00:00,  6.86it/s]


Epoch 6: Train Loss = 3194.4884, Train Acc = 100.00%


Epoch 6/10 (Val): 100%|██████████| 246/246 [00:25<00:00,  9.66it/s]

Epoch 6: Val Acc = 99.83%
Early stopping.





## Compute Accuracy with Newly Trained Model

In [None]:
def compute_topk_accuracy(logits, labels, topk=(1, 3, 5)):
    max_k = max(topk)
    batch_size = labels.size(0)

    _, pred = logits.topk(max_k, dim=1, largest=True, sorted=True)
    pred = pred.t()
    correct = pred.eq(labels.view(1, -1).expand_as(pred))

    topk_accs = {}
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        topk_accs[f"top{k}"] = (correct_k / batch_size).item() * 100.0

    return topk_accs

# Evaluate fine-tuned classifier
classifier.eval()
top1_total, top3_total, top5_total, total_samples = 0, 0, 0, 0

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Evaluating Fine-tuned Classifier"):
        images, labels = images.to(device), labels.to(device)
        image_features = model.encode_image(images).float()
        image_features = F.normalize(image_features, dim=-1)

        logits = classifier(image_features)
        accs = compute_topk_accuracy(logits, labels)

        top1_total += accs['top1'] * images.size(0)
        top3_total += accs['top3'] * images.size(0)
        top5_total += accs['top5'] * images.size(0)
        total_samples += images.size(0)

print(f"\nFine-tuned Classifier Accuracy:")
print(f"Top-1: {top1_total / total_samples:.2f}%")
print(f"Top-3: {top3_total / total_samples:.2f}%")
print(f"Top-5: {top5_total / total_samples:.2f}%")


Evaluating Fine-tuned Classifier: 100%|██████████| 395/395 [00:42<00:00,  9.38it/s]


Fine-tuned Classifier Accuracy:
Top-1: 98.63%
Top-3: 99.58%
Top-5: 99.75%





## Compare To Zero Shot Accuracy

In [None]:
original_model = clip.load("ViT-B/32", device=device)[0].float().eval()
with torch.no_grad():
    tokenized_texts = clip.tokenize([f"A photo of a {classname}" for classname in class_names_test]).to(device)
    text_features_all = original_model.encode_text(tokenized_texts)
    text_features_all = F.normalize(text_features_all, dim=-1).float()

def compute_zero_shot_topk_accuracy(model, image_loader, text_features_all, device):
    model.eval()
    text_features_all = F.normalize(text_features_all, dim=-1)

    top1_total, top3_total, top5_total, total_samples = 0, 0, 0, 0

    with torch.no_grad():
        for images, labels in tqdm(image_loader, desc="Evaluating Zero-Shot CLIP"):
            images, labels = images.to(device), labels.to(device)
            image_features = model.encode_image(images).float()
            image_features = F.normalize(image_features, dim=-1)

            logits = image_features @ text_features_all.T
            accs = compute_topk_accuracy(logits, labels)

            top1_total += accs['top1'] * images.size(0)
            top3_total += accs['top3'] * images.size(0)
            top5_total += accs['top5'] * images.size(0)
            total_samples += images.size(0)

    return {
        'top1': top1_total / total_samples,
        'top3': top3_total / total_samples,
        'top5': top5_total / total_samples,
    }

# Run zero-shot evaluation
zero_shot_results = compute_zero_shot_topk_accuracy(original_model, test_loader, text_features_all, device)

print("\nZero-Shot CLIP Accuracy:")
print(f"Top-1: {zero_shot_results['top1']:.2f}%")
print(f"Top-3: {zero_shot_results['top3']:.2f}%")
print(f"Top-5: {zero_shot_results['top5']:.2f}%")

Evaluating Zero-Shot CLIP: 100%|██████████| 395/395 [00:42<00:00,  9.28it/s]


Zero-Shot CLIP Accuracy:
Top-1: 28.61%
Top-3: 39.03%
Top-5: 44.17%



