In [None]:
import torch
from torch.utils.data import DataLoader, Dataset, ConcatDataset, random_split
from transformers import XLMRobertaTokenizer, XLMRobertaModel
from datasets import Dataset
import os
from google.colab import drive
import wandb
import pandas as pd
from tqdm import tqdm

# Mount Google Drive
drive.mount('/content/drive')

# Define path in Google Drive
DATA_DIR = "/content/drive/MyDrive/KD-EE-XLMR/minimal_data/distillation_data_split"
MODEL_DIR = "/content/drive/MyDrive/KD-EE-XLMR/models/l12"
os.makedirs(MODEL_DIR, exist_ok=True)

Mounted at /content/drive


In [None]:
# Initialize tokenizer
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-large")

# Define languages
languages = {'af': 1.3, 'am': 0.8, 'ar': 28.0, 'as': 0.1, 'az': 6.5,
       'be': 4.3, 'bg': 57.5, 'bn': 8.4, 'bn_rom': 0.5, 'br': 0.1,
       'bs': 0.1, 'ca': 10.1, 'cs': 16.3, 'cy': 0.8, 'da': 45.6,
       'de': 66.6, 'el': 46.9, 'en': 300.8, 'eo': 0.9, 'es': 53.3,
       'et': 6.1, 'eu': 2.0, 'fa': 111.6, 'fi': 54.3, 'fr': 56.8,
       'fy': 0.2, 'ga': 0.5, 'gd': 0.1, 'gl': 2.9, 'gu': 1.9, 'ha': 0.3,
       'he': 31.6, 'hi': 20.2, 'hi_rom': 0.5, 'hr': 20.5, 'hu': 58.4,
       'hy': 5.5, 'id': 148.3, 'is': 3.2, 'it': 30.2, 'ja': 69.3,
       'jv': 0.2, 'ka': 9.1, 'kk': 6.4, 'km': 1.5, 'kn': 3.3, 'ko': 54.2,
       'ku': 0.4, 'ky': 1.2, 'la': 2.5, 'lo': 0.6, 'lt': 13.7, 'lv': 8.8,
       'mg': 0.2, 'mk': 4.8, 'ml': 7.6, 'mn': 3.0, 'mr': 2.8, 'ms': 8.5,
       'my': 0.4, 'my_zaw': 1.6, 'ne': 3.8, 'nl': 29.3, 'no': 49.0,
       'om': 0.1, 'or': 0.6, 'pa': 0.8, 'pl': 44.6, 'ps': 0.7,
       'pt': 49.1, 'ro': 61.4, 'ru': 278.0, 'sa': 0.3, 'sd': 0.4,
       'si': 3.6, 'sk': 23.2, 'sl': 10.3, 'so': 0.4, 'sq': 5.4,
       'sr': 9.1, 'su': 0.1, 'sv': 12.1, 'sw': 1.6, 'ta': 12.2,
       'ta_rom': 0.3, 'te': 4.7, 'te_rom': 0.3, 'th': 71.7, 'tl': 3.1,
       'tr': 20.9, 'ug': 0.4, 'uk': 84.6, 'ur': 5.7, 'ur_rom': 0.5,
       'uz': 0.7, 'vi': 137.3, 'xh': 0.1, 'yi': 0.3, 'zh-Hans': 46.9,
       'zh-Hant': 16.6
}

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

In [None]:
# Function to load preprocessed dataset
def load_preprocessed(dataset_path):
    # Load .parquet file into a Pandas DataFrame
    df = pd.read_parquet(dataset_path)

    # Convert to Hugging Face Dataset
    dataset = Dataset.from_pandas(df)

    # Ensure correct format for PyTorch tensors
    dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

    return dataset  # Returns a Dataset object

def load_chunk(lang, chunk_idx, chunk_dir=DATA_DIR):
    file_path = os.path.join(chunk_dir, f"{lang}_{chunk_idx}.parquet")
    return load_preprocessed(file_path)

# Load teacher and student models
teacher_model = XLMRobertaModel.from_pretrained('xlm-roberta-large', output_hidden_states=True).cuda()
student_model = XLMRobertaModel.from_pretrained('xlm-roberta-base', output_hidden_states=True).cuda()

# Projection layer for student model (since teacher has 1024 hidden size and student has 768)
projection_layer = torch.nn.Linear(768, 1024).cuda()

# Define distillation losses
kl_loss_fn = torch.nn.KLDivLoss(reduction='batchmean')
mse_loss_fn = torch.nn.MSELoss()

# Optimizer
optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-5, weight_decay=0.01)

model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

In [None]:
# WandB for logging
wandb.init(project="xlmr-distillation", name="distillation-l12")

# Define which layers to distill
# layer_mapping = {1: 2, 3: 6, 5: 10, 7: 14, 9: 18, 11: 22}  # L7
layer_mapping = {1: 2, 2: 4, 3: 6, 4: 8, 5: 10, 6: 12, 7: 14, 8: 16, 9: 18, 10: 20, 11: 22}  # L12

# Training hyperparameters
batch_size = 128
num_epochs = 5
eval_every = 3000
patience = 3
min_delta = 0.01
best_val_loss = float('inf')
no_improve_count = 0
global_step = 0
scaler = torch.cuda.amp.GradScaler()  # Enables mixed precision
gradient_accumulation_steps = 4  # Adjust based on GPU memory

