# Knowledge Distillation using Contrastive Learning (CLIP)

In this notebook, we explore **Knowledge Distillation** from a large monolingual model into a smaller multilingual model using **Contrastive Learning**, specifically leveraging the **CLIP** (Contrastive Language-Image Pretraining) loss.

We employ a small paired **English-Persian** dataset to define the loss pairs for training. This demonstrates the mechanics of using CLIP loss for aligning embedding spaces.

### Overview

**CLIP** bridges the gap between modalities (or languages in this case) by aligning corresponding embeddings in a shared space. It uses a **contrastive loss** to ensure that:
- **Positive pairs** (e.g., an English sentence and its Persian translation) have high similarity.
- **Negative pairs** (unmatched samples) have low similarity.

**Knowledge Distillation** transfers knowledge from a large "teacher" model to a smaller "student" model. Here:
- **Teacher:** `EVA02-E-14-plus` (Large, Pretrained).
- **Student:** `setu4993/smaller-LaBSE` (Small, Multilingual).

### Challenges
CLIP typically requires massive batch sizes (>19k) to provide sufficient negative samples. Due to resource constraints (Colab/Local GPU), we use a smaller batch size to demonstrate the *procedure* and training dynamics, rather than achieving state-of-the-art performance.



## Setup


In [None]:
# Install necessary data files
# !gdown "https://drive.google.com/uc?id=1MVx_gIkX4tQ8ya2OsHt0mqLmw1Pf2CcK" -O train.csv
# !gdown "https://drive.google.com/uc?id=1Co-dwJfWw-C_ral0hoAS_X94wN-_vbCj" -O val.csv

