In [None]:
from transformers import BertForSequenceClassification, BertConfig
from transformers import AdamW, get_linear_schedule_with_warmup

# ## BERT model initialization
#
# We now load a pretrained BERT model with a single linear
# classification layer added on top.

print("Initializing BertForSequenceClassification")

model = BertForSequenceClassification.from_pretrained(
    BERTMODEL, cache_dir=CACHE_DIR, num_labels=20
)
model.cuda()


# We set the remaining hyperparameters needed for fine-tuning the
# pretrained model:
#   * EPOCHS: the number of training epochs in fine-tuning
#     (recommended values between 2 and 4)
#   * WEIGHT_DECAY: weight decay for the Adam optimizer
#   * LR: learning rate for the Adam optimizer (2e-5 to 5e-5 recommended)
#   * WARMUP_STEPS: number of warmup steps to (linearly) reach the set
#     learning rate
#
# We also need to grab the training parameters from the pretrained model.

EPOCHS = 4
WEIGHT_DECAY = 0.01
LR = 2e-5
WARMUP_STEPS = int(0.2 * len(train_dataloader))

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
    {
        "params": [
            p
            for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        "weight_decay": WEIGHT_DECAY,
    },
    {
        "params": [
            p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)
        ],
        "weight_decay": 0.0,
    },
]
optimizer = AdamW(optimizer_grouped_parameters, lr=LR, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=len(train_dataloader) * EPOCHS,
)


# ## Learning
#
# Let's now define functions to train() and evaluate() the model:


def train(epoch, loss_vector=None, log_interval=200):
    # Set model to training mode
    model.train()

    # Loop over each batch from the training set
    for step, batch in enumerate(train_dataloader):

        # Copy data to GPU if needed
        batch = tuple(t.to(device) for t in batch)

        # Unpack the inputs from our dataloader
        b_input_ids, b_input_mask, b_labels = batch

        # Zero gradient buffers
        optimizer.zero_grad()

        # Forward pass
        outputs = model(
            b_input_ids,
            token_type_ids=None,
            attention_mask=b_input_mask,
            labels=b_labels,
        )

        loss = outputs[0]
        if loss_vector is not None:
            loss_vector.append(loss.item())

        # Backward pass
        loss.backward()

        # Update weights
        optimizer.step()
        scheduler.step()

        if step % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    step * len(b_input_ids),
                    len(train_dataloader.dataset),
                    100.0 * step / len(train_dataloader),
                    loss,
                )
            )


def evaluate(loader):
    model.eval()

    n_correct, n_all = 0, 0

    for batch in loader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        with torch.no_grad():
            outputs = model(
                b_input_ids, token_type_ids=None, attention_mask=b_input_mask
            )
            logits = outputs[0]

        logits = logits.detach().cpu().numpy()
        predictions = np.argmax(logits, axis=1)

        labels = b_labels.to("cpu").numpy()
        n_correct += np.sum(predictions == labels)
        n_all += len(labels)

    print("Accuracy: [{}/{}] {:.4f}\n".format(n_correct, n_all, n_correct / n_all))


# Now we are ready to train our model using the train()
# function. After each epoch, we evaluate the model using the
# validation set and evaluate().

train_lossv = []
for epoch in range(1, EPOCHS + 1):
    train(epoch, train_lossv)
    print("\nValidation set:")
    evaluate(validation_dataloader)


# ## Inference
#
# For a better measure of the quality of the model, let's see the
# model accuracy for the test messages.

print("Test set:")
evaluate(test_dataloader)
