## This notebook trains a local version of the BERT MLM Model

In [None]:
import torch
import transformers

In [None]:
from syfertext.data.metas.language_modeling import TextDatasetMeta
from syfertext.data.readers.language_modeling import TextReader
from syfertext.data.iterators.bert_loader import BERTIterator
from syfertext.encoders.bert_encoder import BERTEncoder

In [None]:
if torch.cuda.is_available():
    torch.device("cuda")

else:
    device = torch.device("cpu")
    
print(torch.cuda.get_device_properties(device))

In [None]:
encoder = BERTEncoder()

In [None]:
model = transformers.BertForMaskedLM.from_pretrained("bert-base-uncased")
model.to(device)
print("")

In [None]:
optimizer = transformers.AdamW(model.parameters(), lr=2e-5, eps=1e-8)

In [None]:
meta = TextDatasetMeta(train_path="PATH TO TRAIN DATA", 
        valid_path="PATH TO VALIDATION DATA", 
        test_path="PATH TO TEST DATA")

model_save_path = "./mlm_model.pt"

In [None]:
train_reader = TextReader(encoder=encoder, mode='train')
train_loader = BERTIterator(batch_size=20, sentence_len=35, dataset_reader=train_reader)
train_loader.load(meta)
num_epochs = 3

In [None]:
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
    num_warmup_steps=0, 
    num_training_steps=train_loader.num_examples * num_epochs)

In [None]:
val_reader = TextReader(encoder=encoder, mode='valid')
val_loader = BERTIterator(batch_size=10, sentence_len=35, dataset_reader=val_reader)
val_loader.load(meta)

In [None]:
def evaluate(loader, model):
    total_loss = 0.

    with torch.no_grad():
        for data in loader:
            inputs = data["input_ids"].to(device)
            labels = data["labels"].to(device)

            outputs = model(input_ids=inputs, labels=labels)
            total_loss += len(inputs) * outputs.loss.item()

    return total_loss / loader.num_examples

In [None]:
torch.manual_seed(42)

total_batches = train_loader.num_batches

#Change this depending on how often you want training updates
log_interval = 200

for epoch in range(1, num_epochs + 1):
    model.train()
    print(f"=========EPOCH {epoch}=========")

    for batch_num, data in enumerate(train_loader):
        inputs = data["input_ids"].to(device)
        labels = data["labels"].to(device)

        model.zero_grad()

        outputs = model(input_ids=inputs, labels=labels)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        scheduler.step()

        if (batch_num % log_interval == 0):
            print(f"Batch {batch_num}/{total_batches} | Loss: {loss.item()}")

    model.eval()
    val_loss = evaluate(val_loader, model)
    print("-------------------")
    print(f"Val Loss for Epoch {epoch}: {val_loss}")
    print("-------------------")

print(f"Done training! Saving model to {model_save_path}")
torch.save(model.state_dict(), model_save_path)

In [None]:
pred_model = transformers.BertForMaskedLM.from_pretrained("bert-base-uncased")
print("Base model loaded")
pred_model.load_state_dict(torch.load(model_save_path))
pred_model.eval().to(device)
print("Trained state initialized")

In [None]:
test_reader = TextReader(encoder=encoder, mode='test')
test_loader = BERTIterator(batch_size=10, sentence_len=35, dataset_reader=test_reader)
test_loader.load(meta)

In [None]:
test_loss = evaluate(test_loader, pred_model)
print(f"Test Loss: {test_loss}")