# Knowledge distillation

### installing dependencies

In [None]:
%pip install datasets \
             transformers \
             torch \
             sentence-transformers \
             numpy \
             tqdm \
             scikit-learn \
             hazm \
             google.drive

### mounting google drive

In [None]:
from google.colab import drive #type: ignore
# Force unmount and remount with full access
drive.flush_and_unmount()
drive.mount('/content/drive', force_remount=True)
!ls /content/drive/MyDrive/embed-distill/  # Check files in the directory

Mounted at /content/drive


## 1. Imports and defaults

In [3]:
import os
import re
import numpy as np
import torch
from tqdm.auto import tqdm
from datasets import load_dataset
from hazm import SentenceTokenizer
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
from IPython.core.getipython import get_ipython

# --- Paths 
BASE_DIR = '/content/drive/MyDrive/colab/embed-distill' if 'google.colab' in str(get_ipython()) else os.getcwd()
DATA_DIR = os.path.join(BASE_DIR, 'data')
EMB_DIR  = os.path.join(BASE_DIR, 'embeddings')
MODEL_DIR = os.path.join(BASE_DIR, 'models')

for d in [DATA_DIR, EMB_DIR, MODEL_DIR]:
    os.makedirs(d, exist_ok=True)
    
print("Base directory:", BASE_DIR)

# --- Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
if device == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')
else:
    print('No GPU found — make sure you are on Colab with GPU runtime for teacher inference')

Base directory: /content/drive/MyDrive/colab/embed-distill
Device: cuda
GPU: Tesla T4
VRAM: 15.6 GB


## 2. Data loading

In [31]:
print("Loading Persian Wikipedia (streaming)...")
dataset = load_dataset(
    'wikimedia/wikipedia',
    '20231101.fa',
    split='train',
    streaming=True,
    trust_remote_code=True
)

sent_tokenizer = SentenceTokenizer()

