<a href="https://colab.research.google.com/github/Mubashir714/Efficient-CLIP-Distillation/blob/main/Efficient_CLIP_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# INSTALL DEPENDENCIES

In [None]:
!pip install open_clip_torch
!pip install datasets
!pip install transformers
!pip install timm
!pip install accelerate
!pip install ftfy regex tqdm


# IMPORTS

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from datasets import load_dataset
import open_clip
import random
import numpy as np
from tqdm import tqdm


# LOAD TEACHER MODEL (CLIP ViT-B/32)

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

teacher_model, teacher_preprocess_train, teacher_preprocess_val = open_clip.create_model_and_transforms(
    'ViT-B-32', pretrained='openai'
)

teacher_tokenizer = open_clip.get_tokenizer('ViT-B-32')

teacher_model = teacher_model.to(device)
teacher_model.eval()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


open_clip_model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]



CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine

# LOAD STUDENT MODEL (CLIP ViT-Tiny)

In [5]:
student_model, student_preprocess_train, student_preprocess_val = open_clip.create_model_and_transforms(
    'MobileCLIP-S2',
    pretrained=None   # training from scratch
)

student_tokenizer = open_clip.get_tokenizer('MobileCLIP-S2')

student_model = student_model.to(device)
student_model.train()




CustomTextCLIP(
  (visual): TimmModel(
    (trunk): FastVit(
      (stem): Sequential(
        (0): MobileOneBlock(
          (se): Identity()
          (conv_kxk): ModuleList(
            (0): ConvNormAct(
              (conv): Conv2d(3, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
              (bn): BatchNormAct2d(
                80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
                (drop): Identity()
                (act): Identity()
              )
            )
          )
          (conv_scale): ConvNormAct(
            (conv): Conv2d(3, 80, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (bn): BatchNormAct2d(
              80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
              (drop): Identity()
              (act): Identity()
            )
          )
          (act): GELU(approximate='none')
        )
        (1): MobileOneBlock(
          (se): Identity()
          (conv_kxk): ModuleList

# LOAD DATASET (CIFAR-10 with Synthetic Captions)


In [34]:
from datasets import load_dataset
from PIL import Image

dataset = load_dataset("cifar10", split="train")
val_dataset = load_dataset("cifar10", split="test")
from torchvision import transforms
import torch

# CIFAR-10 class names
label_to_caption = {
    0: "a photo of an airplane",
    1: "a photo of an automobile",
    2: "a photo of a bird",
    3: "a photo of a cat",
    4: "a photo of a deer",
    5: "a photo of a dog",
    6: "a photo of a frog",
    7: "a photo of a horse",
    8: "a photo of a ship",
    9: "a photo of a truck"
}

image_transform = transforms.Compose([
    transforms.Resize((224,224)),  # smaller images → faster
    transforms.ToTensor(),
])

def collate_fn(batch):
    images = []
    texts = []
    labels = []

    for item in batch:
        img = image_transform(item["img"])
        images.append(img)

        caption = label_to_caption[item["label"]]
        texts.append(caption)
        labels.append(item["label"])  # needed for student evaluation

    text_tokens = teacher_tokenizer(texts)  # for teacher model

    return {
        "images": torch.stack(images),
        "texts": text_tokens,
        "labels": labels
    }

# train_loader = DataLoader(dataset.with_transform(preprocess_batch), batch_size=32, shuffle=True)
# val_loader  = DataLoader(val_dataset.with_transform(preprocess_batch), batch_size=32, shuffle=False)
# train_dataset_tf = dataset.with_transform(preprocess_sample)
# train_loader = DataLoader(train_dataset_tf, batch_size=32, shuffle=True)


from torch.utils.data import DataLoader

val_loader = DataLoader(
    val_dataset,
    batch_size=8,       # same as training
    shuffle=False,
    collate_fn=collate_fn
)


print("Train samples:", len(dataset))
print("Val samples:", len(val_dataset))


Train samples: 50000
Val samples: 10000


# DISTILLATION LOSS FUNCTIONS

In [26]:
def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    logits = (image_embeds @ text_embeds.t()) / temperature
    labels = torch.arange(len(logits), device=device)
    return (F.cross_entropy(logits, labels) +
            F.cross_entropy(logits.t(), labels)) / 2


def distillation_loss(img_t, txt_t, img_s, txt_s):
    img_loss = F.mse_loss(img_s, img_t.detach())
    txt_loss = F.mse_loss(txt_s, txt_t.detach())
    cont_loss = contrastive_loss(img_s, txt_s)
    return img_loss + txt_loss + 0.5 * cont_loss


# TRAINING LOOP

In [28]:
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-4)
EPOCHS = 3
teacher_model.eval()  # teacher is frozen

for epoch in range(EPOCHS):
    total_loss = 0

    for batch in tqdm(train_loader):

        # Move tensors to GPU
        images = batch["images"].to(device)
        texts  = batch["texts"].to(device)

        # --- Teacher embeddings (frozen) ---
        with torch.no_grad():
            img_t = teacher_model.encode_image(images)
            txt_t = teacher_model.encode_text(texts)

        # --- Student embeddings (trainable) ---
        img_s = student_model.encode_image(images)
        txt_s = student_model.encode_text(texts)

        # Compute distillation loss
        loss = distillation_loss(img_t, txt_t, img_s, txt_s)

        # Backprop for student model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss/len(train_loader):.4f}")


100%|██████████| 6250/6250 [38:48<00:00,  2.68it/s]


Epoch 1 Loss: 2.5412


100%|██████████| 6250/6250 [38:47<00:00,  2.68it/s]


Epoch 2 Loss: 0.9549


100%|██████████| 6250/6250 [38:45<00:00,  2.69it/s]

Epoch 3 Loss: 0.6968





# EVALUATION (RECALL@1)

In [35]:
from tqdm import tqdm
import torch

def compute_recall_at_1(model, loader, tokenizer_model="teacher"):
    """
    Compute Recall@1 for image-text matching.

    tokenizer_model: "teacher" or "student"
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            images = batch["images"].to(device)

            if tokenizer_model == "teacher":
                # Teacher uses pre-tokenized text tensor
                tokens = batch["texts"].to(device)
            else:
                # Student uses labels to recreate captions and tokenize
                captions = [label_to_caption[label] for label in batch["labels"]]
                tokens = student_tokenizer(captions).to(device)

            # Encode embeddings
            img_embeds = model.encode_image(images)
            txt_embeds = model.encode_text(tokens)

            # Compute similarity and recall@1
            logits = img_embeds @ txt_embeds.t()
            preds = logits.argmax(dim=1)
            labels = torch.arange(len(preds), device=device)

            correct += (preds == labels).sum().item()
            total += len(labels)

    return correct / total
teacher_acc = compute_recall_at_1(teacher_model, val_loader, tokenizer_model="teacher")
student_acc = compute_recall_at_1(student_model, val_loader, tokenizer_model="student")

print(f"Teacher Recall@1: {teacher_acc:.4f}")
print(f"Student Recall@1: {student_acc:.4f}")


Evaluating: 100%|██████████| 1250/1250 [01:01<00:00, 20.41it/s]
Evaluating: 100%|██████████| 1250/1250 [02:04<00:00, 10.01it/s]

Teacher Recall@1: 0.5036
Student Recall@1: 0.5231





# MODEL SIZE & SPEED

In [36]:
import time
import os
import torch

def model_size(model):
    torch.save(model.state_dict(), "temp.pth")
    size = os.path.getsize("temp.pth") / 1e6  # MB
    os.remove("temp.pth")
    return size

def speed_test(model, tokenizer_model="teacher"):
    model.eval()
    dummy_image = torch.randn(1, 3, 224, 224).to(device)

    # Create dummy text
    if tokenizer_model == "teacher":
        dummy_text = teacher_tokenizer(["a photo of a cat"]).to(device)
    else:
        dummy_text = student_tokenizer(["a photo of a cat"]).to(device)

    start = time.time()
    for _ in range(20):
        with torch.no_grad():
            _ = model.encode_image(dummy_image)
            _ = model.encode_text(dummy_text)
    end = time.time()

    return (end - start) / 20  # average time per inference

print("Teacher size:", model_size(teacher_model), "MB")
print("Student size:", model_size(student_model), "MB")
print("Teacher speed:", speed_test(teacher_model, tokenizer_model="teacher"), "sec/inference")
print("Student speed:", speed_test(student_model, tokenizer_model="student"), "sec/inference")


Teacher size: 605.205191 MB
Student size: 398.098845 MB
Teacher speed: 0.02715606689453125 sec/inference
Student speed: 0.048318469524383546 sec/inference


# SAVE STUDENT MODEL


In [38]:

# Save the trained student model's weights
torch.save(student_model.state_dict(), "efficient_clip_student.pth")
