In [None]:
import transformers as ts
from transformers import Trainer, TrainingArguments

def train_model(dataset, tokenizer, model_path, save_path, num_epochs=3, batch_size=48):
    """
    Train a masked language model using the Hugging Face Trainer.

    Args:
        dataset: The tokenized dataset.
        tokenizer: The tokenizer for MLM.
        model_path (str): Path to the pretrained model.
        save_path (str): Path to save the trained model and checkpoints.
        num_epochs (int): Number of training epochs.
        batch_size (int): Batch size for training.
    """
    print(f"Initializing model from {model_path}...")
    model = ts.AutoModelForMaskedLM.from_pretrained(model_path)

    # Calculate total trainable parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total trainable parameters: {total_params / 1e6:.2f}M")

    # Data collator for MLM
    data_collator = ts.DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt"
    )

    # Training arguments
    training_args = TrainingArguments(
        output_dir=save_path + "checkpoints",
        logging_steps=250,
        overwrite_output_dir=True,
        save_steps=2500,
        num_train_epochs=num_epochs,
        learning_rate=5e-5,
        lr_scheduler_type="linear",
        warmup_steps=5000,
        per_device_train_batch_size=batch_size,
        weight_decay=1e-4,
        save_total_limit=5,
        remove_unused_columns=True,
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=data_collator,
    )

    print("Starting training...")
    trainer.train()

    print(f"Saving model to {save_path}final/model/...")
    trainer.save_model(save_path + "final/model/")