def clean_and_split(article_text):
    """Clean wiki markup and split into sentences."""
    # Remove wiki artifacts
    text = re.sub(r'http\S+', '', article_text)
    text = re.sub(r'\{\{.*?\}\}', '', text, flags=re.DOTALL)
    text = re.sub(r'\[\[(?:[^|\]]*\|)?([^\]]*)\]\]', r'\1', text)
    text = re.sub(r'[=\*#\|<>{}\[\]]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()

    # Split into sentences
    sentences = sent_tokenizer.tokenize(text)
    return sentences

def is_valid(sentence):
    """Quality filter."""
    if len(sentence) < 20 or len(sentence) > 512:
        return False
    persian_ratio = len(re.findall(r'[\u0600-\u06FF]', sentence)) / len(sentence)
    if persian_ratio < 0.4:           # at least 40% Persian characters
        return False
    return True

# --- Extract sentences
TARGET = 1_200_000   # collect a bit more than 1m to have buffer
sentences = []

for article in tqdm(dataset, desc='Scanning articles'):
    sents = clean_and_split(article['text'])
    sentences.extend([s for s in sents if is_valid(s)])
    if len(sentences) >= TARGET:
        break

# Deduplicate
sentences = list(dict.fromkeys(sentences))[:1_000_000]
print(f'\nFinal sentence count: {len(sentences):,}')

# Save raw sentences
sentences_path = os.path.join(DATA_DIR, 'persian_sentences.txt')
with open(sentences_path, 'w', encoding='utf-8') as f:
    f.write('\n'.join(sentences))

print(f'Saved to {sentences_path}')
print(f'Sample sentences:')
for s in sentences[:3]:
    print(f'  → {s}')

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'wikimedia/wikipedia' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'wikimedia/wikipedia' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loading Persian Wikipedia (streaming)...


Scanning articles: 0it [00:00, ?it/s]


Final sentence count: 1,000,000
Saved to /content/drive/MyDrive/colab/embed-distill/data/persian_sentences.txt
Sample sentences:
  → مقاله‌های برگزیده – مقالهٔ امروز بیشتر… امروز: ، میلادی برابر هجری خورشیدی و (UTC) → روز قبل – روز بعد ←یادبودهای – یادبودهای بیشتر… بایگانی – نگاره‌های برگزیدهٔ بیشتر
  → ویکی‌پدیا یک دانشنامه برخط چندزبانه مبتنی بر وب با محتوای آزاد و همکاری باز است که با همکاری افراد داوطلب نوشته می‌شود و هر کسی که به اینترنت و وب دسترسی داشته باشد می‌تواند مقالات آن را ببیند و ویرایش کند.
  → نام ویکی‌پدیا از پیوند واژه «ویکی» (به معنی وبگاه مشارکتی) با «پدیا» (گرفته‌شده از پسوند واژه encyclopedia به معنی دانشنامه یا دائرةالمعارف) ایجاد شده‌است.


## 3. Data pipeline

In [39]:
# Load saved sentences
sentences_path = os.path.join(DATA_DIR, 'persian_sentences.txt')
with open(sentences_path, 'r', encoding='utf-8') as f:
    sentences = [line.strip() for line in f if line.strip()]

print(f'Total sentences loaded: {len(sentences):,}')

# Train/Val split (95/5)
train_sentences, val_sentences = train_test_split(
    sentences,
    test_size=0.05,
    random_state=42
)

print(f'Train: {len(train_sentences):,}')
print(f'Val:   {len(val_sentences):,}')

# Save splits
for name, split in [('train', train_sentences), ('val', val_sentences)]:
    path = os.path.join(DATA_DIR, f'{name}_sentences.txt')
    with open(path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(split))
    print(f'Saved {name} split → {path}')

Total sentences loaded: 1,000,000
Train: 950,000
Val:   50,000
Saved train split → /content/drive/MyDrive/colab/embed-distill/data/train_sentences.txt
Saved val split → /content/drive/MyDrive/colab/embed-distill/data/val_sentences.txt


## 4. Teacher

In [None]:
# --- Config
BATCH_SIZE = 64        # reduce to 16 if you get OOM
SAVE_EVERY = 10_000    # checkpoint every 10K sentences

# --- Load teacher
print("Loading Jina-v3 teacher model...")
teacher = SentenceTransformer(
    'jinaai/jina-embeddings-v3',
    trust_remote_code=True,
    device=device
)
print(f"Teacher embedding dim: {teacher.get_sentence_embedding_dimension()}")

# --- Load train sentences
with open(os.path.join(DATA_DIR, 'train_sentences.txt'), 'r', encoding='utf-8') as f:
    train_sentences = [l.strip() for l in f if l.strip()]
print(f"Sentences to embed: {len(train_sentences):,}")

# --- Resume support: check how many already embedded
emb_path    = os.path.join(EMB_DIR, 'train_embeddings.npy')
index_path  = os.path.join(EMB_DIR, 'train_embedded_count.txt')

if os.path.exists(index_path):
    with open(index_path) as f:
        start_idx = int(f.read().strip())
    embeddings = list(np.load(emb_path))
    print(f"Resuming from sentence {start_idx:,}")
else:
    start_idx  = 0
    embeddings = []
    print("Starting fresh")

# --- Generate embeddings in batches with checkpointing
teacher.eval()
with torch.no_grad():
    for i in tqdm(range(start_idx, len(train_sentences), BATCH_SIZE), desc='Embedding'):
        batch = train_sentences[i : i + BATCH_SIZE]
        embs  = teacher.encode(
            batch,
            task='text-matching',     # Jina-v3 is task-aware, this is correct task for STS
            batch_size=BATCH_SIZE,
            show_progress_bar=False,
            convert_to_numpy=True,
            normalize_embeddings=True
        )
        embeddings.extend(embs)

        # Checkpoint
        if len(embeddings) % SAVE_EVERY == 0:
            np.save(emb_path, np.array(embeddings))
            with open(index_path, 'w') as f:
                f.write(str(len(embeddings)))

# --- Final save
embeddings = np.array(embeddings)
np.save(emb_path, embeddings)
print(f"\n✅ Done!")
print(f"Embeddings shape: {embeddings.shape}")
print(f"Saved to: {emb_path}")

Loading Jina-v3 teacher model...


modules.json:   0%|          | 0.00/378 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/464 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

custom_st.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/jina-embeddings-v3:
- custom_st.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


config.json: 0.00B [00:00, ?B/s]

configuration_xlm_roberta.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- configuration_xlm_roberta.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_lora.py: 0.00B [00:00, ?B/s]

modeling_xlm_roberta.py: 0.00B [00:00, ?B/s]

embedding.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- embedding.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


rotary.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- rotary.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


mha.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- mha.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


xlm_padding.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- xlm_padding.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


block.py: 0.00B [00:00, ?B/s]

stochastic_depth.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- stochastic_depth.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


mlp.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- mlp.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- block.py
- stochastic_depth.py
- mlp.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/jinaai/xlm-roberta-flash-implementation:
- modeling_xlm_roberta.py
- embedding.py
- rotary.py
- mha.py
- xlm_padding.py
- block.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downl

model.safetensors:   0%|          | 0.00/1.14G [00:00<?, ?B/s]



tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/192 [00:00<?, ?B/s]

Teacher embedding dim: 1024
Sentences to embed: 950,000
Starting fresh


Embedding:   0%|          | 0/29688 [00:00<?, ?it/s]

: 