In [None]:
import torch
from catalyst import dl
from src.runners import DistilMLMRunner
from src.models import DistilbertStudentModel, BertForMLM
from catalyst.core import MetricAggregationCallback
from torch.utils.data import DataLoader
from src.callbacks import (
    CosineLossCallback,
    KLDivLossCallback,
    MaskedLanguageModelCallback,
    MSELossCallback,
    PerplexityMetricCallbackDistillation,
)
import pandas as pd

In [None]:
PATH_TO_YOUR_DATASET = "./data"
train_df = pd.read_csv(f"{PATH_TO_YOUR_DATASET}/train.csv")
valid_df = pd.read_csv(f"{PATH_TO_YOUR_DATASET}/valid.csv")

In [None]:
from src.data import LanguageModelingDataset
from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorForLanguageModeling

model_name = "bert-based-uncased"

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

train_dataset = LanguageModelingDataset(train_df["text"], tokenizer)
valid_dataset = LanguageModelingDataset(valid_df["text"], tokenizer)

collate_fn = DataCollatorForLanguageModeling(tokenizer).collate_batch
train_dataloader = DataLoader(
    train_dataset, collate_fn=collate_fn, batch_size=2
)
valid_dataloader = DataLoader(
    valid_dataset, collate_fn=collate_fn, batch_size=2
)
loaders = {"train": train_dataloader, "valid": valid_dataloader}

In [None]:
teacher = BertForMLM(model_name)
student = DistilbertStudentModel(model_name)

model = torch.nn.ModuleDict({"teacher": teacher, "student": student})

callbacks = {
    "masked_lm_loss": MaskedLanguageModelCallback(),
    "mse_loss": MSELossCallback(),
    "cosine_loss": CosineLossCallback(),
    "kl_div_loss": KLDivLossCallback(),
    "loss": MetricAggregationCallback(
        prefix="loss",
        mode="weighted_sum",
        metrics={
            "cosine_loss": 1.0,
            "masked_lm_loss": 1.0,
            "kl_div_loss": 1.0,
            "mse_loss": 1.0
        }
    ),
    "optimizer": dl.OptimizerCallback(),
    "perplexity": PerplexityMetricCallbackDistillation()
}

In [None]:
runner = DistilMLMRunner(device=torch.device("cuda"))
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
runner.train(
    model=model,
    optimizer=optimizer,
    loaders=loaders,
    verbose=True,
    check=True,
    callbacks=callbacks,
)