Downloading...
From: https://drive.google.com/uc?id=1MVx_gIkX4tQ8ya2OsHt0mqLmw1Pf2CcK
To: /kaggle/working/train.csv
100%|███████████████████████████████████████| 7.35M/7.35M [00:00<00:00, 151MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Co-dwJfWw-C_ral0hoAS_X94wN-_vbCj
To: /kaggle/working/val.csv
100%|███████████████████████████████████████| 2.45M/2.45M [00:00<00:00, 198MB/s]


In [None]:
# Install required packages
import sys
import subprocess
import itertools
import re
import math
import random
import string
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from datasets import load_dataset, Dataset
from transformers import AutoConfig, AutoTokenizer, AutoModel
import open_clip
from open_clip import model as TE
from tqdm import tqdm

def install_package(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

REQUIRED_PACKAGES = [
    'open_clip_torch',
    'gdown',
    'pandas',
    'numpy',
    'matplotlib',
    'transformers',
    'tqdm',
    'torch',
    'datasets',
]

# Uncomment to install packages if needed
# for package in REQUIRED_PACKAGES:
#     try:
#         import pkg_resources
#         pkg_resources.get_distribution(package)
#     except Exception:
#         install_package(package)

print("Imports complete.")



  import pkg_resources


open_clip-torch is NOT installed
Collecting open_clip-torch
  Downloading open_clip_torch-2.29.0-py3-none-any.whl.metadata (31 kB)
Collecting ftfy (from open_clip-torch)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading open_clip_torch-2.29.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy, open_clip-torch
Successfully installed ftfy-6.3.1 open_clip-torch-2.29.0
open_clip-torch was successfully installed.
pandas (2.2.3) is installed
numpy (1.26.4) is installed
matplotlib (3.7.5) is installed
transformers (4.46.3) is installed
tqdm (4.66.4) is installed
torch (2.4.0) is installed
datasets (3.1.0) is installed


In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        print("Warning: CUDA not found, using CPU.")
        return torch.device("cpu")

device = get_device()
print(f"Using device: {device}")

## Configuration
Settings for the teacher model (OpenCLIP) and student model (LaBSE), along with training hyperparameters.

In [None]:
configs = {
    "device": device,
    "reference_checkPoint" : "EVA02-E-14-plus",                # Teacher: OpenCLIP model
    "candidate_checkpoint" : "setu4993/smaller-LaBSE",         # Student: Multilingual BERT-like model
    "train_path" : "train.csv",
    "val_path" : "val.csv",
    "save_path" : "./best-model.pth",
    "english" : "en",
    "persian" : "fa",
    "batch_size": 128,          
    "lr": 1e-4,
    "epochs": 3,
    "tok_percentile" : 99,
    "temperature": 0.07,
    "dropout": 0.05,
    "unfreezed_layers" : 10,   # Fine-tune only the top layers
    "weight_decay": 1e-5,
    "patience": 1,
    "factor" : 0.8,
    # Teacher Model Specs (EVA02-E-14-plus)
    "reference_embedding": 1024,
    "reference_context_length" : 77,
    "reference_vocab_size" : 49408,
    "reference_heads" : 20,
    "reference_width" : 1280,
    "reference_layers" : 32,
    "cls_token_index" : 0,
    "project_to" : 1024,       # Projection dimension for alignment
}

## Key Concepts

**1. Temperature in Contrastive Learning**
Temperature ($\tau$) scales the logits in the softmax function.
- **Low $\tau$:** Sharpens the distribution, making the model more confident about the positive pair and penalizing negatives more harshly.
- **High $\tau$:** Smooths the distribution.
We use a small temperature (e.g., 0.07) to encourage the model to be discriminative.

**2. Freezing Layers**
We freeze the majority of the pre-trained student model and only fine-tune the top layers (`unfreezed_layers: 10`) and the projection head. This:
- Preserves the robust language understanding acquired during pre-training.
- Reduces computational cost and prevents overfitting on our smaller dataset.

**3. Token Percentile**
`tok_percentile` determines the sequence length that covers, for example, 99% of the dataset. We truncate/pad to this length instead of the absolute maximum to optimize memory usage without losing significant information.

## Data Preprocessing
Functions to load CSVs, normalize text (Persian/English), and create DataLoaders.

In [None]:
def get_datasets_csv(prev_en_col, prev_fa_col, new_en_col, new_fa_col, train_path, val_path):
    # Check if files exist
    try:
        df = pd.read_csv(train_path)
        df_val = pd.read_csv(val_path)
    except FileNotFoundError:
        print("Dataset files not found. Please ensure train.csv and val.csv are present.")
        # Create dummy data for demonstration if files are missing
        df = pd.DataFrame({prev_en_col: ["Hello", "World"], prev_fa_col: ["سلام", "جهان"]})
        df_val = pd.DataFrame({prev_en_col: ["Test", "Data"], prev_fa_col: ["تست", "داده"]})

    if df.empty or df_val.empty:
        raise ValueError("Dataset is empty")

    df_train = df.loc[:, [prev_en_col, prev_fa_col]].rename(columns={prev_en_col: new_en_col, prev_fa_col: new_fa_col})
    df_val = df_val.loc[:, [prev_en_col, prev_fa_col]].rename(columns={prev_en_col: new_en_col, prev_fa_col: new_fa_col})

    dataset_train = Dataset.from_pandas(df_train)
    dataset_val = Dataset.from_pandas(df_val)

    return dataset_train, dataset_val

def get_ds_by_lang(persian_col, english_col):
    def get_persian_ds(dataset):
        return dataset[persian_col]

    def get_english_ds(dataset):
        return dataset[english_col]

    return get_persian_ds, get_english_ds

class TextNormalizer():
    def __init__(self):
        # Persian normalization table
        translation_src = ' ىكي“”0123456789%إأآئيؤةك'
        translation_dst = ' یکی""۰۱۲۳۴۵۶۷۸۹٪اااییوهک'
        self.translations = str.maketrans(translation_src, translation_dst)

        patterns = [
            (r' {2,}', ' '),       # extra spaces
            (r'\n+', ' '),         # newlines
            (r'\u200c+', ' '),     # ZWNJs
            (r'[ـ\r]', '')         # keshide, carriage returns
        ]
        self.character_refinement_patterns = [(re.compile(p), r) for p, r in patterns]

    def normalize_fa(self, text):
        text = text.lower().translate(self.translations)
        text = re.sub('[^a-zA-Z۰-۹آ-ی ]', ' ', text)
        for pattern, repl in self.character_refinement_patterns:
            text = pattern.sub(repl, text)
        return text.strip()

    def normalize_en(self, text):
        text = text.lower()
        text = text.translate(str.maketrans('', '', string.punctuation))
        return text.strip()

def apply_preprocess(datasets, configs=configs):
    normalizer = TextNormalizer()
    
    def apply_row_normalization(example):
        example[configs['persian']] = normalizer.normalize_fa(example[configs['persian']])
        example[configs['english']] = normalizer.normalize_en(example[configs['english']])
        return example

    new_datasets = []
    for dataset in datasets:
        new_datasets.append(dataset.map(apply_row_normalization))

    return new_datasets

## Utilities
Helper functions for token handling and metrics.

In [None]:
def get_cls_token(tensor, configs=configs):
    """
    Extracts the classification (CLS) token from the input tensor.
    """
    cls_id = configs["cls_token_index"]
    return tensor[:, cls_id, :].unsqueeze(1)

def flatten_middle(tensor):
    """
    Flattens the middle dimension (sequence length) of the input tensor.
    """
    return tensor.view(tensor.size(0), -1)

def freeze_model(model, freeze=True):
    """
    Freeze or unfreeze all parameters of a given model.
    """
    for param in model.parameters():
        param.requires_grad = not freeze
    return model

def plot_metric(metric_data, metric_name):
    if metric_name is None or metric_name not in metric_data:
        raise ValueError("No such metric")
    metric_data[metric_name].plot()
    plt.xlabel('Epochs')
    plt.ylabel(metric_name)
    plt.title(f'Plot of {metric_name}')
    plt.show()

def calc_percentile_tokens(dataset, tokenizer, field, percentile=configs["tok_percentile"], threshold=1):
    """
    Calculate the token length at a specific percentile for a dataset field.
    """
    tokenized = tokenizer(dataset[field])
    
    if isinstance(tokenized, dict) and 'input_ids' in tokenized:
        token_lengths = [len(x) for x in tokenized['input_ids']]
    elif isinstance(tokenized, torch.Tensor):
        token_lengths = [t.nonzero().size(0) for t in tokenized]
    else:
        # Fallback for list of lists
        token_lengths = [len(x) for x in tokenized]

    if not token_lengths:
        return 64 # Default fallback

    percentile_length = np.percentile(token_lengths, percentile)
    return int(percentile_length) + threshold

# Text Encoder Factory (Teacher)
def TextEncoder(configs):
    new_model = TE.TextTransformer(
        context_length=configs['reference_context_length'],
        vocab_size=configs["reference_vocab_size"],
        width=configs["reference_width"],
        layers=configs["reference_layers"],
        heads=configs["reference_heads"],
        output_dim=configs["reference_embedding"]
    )
    return new_model

In [None]:
class Swish(nn.Module):
    """
    Swish activation function: x * sigmoid(beta * x)
    """
    def __init__(self, beta=1.0):
        super().__init__()
        self.beta = beta

    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)

class LinearProjection(nn.Module):
    """
    Projection Head: Projects embeddings to the shared space.
    - Linear -> Swish -> BatchNorm -> Linear -> Dropout -> LayerNorm (Residual)
    """
    def __init__(self, embedding_dim, projection_dim=configs['project_to'], dropout=configs['dropout']):
        super(LinearProjection, self).__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.swish = Swish(beta=1.0)
        self.batch_norm = nn.BatchNorm1d(projection_dim)
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        out = self.swish(projected)
        out = self.batch_norm(out)
        out = self.fc(out)
        out = self.dropout(out)
        # Residual connection
        return self.layer_norm(out + projected)

class CandidateModel(nn.Module):
    """
    Student Model: Wraps the multilingual base model (e.g., LaBSE) and a projection head.
    """
    def __init__(self, model_name, unfreeze_layers, trainable=True):
        super().__init__()
        self.configs = AutoConfig.from_pretrained(model_name)
        # We assume the config has a hidden_size
        hidden_size = getattr(self.configs, 'hidden_size', 768)
        
        self.candidateProjection = LinearProjection(embedding_dim=hidden_size, projection_dim=configs['project_to'])
        
        self.model = AutoModel.from_pretrained(model_name)
        
        # Freezing logic
        if trainable:
            # Freeze all first
            for param in self.model.parameters():
                param.requires_grad = False
            
            # Unfreeze the last `unfreeze_layers` layers of the encoder
            # Adapting for BERT-like architectures where layers are in model.encoder.layer
            if hasattr(self.model, 'encoder') and hasattr(self.model.encoder, 'layer'):
                total_layers = len(self.model.encoder.layer)
                for i in range(total_layers - unfreeze_layers, total_layers):
                    for param in self.model.encoder.layer[i].parameters():
                        param.requires_grad = True
            elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'layer'):
                 # For some other Transformers
                 total_layers = len(self.model.transformer.layer)
                 for i in range(total_layers - unfreeze_layers, total_layers):
                    for param in self.model.transformer.layer[i].parameters():
                        param.requires_grad = True
                        
            # Ensure pooler (if exists) is trainable if we were using it, 
            # but we use CLS token manually so we check pooler usage? 
            # Usually we retrain the pooler if we use it. 
            if hasattr(self.model, 'pooler') and self.model.pooler is not None:
                for param in self.model.pooler.parameters():
                    param.requires_grad = True
        
        self.batchNorm = nn.BatchNorm1d(1, hidden_size)
        self.targetTokenIdx = configs["cls_token_index"]

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        # Extract CLS token
        cls_embed = get_cls_token(output.last_hidden_state) 
        cls_embed = self.batchNorm(cls_embed)
        # Flatten and Project
        cls_embed = self.candidateProjection(flatten_middle(cls_embed))
        return cls_embed

