In [2]:
from transformers import DataCollatorForLanguageModeling, AutoTokenizer, Trainer, TrainingArguments

In [None]:
DATASET_PROC_PATH = "../../data/pretrain/base/wikitext-2-v1-base-proc"
TEST_SIZE = 0.2
VAL_SIZE = 0.2
SPLIT_SEED = 42069

TOKENIZER_NAME = "albert-base-v2"
MLM_PROBABILITY = 0.15

TRAINER_OUTPUT = "../../experiments/checkpoints/base/pretrain"
EPOCHS = 5
LOGGING_STEPS = 100
LOGGER_OUTPUT = "../../experiments/logs/base/pretrain"
SAVE_STEPS = 200
SAVE_LIMIT = 5

In [4]:
from datasets import load_from_disk

dataset = load_from_disk(DATASET_PROC_PATH)

In [5]:
split_train_test = dataset.train_test_split(test_size=TEST_SIZE, seed=SPLIT_SEED)
dataset_test = split_train_test["test"]

split_train_val = split_train_test["train"].train_test_split(test_size=VAL_SIZE, seed=SPLIT_SEED)
dataset_train = split_train_val["train"]
dataset_val = split_train_val["test"]

print(len(dataset_train), len(dataset_test), len(dataset_val))

128 40 32


In [6]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True, mlm_probability=MLM_PROBABILITY
)

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer

