### This is a sample code snippet for training a smaller BERT model for spell checking using Neuspell.

In [None]:
# Install libs
!pip install -q pytorch_pretrained_bert
!pip install -q transformers
!pip install -q sentencepiece

In [None]:
%%writefile train_KD.py


# Lib versions
# pytorch_version='1.10'
# py_version='py38'


# Headers
from pytorch_pretrained_bert import BertAdam
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import os
import time
import gc
from typing import List, Dict, Union
from corrector_subwordbert import BertChecker
from commons import DEFAULT_TRAINTEST_DATA_PATH
from corrector import Corrector
from helpers import bert_tokenize_for_valid_examples
from helpers import load_data, load_vocab_dict, save_vocab_dict
from helpers import load_data, load_vocab_dict, save_vocab_dict
from helpers import train_validation_split, batch_iter, labelize, progressBar, batch_accuracy_func
from subwordbert import create_model, load_pretrained, model_predictions, model_inference


# Load data and split in train-validation
data_dir = "PATH_TO_DATA_DIR"
clean_file="CORRECT_SENTENCES"
corrupt_file="NOISY_SENTENCES"
validation_split=0.2
train_data = load_data(data_dir, clean_file, corrupt_file)
train_data, valid_data = train_validation_split(train_data, 0.8, seed=11690)
print("len of train and test data: ", len(train_data), len(valid_data))


# Load both teacher and student models
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

teacher_checker = BertChecker("bert-base-cased")
teacher_checker.from_pretrained(ckpt_path="PATH_TO_PRETRAINED_MODEL", 
                                vocab_path="PATH_TO_VOCABS")

student_checker = BertChecker("google/bert_uncased_L-2_H-256_A-4") #device="cuda") nreimers/TinyBERT_L-4_H-312_v2
student_checker.from_huggingface(bert_pretrained_name_or_path="google/bert_uncased_L-2_H-256_A-4", 
                                 vocab="PATH_TO_VOCABS")

teacher_checker.is_model_ready()
student_checker.is_model_ready()

t_model = teacher_checker.get_model()
s_model = student_checker.get_model()

t_model.to(DEVICE)
s_model.to(DEVICE)


# Set training params
TRAIN_BATCH_SIZE, VALID_BATCH_SIZE = 2, 2
GRADIENT_ACC = 512
n_epochs=4
START_EPOCH, N_EPOCHS = 0, n_epochs
CHECKPOINT_PATH = "PATH_TO_HECKPOINT_DIR"
print(f"CHECKPOINT_PATH: {CHECKPOINT_PATH}")


