
# Knowledge Distillation: CLIP to MobileNetV3
This notebook implements knowledge distillation from the CLIP visual encoder to a MobileNetV3 model using PyTorch.

**Objective:**
Transfer knowledge from the CLIP model (teacher) to a smaller, lightweight vision model (student), enabling efficient deployment while maintaining robust performance.

**Models:**
- Teacher: CLIP Vision Encoder (`openai/clip-vit-base-patch32`)
- Student: MobileNetV3 (from `torchvision.models`)

**Dataset:**
For demonstration, we use a subset of ImageNet or similar image classification datasets compatible with both models.

**Steps:**
1. Load and preprocess the dataset.
2. Load pre-trained teacher (CLIP) and student (MobileNetV3) models.
3. Implement the distillation training loop.
4. Evaluate the performance of the student model.


In [1]:
!pip install transformers
!pip install ipywidgets

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [2]:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from transformers import CLIPModel

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [3]:

# Dataset and transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load a dataset (e.g., ImageNet or a subset)
# For demonstration, we use CIFAR-10 as a placeholder
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

val_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [4]:

# Load pre-trained teacher model (CLIP)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
teacher_model = clip_model.vision_model.to(device)
teacher_model.eval()

# Load pre-trained student model (MobileNetV3)
student_model = models.mobilenet_v3_small(pretrained=True)

# Replace the final classification layer for CIFAR-10
num_features = student_model.classifier[-1].in_features  # Get input features of the last layer
student_model.classifier[-1] = nn.Linear(num_features, 768)  # CIFAR-10 has 10 classes

student_model = student_model.to(device)


# Print model structures
print("Teacher Model (CLIP):")
print(teacher_model)

print("Student Model (MobileNetV3):")
print(student_model)