# Loss hyperparameters
temperature = 2.0
lambda_final = 1.0
lambda_intermediate = 0.5

# Training loop
for epoch in range(num_epochs):
    print(f"\n========== Epoch {epoch+1} ==========")
    chunk_idx = epoch % 5
    print(f"Loading chunk {chunk_idx} from distillation data...")
    # Dynamically load the current chunk for each language
    datasets = [load_chunk(lang, chunk_idx) for lang in languages]
    mixed_dataset = ConcatDataset(datasets)

    # Re-split train/val
    train_size = int(0.999 * len(mixed_dataset))
    val_size = len(mixed_dataset) - train_size
    train_dataset, val_dataset = random_split(mixed_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)

    epoch_steps = len(train_loader)
    running_loss = 0.0
    student_model.train()
    optimizer.zero_grad()  # Outside the loop to accumulate gradients

    for step, batch in enumerate(tqdm(train_loader, desc=f"Training")):
        input_ids = batch["input_ids"].cuda()
        attention_mask = batch["attention_mask"].cuda()

        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask).hidden_states

        with torch.cuda.amp.autocast():  # Mixed percision training
          student_outputs = student_model(input_ids, attention_mask=attention_mask).hidden_states
          loss = 0.0

          student_last = projection_layer(student_outputs[-1])
          teacher_last = teacher_outputs[-1]
          final_loss = kl_loss_fn(
              (student_last / temperature).log_softmax(dim=-1),
              (teacher_last / temperature).softmax(dim=-1),
          ) * (temperature**2)
          loss += lambda_final * final_loss

          for student_layer, teacher_layer in layer_mapping.items():
              student_rep = projection_layer(student_outputs[student_layer])
              teacher_rep = teacher_outputs[teacher_layer]
              loss += lambda_intermediate * mse_loss_fn(student_rep, teacher_rep) / len(layer_mapping)

        # Scale loss for mixed precision
        scaler.scale(loss).backward()

        # Gradient accumulation: Only update weights every gradient_accumulation_steps
        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_loader):
          scaler.step(optimizer)
          scaler.update()
          optimizer.zero_grad()

        global_step += 1
        running_loss += loss.item()
        wandb.log({"Batch Loss": loss.item(), "Global Step": global_step})

        # Trigger validation every eval_every steps
        if global_step % eval_every == 0 or global_step == epoch_steps:
          val_loss = 0.0
          student_model.eval()
          with torch.no_grad():
            for val_batch in tqdm(val_loader, desc=f"Validating"):
              input_ids = val_batch["input_ids"].cuda()
              attention_mask = val_batch["attention_mask"].cuda()

              student_outputs = student_model(input_ids, attention_mask=attention_mask).hidden_states
              teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask).hidden_states

              vloss = 0.0
              student_last = projection_layer(student_outputs[-1])
              teacher_last = teacher_outputs[-1]
              vloss += lambda_final * kl_loss_fn(
                (student_last / temperature).log_softmax(dim=-1),
                (teacher_last / temperature).softmax(dim=-1),
              ) * (temperature**2)

              for student_layer, teacher_layer in layer_mapping.items():
                student_rep = projection_layer(student_outputs[student_layer])
                teacher_rep = teacher_outputs[teacher_layer]
                vloss += lambda_intermediate * mse_loss_fn(student_rep, teacher_rep)

              val_loss += vloss.item()

          avg_val_loss = val_loss / len(val_loader)
          wandb.log({"Validation Loss": avg_val_loss})
          print(f"\n>>> Step {global_step}: Validation Loss = {avg_val_loss:.4f}")

          # Early stopping check
          if best_val_loss - avg_val_loss >= min_delta:
            best_val_loss = avg_val_loss
            no_improve_count = 0
            # Save best model
            if global_step % eval_every == 0:
              student_model.save_pretrained(f"{MODEL_DIR}/best_step_{global_step}")
              tokenizer.save_pretrained(f"{MODEL_DIR}/best_step_{global_step}")
              print(f"✓ Validation improved. Model saved at step {global_step}.")
              if global_step == 6000:
                break
          else:
            no_improve_count += 1
            print(f"✗ No improvement. Patience counter: {no_improve_count}/{patience}")

          student_model.train()

          if no_improve_count >= patience:
            print("🔥 Early stopping triggered due to no improvement.")
            break

wandb.finish()

  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdamianxml[0m ([33mdamianxml-uppsala-universitet[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  scaler = torch.cuda.amp.GradScaler()  # Enables mixed precision



Loading chunk 0 from distillation data...


  with torch.cuda.amp.autocast():  # Mixed percision training
Training:  97%|█████████▋| 37/38 [00:51<00:01,  1.34s/it]
Validating:   0%|          | 0/1 [00:00<?, ?it/s][A
Validating: 100%|██████████| 1/1 [00:00<00:00,  2.61it/s]
Training: 100%|██████████| 38/38 [00:52<00:00,  1.39s/it]



>>> Step 38: Validation Loss = 39.3992

Loading chunk 1 from distillation data...


KeyboardInterrupt: 