# Create an optimizer
param_optimizer = list(s_model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
t_total = int(len(train_data) / TRAIN_BATCH_SIZE / GRADIENT_ACC * N_EPOCHS)
if t_total == 0:
    t_total = 1
optimizer = BertAdam(optimizer_grouped_parameters, lr=1e-3, warmup=0.1, t_total=t_total)


# Set student model in eval mode and teacher model in train mode
s_model.train()
t_model.eval()


# Define the KL loss
loss_functionKL = nn.KLDivLoss(reduction="batchmean")


# Set the KD params
temperature = 4.0
alpha=0.5


# Running stats
max_dev_acc, argmax_dev_acc = -1, -1
patience = 100


# Train and Eval
for epoch_id in range(START_EPOCH, N_EPOCHS):# + 1):
    if (epoch_id - argmax_dev_acc) > patience: # check for patience
        print("patience count reached. early stopping initiated")
        print("max_dev_acc: {}, argmax_dev_acc: {}".format(max_dev_acc, argmax_dev_acc))
        break
 
    # train loss and backprop
    train_loss = 0.
    train_acc = 0.
    train_acc_count = 0.
    print("train_data size: {}".format(len(train_data)))

    train_data_iter = batch_iter(train_data, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
    nbatches = int(np.ceil(len(train_data) / TRAIN_BATCH_SIZE))
    
    optimizer.zero_grad()
    
    for batch_id, (batch_labels, batch_sentences) in enumerate(train_data_iter):
        st_time = time.time()
        # set batch data for bert
        batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = \
            bert_tokenize_for_valid_examples(batch_labels, batch_sentences)
        if len(batch_labels_) == 0:
            # print("################")
            print("Not training the following lines due to pre-processing mismatch: \n")
            # print([(a, b) for a, b in zip(batch_labels, batch_sentences)])
            # print("################")
            continue
        else:
            batch_labels, batch_sentences = batch_labels_, batch_sentences_
        
        batch_bert_inp = {k: v.to(DEVICE) for k, v in batch_bert_inp.items()}
        batch_labels, batch_lengths = labelize(batch_labels, teacher_checker.get_vocab())
        batch_labels = batch_labels.to(DEVICE)
        
        # forward
        outputs_student_logits, s_loss = s_model(batch_bert_inp, batch_bert_splits, targets=batch_labels, logits_flag=True)
        batch_loss = s_loss.cpu().detach().numpy()
        train_loss += batch_loss
        
        with torch.no_grad():
            outputs_teacher_logits, t_loss = t_model(batch_bert_inp, batch_bert_splits, targets=batch_labels, logits_flag=True)
  
        # Soften probabilities and compute distillation loss
        loss_logits = (loss_functionKL(
            F.log_softmax(outputs_student_logits / temperature, dim=-1),
            F.softmax(outputs_teacher_logits / temperature, dim=-1)) * (temperature ** 2))
        loss = alpha * s_loss + (1. - alpha) * loss_logits  # Return weighted student loss
        
        # backward
        if GRADIENT_ACC > 1:
            loss = loss / GRADIENT_ACC
        loss.backward()
        
        # step
        if (batch_id + 1) % GRADIENT_ACC == 0 or batch_id >= nbatches - 1:
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            # scheduler.step()
            optimizer.zero_grad()
        
        # compute accuracy in numpy
        if batch_id % 10000 == 0:
            train_acc_count += 1
            s_model.eval()
            with torch.no_grad():
                _, batch_predictions = s_model(batch_bert_inp, batch_bert_splits, targets=batch_labels)
            s_model.train()
            batch_labels = batch_labels.cpu().detach().numpy()
            batch_lengths = batch_lengths.cpu().detach().numpy()
            ncorr, ntotal = batch_accuracy_func(batch_predictions, batch_labels, batch_lengths)
            batch_acc = ncorr / ntotal
            train_acc += batch_acc
            
        # update progress
        progressBar(batch_id + 1,
                    int(np.ceil(len(train_data) / TRAIN_BATCH_SIZE)),
                    ["batch_time", "batch_loss", "avg_batch_loss", "batch_acc", "avg_batch_acc"],
                    [time.time() - st_time, batch_loss, train_loss / (batch_id + 1), batch_acc,
                     train_acc / train_acc_count])
        if batch_id == 0 or (batch_id + 1) % 5000 == 0:
            nb = int(np.ceil(len(train_data) / TRAIN_BATCH_SIZE))
            
    print(f"\nEpoch {epoch_id} train_loss: {train_loss / (batch_id + 1)}")

    # valid loss
    valid_loss = 0.
    valid_acc = 0.
    print("valid_data size: {}".format(len(valid_data)))

    valid_data_iter = batch_iter(valid_data, batch_size=VALID_BATCH_SIZE, shuffle=False)
    for batch_id, (batch_labels, batch_sentences) in enumerate(valid_data_iter):
        st_time = time.time()
        # set batch data for bert
        batch_labels_, batch_sentences_, batch_bert_inp, batch_bert_splits = \
            bert_tokenize_for_valid_examples(batch_labels, batch_sentences)
        if len(batch_labels_) == 0:
            # print("################")
            print("Not validating the following lines due to pre-processing mismatch: \n")
            # print([(a, b) for a, b in zip(batch_labels, batch_sentences)])
            # print("################")
            continue
        else:
            batch_labels, batch_sentences = batch_labels_, batch_sentences_
        batch_bert_inp = {k: v.to(DEVICE) for k, v in batch_bert_inp.items()}
        batch_labels, batch_lengths = labelize(batch_labels, teacher_checker.get_vocab())
        batch_labels = batch_labels.to(DEVICE)
        
        # forward
        s_model.eval()
        with torch.no_grad():
            batch_loss, batch_predictions = s_model(batch_bert_inp, batch_bert_splits, targets=batch_labels)
        s_model.train()
        valid_loss += batch_loss
        
        # compute accuracy in numpy
        batch_labels = batch_labels.cpu().detach().numpy()
        batch_lengths = batch_lengths.cpu().detach().numpy()
        ncorr, ntotal = batch_accuracy_func(batch_predictions, batch_labels, batch_lengths)
        batch_acc = ncorr / ntotal
        valid_acc += batch_acc
        
        # update progress
        progressBar(batch_id + 1,
                    int(np.ceil(len(valid_data) / VALID_BATCH_SIZE)),
                    ["batch_time", "batch_loss", "avg_batch_loss", "batch_acc", "avg_batch_acc"],
                    [time.time() - st_time, batch_loss, valid_loss / (batch_id + 1), batch_acc,
                     valid_acc / (batch_id + 1)])
        if batch_id == 0 or (batch_id + 1) % 2000 == 0:
            nb = int(np.ceil(len(valid_data) / VALID_BATCH_SIZE))

    print(f"\nEpoch {epoch_id} valid_loss: {valid_loss / (batch_id + 1)}")

    # save model, optimizer and test_predictions if val_acc is improved
    if valid_acc >= max_dev_acc:
        print(f"validation accuracy improved from {max_dev_acc:.4f} to {valid_acc:.4f}")
        name = "pytorch_model" + str(epoch_id) + ".bin"
        torch.save(s_model.state_dict(), os.path.join(CHECKPOINT_PATH, name))
        print("Model saved at {} in epoch {}".format(os.path.join(CHECKPOINT_PATH, name), epoch_id))

        # re-assign
        max_dev_acc, argmax_dev_acc = valid_acc, epoch_id
        
    gc.collect()
    torch.cuda.empty_cache()