class TripleDistillationTrainer(Trainer):
    def __init__(
        self,
        teacher_model,
        alpha_ce=0.5,
        alpha_kl=0.3,
        alpha_cos=0.2,
        temperature=2.0,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        self.teacher.eval()
        self.alpha_ce = alpha_ce
        self.alpha_kl = alpha_kl
        self.alpha_cos = alpha_cos
        self.temperature = temperature

        self.cos_loss_fct = nn.CosineEmbeddingLoss()
        self.proj = None  # projection if hidden dims differ

    def _maybe_build_proj(self, in_dim, out_dim, device):
        if self.proj is None and in_dim != out_dim:
            self.proj = nn.Linear(in_dim, out_dim, bias=False).to(device)
        return self.proj

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels")
        attention_mask = inputs.get("attention_mask")

        # --- Student ---
        outputs_student = model(
            **{k: v for k, v in inputs.items() if k != "labels"},
            output_hidden_states=True,
            output_attentions=False
        )
        student_logits = outputs_student.logits
        student_hidden = outputs_student.hidden_states[-1]

        # --- Teacher ---
        with torch.no_grad():
            outputs_teacher = self.teacher(
                **{k: v for k, v in inputs.items() if k != "labels"},
                output_hidden_states=True,
                output_attentions=False
            )
            teacher_logits = outputs_teacher.logits
            teacher_hidden = outputs_teacher.hidden_states[-1]

        # 1️⃣ CrossEntropy loss (MLM or classification)
        ce_loss = F.cross_entropy(
            student_logits.view(-1, student_logits.size(-1)),
            labels.view(-1),
            reduction="mean"
        )

        # 2️⃣ KL divergence (logits distillation)
        s_logits_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
        t_logits_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
        kl_loss = F.kl_div(s_logits_soft, t_logits_soft, reduction="batchmean") * (self.temperature ** 2)

        # 3️⃣ Cosine embedding loss (hidden states)
        if attention_mask is not None:
            mask = attention_mask.view(-1).bool()
            s_hidden = student_hidden.view(-1, student_hidden.size(-1))[mask]
            t_hidden = teacher_hidden.view(-1, teacher_hidden.size(-1))[mask]
        else:
            s_hidden = student_hidden.view(-1, student_hidden.size(-1))
            t_hidden = teacher_hidden.view(-1, teacher_hidden.size(-1))

        # project teacher hidden if sizes differ
        if t_hidden.size(-1) != s_hidden.size(-1):
            proj = self._maybe_build_proj(t_hidden.size(-1), s_hidden.size(-1), s_hidden.device)
            t_hidden = proj(t_hidden)

        cos_target = s_hidden.new_ones(s_hidden.size(0))
        cos_loss = self.cos_loss_fct(
            F.normalize(s_hidden, dim=-1),
            F.normalize(t_hidden.detach(), dim=-1),
            cos_target
        )

        # Final weighted sum
        loss = (self.alpha_ce * ce_loss) + (self.alpha_kl * kl_loss) + (self.alpha_cos * cos_loss)

        return (loss, outputs_student) if return_outputs else loss


In [8]:
from transformers import AlbertForMaskedLM
from heliumbert import HeliumbertForMaskedLM, HeliumbertConfig

albert_model = AlbertForMaskedLM.from_pretrained("albert-base-v2")

heliumbert_config = HeliumbertConfig(
    vocab_size=albert_model.config.vocab_size,
    embedding_size=albert_model.config.embedding_size,
    hidden_size=albert_model.config.hidden_size,
    num_hidden_layers=albert_model.config.num_hidden_layers // 2,
    num_hidden_groups=albert_model.config.num_hidden_groups,
    num_attention_heads=albert_model.config.num_attention_heads,
    intermediate_size=albert_model.config.intermediate_size,
    inner_group_num=albert_model.config.inner_group_num,
    hidden_act=albert_model.config.hidden_act,
    hidden_dropout_prob=albert_model.config.hidden_dropout_prob,
    attention_probs_dropout_prob=albert_model.config.attention_probs_dropout_prob,
    max_position_embeddings=albert_model.config.max_position_embeddings,
    type_vocab_size=albert_model.config.type_vocab_size,
    initializer_range=albert_model.config.initializer_range,
    layer_norm_eps=albert_model.config.layer_norm_eps,
    classifier_dropout_prob=albert_model.config.classifier_dropout_prob,
    position_embedding_type=albert_model.config.position_embedding_type,
    pad_token_id=albert_model.config.pad_token_id,
    bos_token_id=albert_model.config.bos_token_id,
    eos_token_id=albert_model.config.eos_token_id,
)

MODEL = HeliumbertForMaskedLM(heliumbert_config)


Some weights of the model checkpoint at albert-base-v2 were not used when initializing AlbertForMaskedLM: ['albert.pooler.bias', 'albert.pooler.weight']
- This IS expected if you are initializing AlbertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
albert_model.num_parameters()

11221680

In [10]:
MODEL.num_parameters()

11221424

In [11]:
training_args = TrainingArguments(
    output_dir=TRAINER_OUTPUT,
    overwrite_output_dir=True,

    num_train_epochs=EPOCHS,
    prediction_loss_only=False,

    per_device_train_batch_size=16,

    logging_strategy="steps",
    logging_steps=LOGGING_STEPS,
    logging_dir=LOGGER_OUTPUT,

    save_strategy="steps",
    save_steps=SAVE_STEPS,
    save_total_limit=SAVE_LIMIT,
)

In [21]:
trainer = TripleDistillationTrainer(
    teacher_model=albert_model,
    model=MODEL,
    args=training_args,
    train_dataset=dataset_train,
    eval_dataset=dataset_val,
    data_collator=data_collator
)

In [22]:
trainer.train()



Step,Training Loss


KeyboardInterrupt: 

In [14]:
trainer.evaluate()



{'eval_loss': 362.9653625488281,
 'eval_runtime': 7.7303,
 'eval_samples_per_second': 4.14,
 'eval_steps_per_second': 0.517,
 'epoch': 1.0}

In [25]:
trainer.save_model(TRAINER_OUTPUT + "/full")

In [23]:
resumed_model = HeliumbertForMaskedLM.from_pretrained(TRAINER_OUTPUT + "/checkpoint-2")

trainer = TripleDistillationTrainer(
    teacher_model=albert_model,
    model=resumed_model,
    args=training_args,
    train_dataset=dataset_train,
    eval_dataset=dataset_val,
    data_collator=data_collator
)

In [24]:
trainer.train(resume_from_checkpoint=True)

There were missing keys in the checkpoint model loaded: ['predictions.decoder.weight', 'predictions.decoder.bias'].




Step,Training Loss
4,337.4442
6,301.0945
8,306.2599


TrainOutput(global_step=8, training_loss=236.19964599609375, metrics={'train_runtime': 61.8016, 'train_samples_per_second': 2.071, 'train_steps_per_second': 0.129, 'total_flos': 719181053952.0, 'train_loss': 236.19964599609375, 'epoch': 1.0})