## Training Loop
Implementation of the Symmetric Contrastive Loss (CLIP Loss) and the training/validation loops.

In [None]:
def calc_loss(batch, reference_model, candidate_model, temperature):
    """
    Compute Symmetric Contrastive Loss:
    Loss = (CrossEntropy(sim(A,B)) + CrossEntropy(sim(B,A))) / 2
    """
    candidate_tokenized = batch["candidate"].to(configs["device"])
    reference_tokenized = batch["reference"].to(configs["device"])

    # Forward pass
    reference_embeds = reference_model(reference_tokenized) # Teacher embeddings
    candidate_embeds = candidate_model(
        input_ids=candidate_tokenized["input_ids"],
        attention_mask=candidate_tokenized["attention_mask"]
    ) # Student embeddings

    # Normalize
    reference_embeds = F.normalize(reference_embeds, p=2, dim=-1)
    candidate_embeds = F.normalize(candidate_embeds, p=2, dim=-1)

    # Similarity matrix
    logits = torch.matmul(candidate_embeds, reference_embeds.T) / temperature
    
    # Targets are diagonal (0, 1, 2, ...) since pairs are aligned in the batch
    targets = torch.arange(logits.size(0)).to(configs["device"])
    
    loss = (
        F.cross_entropy(logits, targets) +
        F.cross_entropy(logits.T, targets)
    ) / 2
    
    preds = torch.argmax(logits, dim=1)
    corrects = (preds == targets).sum().item()

    return loss, corrects

