## 1. Import Modules and Data
BERT can be fine-tined on Stanford Sentiment Treebank-2(SST2) dataset for text classification task. More info about SST2 can be found [here](https://huggingface.co/datasets/stanfordnlp/sst2).

In [1]:
import torch
from data import load_data
from modules.bert import BertForSequenceClassification
import config 

# load sst-2
tokenizer, train_dataloader, valid_dataloader = load_data(
    name="sst2",
    loading_ratio=1,  # load 100% sst-2 data
    num_proc=4,  # use 4 processes
    splits=["train", "validation"]  # load train and validation dataset
)

## 2. Build Model and Load from Pre-trained
Build a BERT text classification model which inherits from the BERT class and add a binary linear classification layer at the end of the structure.

In [2]:
device = torch.device("cuda")

# load pretrained model
model = BertForSequenceClassification.from_pretrained(
    model_name_or_path=config.pretrained_path
).to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Number of trainable parameters: 109.48M


## 3. Train Model

In [3]:
import torch
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

import config


@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    for batch in tqdm(dataloader, desc="Evaluating"):
        input_ids, labels = batch
        input_ids = input_ids.to(device)
        labels = labels.to(device)
        attention_mask = (input_ids != tokenizer.pad_token_id).bool()

        loss, clf_logits = model(
            input_ids, attention_mask=attention_mask, labels=labels
        )

        total_loss += loss.item()

        preds = torch.argmax(clf_logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)

    return avg_loss, accuracy

In [4]:
def train(epoch, model, optimizer, scheduler, dataloader):
    model.train()
    total_loss = 0
    optimizer.zero_grad()

    for batch in tqdm(dataloader, desc=f"Training Epoch {epoch}"):
        input_ids, labels = batch
        input_ids = input_ids.to(device)
        labels = labels.to(device)
        attention_mask = (input_ids != tokenizer.pad_token_id).bool()

        loss, logits = model(input_ids, attention_mask=attention_mask, labels=labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [None]:
def training_loop(
    model,
    train_dataloader,
    valid_dataloader,
    optimizer,
    scheduler,
    num_epochs,
):
    for epoch in range(num_epochs):
        # train
        avg_train_loss = train(epoch + 1, model, optimizer, scheduler, train_dataloader)

        # valid
        avg_valid_loss, avg_acc = evaluate(model, valid_dataloader)

        print(
            f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {avg_train_loss:.4f},",
            f"Validation Loss: {avg_valid_loss:.4f}, Accuracy: {avg_acc * 100:.2f}",
        )

        torch.save(
            {
                "epoch": epoch + 1,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            },
            config.checkpoint_dir / f"bert_clf_{epoch + 1}.pth",
        )


optimizer = AdamW(
    model.parameters(),
    lr=config.FinetuningConfig.lr,
    weight_decay=config.FinetuningConfig.weight_decay,
)
scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.FinetuningConfig.warmup_steps,
    num_training_steps=len(train_dataloader) * config.FinetuningConfig.n_epoch,
)

training_loop(
    model,
    train_dataloader,
    valid_dataloader,
    optimizer,
    scheduler,
    config.FinetuningConfig.n_epoch,
)