Teacher Model (CLIP):
CLIPVisionTransformer(
  (embeddings): CLIPVisionEmbeddings(
    (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (position_embedding): Embedding(50, 768)
  )
  (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (encoder): CLIPEncoder(
    (layers): ModuleList(
      (0-11): 12 x CLIPEncoderLayer(
        (self_attn): CLIPSdpaAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): CLIPMLP(
          (activation_fn): QuickGELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_feature



In [5]:

# Loss function for knowledge distillation
criterion = nn.KLDivLoss(reduction="batchmean")

# Optimizer for student model
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)

# Training loop
from tqdm import tqdm

def train_one_epoch(teacher, student, dataloader, optimizer, device):
    teacher.eval()  # Teacher is frozen
    student.train()
    total_loss = 0

    # Wrap dataloader with tqdm
    pbar = tqdm(dataloader, desc="Training", leave=False)
    for images, _ in pbar:
        images = images.to(device)

        # Get teacher logits
        with torch.no_grad():
            teacher_outputs = teacher(images).last_hidden_state[:, 0, :]

        # Get student logits
        student_outputs = student(images)

        # Compute distillation loss
        loss = criterion(
            nn.functional.log_softmax(student_outputs / 1.0, dim=1),
            nn.functional.softmax(teacher_outputs / 1.0, dim=1)
        )

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pbar.set_postfix({"Loss": loss.item()})  # Update progress bar

    avg_loss = total_loss / len(dataloader)
    print(f"Training Loss: {avg_loss:.4f}")
    return avg_loss


In [6]:

from torch.nn.functional import normalize

def evaluate(student, teacher, dataloader, device):
    student.eval()
    teacher.eval()
    similarities = []

    with torch.no_grad():
        for images, _ in tqdm(dataloader, desc="Evaluation"):
            images = images.to(device)

            # Get embeddings from teacher and student
            teacher_outputs = teacher(images).last_hidden_state[:, 0, :]  # [CLS] token embeddings
            student_outputs = student(images)

            # Normalize embeddings
            teacher_norm = normalize(teacher_outputs, p=2, dim=1)
            student_norm = normalize(student_outputs, p=2, dim=1)

            # Compute cosine similarity
            similarity = (teacher_norm * student_norm).sum(dim=1).mean().item()
            similarities.append(similarity)

    avg_similarity = sum(similarities) / len(similarities)
    print(f"Average Cosine Similarity: {avg_similarity:.4f}")
    return avg_similarity



In [7]:

import time

def benchmark_model(model, device, input_shape=(1, 3, 224, 224), runs=100):
    model = model.to(device)
    model.eval()
    input_tensor = torch.randn(input_shape, device=device)

    # Warm-up runs
    for _ in range(10):
        _ = model(input_tensor)

    # Benchmark
    if device.type == 'cuda':
        torch.cuda.synchronize()
    start_time = time.time()

    for _ in range(runs):
        _ = model(input_tensor)

    if device.type == 'cuda':
        torch.cuda.synchronize()
    elapsed_time = (time.time() - start_time) / runs
    return elapsed_time * 1000  # Convert to milliseconds

# Benchmark CLIP and MobileNetV3 before training
print("Benchmarking before training...")
clip_time_cpu = benchmark_model(teacher_model, torch.device('cpu'))
mobilenet_time_cpu = benchmark_model(student_model, torch.device('cpu'))

if torch.cuda.is_available():
    clip_time_cuda = benchmark_model(teacher_model, torch.device('cuda'))
    mobilenet_time_cuda = benchmark_model(student_model, torch.device('cuda'))
else:
    clip_time_cuda = None
    mobilenet_time_cuda = None

print(f"CLIP Inference Time (CPU): {clip_time_cpu:.2f} ms")
print(f"MobileNetV3 Inference Time (CPU): {mobilenet_time_cpu:.2f} ms")
if clip_time_cuda and mobilenet_time_cuda:
    print(f"CLIP Inference Time (CUDA): {clip_time_cuda:.2f} ms")
    print(f"MobileNetV3 Inference Time (CUDA): {mobilenet_time_cuda:.2f} ms")


Benchmarking before training...
CLIP Inference Time (CPU): 51.98 ms
MobileNetV3 Inference Time (CPU): 11.15 ms
CLIP Inference Time (CUDA): 7.60 ms
MobileNetV3 Inference Time (CUDA): 7.06 ms


In [9]:

# Training and evaluation
num_epochs = 10
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss = train_one_epoch(teacher_model, student_model, train_loader, optimizer, device)
    scheduler.step()  # Step the scheduler
    val_accuracy = evaluate(student_model, teacher_model, val_loader, device)

Epoch 1/10


                                                                         

Training Loss: 0.0104


Evaluation: 100%|██████████| 79/79 [00:18<00:00,  4.30it/s]


Average Cosine Similarity: 0.8510
Epoch 2/10


                                                                         

Training Loss: 0.0090


Evaluation: 100%|██████████| 79/79 [00:19<00:00,  4.07it/s]


Average Cosine Similarity: 0.8672
Epoch 3/10


                                                                         

Training Loss: 0.0082


Evaluation: 100%|██████████| 79/79 [00:19<00:00,  4.06it/s]


Average Cosine Similarity: 0.8770
Epoch 4/10


                                                                         

Training Loss: 0.0078


Evaluation: 100%|██████████| 79/79 [00:19<00:00,  4.05it/s]


Average Cosine Similarity: 0.8787
Epoch 5/10


                                                                         

Training Loss: 0.0078


Evaluation: 100%|██████████| 79/79 [00:19<00:00,  4.09it/s]


Average Cosine Similarity: 0.8795
Epoch 6/10


                                                                         

Training Loss: 0.0077


Evaluation: 100%|██████████| 79/79 [00:19<00:00,  4.09it/s]


Average Cosine Similarity: 0.8803
Epoch 7/10


                                                                         

Training Loss: 0.0077


Evaluation: 100%|██████████| 79/79 [00:19<00:00,  4.07it/s]


Average Cosine Similarity: 0.8804
Epoch 8/10


                                                                         

Training Loss: 0.0077


Evaluation: 100%|██████████| 79/79 [00:19<00:00,  4.08it/s]


Average Cosine Similarity: 0.8805
Epoch 9/10


                                                                         

Training Loss: 0.0077


Evaluation: 100%|██████████| 79/79 [00:19<00:00,  4.07it/s]


Average Cosine Similarity: 0.8806
Epoch 10/10


                                                                         

Training Loss: 0.0077


Evaluation: 100%|██████████| 79/79 [00:19<00:00,  4.07it/s]

Average Cosine Similarity: 0.8806





In [10]:
# Benchmark CLIP and MobileNetV3 after training
print("Benchmarking after training...")
clip_time_cpu = benchmark_model(teacher_model, torch.device('cpu'))
mobilenet_time_cpu = benchmark_model(student_model, torch.device('cpu'))

if torch.cuda.is_available():
    clip_time_cuda = benchmark_model(teacher_model, torch.device('cuda'))
    mobilenet_time_cuda = benchmark_model(student_model, torch.device('cuda'))
else:
    clip_time_cuda = None
    mobilenet_time_cuda = None

print(f"CLIP Inference Time (CPU): {clip_time_cpu:.2f} ms")
print(f"MobileNetV3 Inference Time (CPU): {mobilenet_time_cpu:.2f} ms")
if clip_time_cuda and mobilenet_time_cuda:
    print(f"CLIP Inference Time (CUDA): {clip_time_cuda:.2f} ms")
    print(f"MobileNetV3 Inference Time (CUDA): {mobilenet_time_cuda:.2f} ms")


Benchmarking after training...
CLIP Inference Time (CPU): 52.63 ms
MobileNetV3 Inference Time (CPU): 11.24 ms
CLIP Inference Time (CUDA): 7.40 ms
MobileNetV3 Inference Time (CUDA): 7.12 ms


In [11]:
save_path = "./mobilenetv3_student_model.pth"
torch.save(student_model.state_dict(), save_path)
print(f"MobileNetV3 student model saved to {save_path}")

MobileNetV3 student model saved to ./mobilenetv3_student_model.pth