def train_loop(dataloader, models, reference_tokenizer, candidate_tokenizer, optimizer, temperature):
    models['candidateModel'].train()

    total_loss = 0.0
    total_corrects = 0
    total_samples = 0

    print("Training...")
    for (index, pairs) in tqdm(enumerate(dataloader), total=len(dataloader)):
        # Tokenize (using get_ds_by_lang functions implicitly or direct keys)
        # Note: pairs is a batch dictionary from HuggingFace dataset
        
        # Tokenize Persian (Student Candidate)
        candidate_tokenized = candidate_tokenizer(
            pairs[configs['persian']], 
            padding='max_length', 
            truncation=True, 
            return_tensors="pt", 
            max_length=configs["fa_tok_percentile"]
        )
        
        # Tokenize English (Teacher Reference)
        reference_text_tokenized = reference_tokenizer(
            pairs[configs['english']]
        )

        batch = {
            "candidate" : candidate_tokenized,
            "reference" : reference_text_tokenized
        }

        loss, corrects = calc_loss(batch, models['referenceModel'], models['candidateModel'], temperature)
        
        total_corrects += corrects
        total_loss += loss.item() * len(pairs[configs['persian']]) # Sum up loss
        total_samples += len(pairs[configs['persian']])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_loss = total_loss / total_samples
    avg_accuracy = total_corrects / total_samples

    print(f"Train Loss: {avg_loss:.4f} | Train Acc: {avg_accuracy:.4f}")
    return avg_loss, avg_accuracy

def val_loop(dataloader, models, reference_tokenizer, candidate_tokenizer, temperature):
    models['candidateModel'].eval()

    print("Validating...")
    total_loss = 0.0
    total_corrects = 0
    total_samples = 0

    with torch.no_grad():
        for (index, pairs) in tqdm(enumerate(dataloader), total=len(dataloader)):
            candidate_tokenized = candidate_tokenizer(
                pairs[configs['persian']], 
                padding='max_length', 
                truncation=True, 
                return_tensors="pt", 
                max_length=configs["fa_tok_percentile"]
            )
            
            # Note: OpenCLIP tokenizer usually returns tensor directly or we assume it matches
            reference_text_tokenized = reference_tokenizer(
                pairs[configs['english']]
            )

            batch = {
                "candidate" : candidate_tokenized,
                "reference" : reference_text_tokenized
            }

            loss, corrects = calc_loss(batch, models['referenceModel'], models['candidateModel'], temperature)

            total_corrects += corrects
            total_loss += loss.item() * len(pairs[configs['persian']])
            total_samples += len(pairs[configs['persian']])

    avg_loss = total_loss / total_samples
    avg_accuracy = total_corrects / total_samples

    print(f"Val Loss: {avg_loss:.4f} | Val Acc: {avg_accuracy:.4f}")
    return avg_loss, avg_accuracy

