In [1]:
import os
import numpy as np
import torch
from datasets import load_from_disk
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, BertConfig, AutoModel

In [2]:
# Define a BERT configuration
def setup_bert_config(
    vocab_size: int,
    hidden_size: int = 768,
    num_hidden_layers: int = 12,
    num_attention_heads: int = 12,
    intermediate_size: int = 3072,
    max_position_embeddings: int = 512,
) -> BertConfig:
    config = BertConfig(
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        num_hidden_layers=num_hidden_layers,
        num_attention_heads=num_attention_heads,
        intermediate_size=intermediate_size,
        max_position_embeddings=max_position_embeddings,
        output_hidden_states=True,
    )
    return config


In [3]:
teacher_model_name = "bert-base-uncased"
student_model_name = "phonetic_bert"
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
student_tokenizer = AutoTokenizer.from_pretrained("/home/toure215/BERT_phonetic/tokenizers/tokenizer_phonetic_WordPiece")



In [4]:
teacher = AutoModel.from_pretrained(teacher_model_name)
config = setup_bert_config(vocab_size=student_tokenizer.vocab_size)
student = BertForMaskedLM(config=config)

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [5]:
teacher.to("cuda")
student.to("cuda")

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [None]:
from transformers import Trainer

class TeacherStudentTrainer(Trainer):
    def __init__(self, teacher=None, student=None, lmd:int = 0.1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher
        self.student = student
        self.lmd = lmd
        self.mse = torch.nn.MSELoss()

    def compute_loss(self, model, inputs, return_outputs=False):
        # Extract teacher and student inputs
        teacher_input_ids = inputs.pop("teacher_input_ids")
        teacher_attention_mask = inputs.pop("teacher_attention_mask")
        student_input_ids = inputs.pop("student_input_ids")
        student_attention_mask = inputs.pop("student_attention_mask")
        
        # Teacher outputs
        with torch.no_grad():
            teacher_outputs = self.teacher(
                input_ids=teacher_input_ids,
                attention_mask=teacher_attention_mask
            )
            teacher_cls_embeddings = teacher_outputs.last_hidden_state[:, 0, :]  # [CLS]
        
        # Student outputs
        student_outputs = model(
            input_ids=student_input_ids,
            attention_mask=student_attention_mask,
            labels=inputs["labels"]
        )
        student_cls_embeddings = student_outputs.hidden_states[-1][:, 0, :]  # [CLS]
        
        # Losses
        mlm_loss = student_outputs.loss
        mse_loss = self.mse(student_cls_embeddings, teacher_cls_embeddings)
        
        # Total loss
        total_loss = mlm_loss + self.lmd * mse_loss
        return (total_loss, student_outputs) if return_outputs else total_loss


In [7]:
mlm_data_collator = DataCollatorForLanguageModeling(
    tokenizer=student_tokenizer, mlm=True, mlm_probability=0.15
)

def custom_data_collator(batch):
    # Extract teacher and student texts
    original_texts = [example["original_text"] for example in batch]
    phonetic_texts = [example["text"] for example in batch]
    
    # Tokenize teacher inputs
    teacher_inputs = teacher_tokenizer(
        original_texts, 
        padding=True, 
        truncation=True, 
        max_length=64, 
        return_tensors="pt"
    )
    
    # Tokenize student inputs
    student_inputs = student_tokenizer(
        phonetic_texts, 
        padding=True, 
        truncation=True, 
        max_length=64, 
        return_tensors="pt"
    )
    
    # Apply MLM masking to student inputs
    student_inputs = mlm_data_collator(
        [{"input_ids": input_id} for input_id in student_inputs["input_ids"]]
    )
    
    return {
        "teacher_input_ids": teacher_inputs["input_ids"],
        "teacher_attention_mask": teacher_inputs["attention_mask"],
        "student_input_ids": student_inputs["input_ids"],
        "student_attention_mask": student_inputs["attention_mask"],
        "labels": student_inputs["labels"],  # MLM labels
    }


In [8]:
dataset = load_from_disk("/home/toure215/BERT_phonetic/DATASETS/phonetic_wikitext")

In [9]:
training_args = TrainingArguments(
    output_dir="./models/phonetic_bert",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=2,
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
    remove_unused_columns=False,
)

trainer = TeacherStudentTrainer(
    teacher=teacher,
    student=student,
    model_init=lambda: student,
    args=training_args,
    data_collator=custom_data_collator,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
)

In [10]:
trainer.train() 

  0%|          | 0/605225 [00:00<?, ?it/s]

Could not estimate the number of tokens of the input, floating-point operations will not be computed


KeyboardInterrupt: 