In [None]:
# Preprocessing Pipeline
dataset_train, dataset_val = get_datasets_csv(
    "en", "fa", 
    configs["english"], configs["persian"], 
    configs["train_path"], configs["val_path"]
)

# Apply Normalization
dataset_train, dataset_val = apply_preprocess([dataset_train, dataset_val], configs)

print("Example processed row:", dataset_train[0])

# Dataloaders
train_dataloader = DataLoader(dataset_train, batch_size=configs['batch_size'], shuffle=True)
val_dataloader = DataLoader(dataset_val, batch_size=configs['batch_size'], shuffle=False)

## Model Initialization
Initialize tokenizers and models, and calculate dynamic token percentiles for efficient padding.

In [None]:
# 1. Tokenizers
reference_tokenizer = open_clip.get_tokenizer(configs["reference_checkPoint"])
candidate_tokenizer = AutoTokenizer.from_pretrained(configs["candidate_checkpoint"])

# Dynamically update config with student's hidden size
candidate_config_obj = AutoConfig.from_pretrained(configs["candidate_checkpoint"])
configs["candidate_embedding"] = candidate_config_obj.hidden_size

# 2. Calculate Token Percentiles
fa_token_len = calc_percentile_tokens(dataset_train, candidate_tokenizer, configs["persian"])
en_token_len = calc_percentile_tokens(dataset_train, reference_tokenizer, configs["english"])

configs["en_tok_percentile"] = en_token_len
configs["fa_tok_percentile"] = fa_token_len

print(f"Calculated Sequence Lengths - FA: {fa_token_len}, EN: {en_token_len}")

# 3. Models
reference_model = TextEncoder(configs).to(configs["device"])
candidate_model = CandidateModel(
    model_name=configs["candidate_checkpoint"], 
    unfreeze_layers=configs["unfreezed_layers"]
).to(configs["device"])

# Freeze Teach Model completely
reference_model = freeze_model(reference_model, freeze=True)

models = {
    "referenceModel" : reference_model,
    "candidateModel" : candidate_model
}

print(f"Model {configs['candidate_checkpoint']} initialized.")

## Execution
Start the training process.

In [None]:
# Optimizer
# Temperature is learned
temperature = torch.nn.Parameter(torch.tensor(configs['temperature']).to(configs['device']).float())

optimizer = torch.optim.AdamW(
    list(models['candidateModel'].parameters()) + [temperature], 
    weight_decay=configs["weight_decay"], 
    lr=configs['lr']
)

lr_scheduler = ReduceLROnPlateau(
    optimizer, 'max', 
    patience=configs['patience'], 
    factor=configs['factor']
)

best_val_acc = float('-inf')
metrics = pd.DataFrame(columns=["Avg-train-loss", "Avg-train-accuracy", "Avg-val-loss", "Avg-val-accuracy"])

for t in range(configs['epochs']):
    print(f"\nEpoch {t+1}/{configs['epochs']}")
    train_loss, train_acc = train_loop(
        train_dataloader, models, reference_tokenizer, candidate_tokenizer, optimizer, temperature
    )
    val_loss, val_acc = val_loop(
        val_dataloader, models, reference_tokenizer, candidate_tokenizer, temperature
    )

    metrics.loc[t+1] = [train_loss, train_acc, val_loss, val_acc]
    
    # Save best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(models['candidateModel'].state_dict(), configs['save_path'])
        print(f"New best model saved with Acc: {val_acc:.4f}")

    lr_scheduler.step(val_acc)
    print(f"Temperature: {temperature.item():.4f}")

print(f'\nTraining Complete. Best Validation Accuracy: {best_val_acc:.4f}')

In [None]:
# Plot Metrics
plot_metric(metrics, "Avg-val-accuracy")
print(metrics.tail(3))