In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer
import numpy as np
from torchmetrics import Accuracy, Precision, Recall, F1Score
from noisy_intents.data import DydaDA
from noisy_intents.training import train, autodetect_device
from noisy_intents.models import BERT
from noisy_intents.eval import compute_metrics

## Load data

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
train_data = DydaDA.from_hugging_face("train", tokenizer, max_len=128)
val_data = DydaDA.from_hugging_face("validation", tokenizer, max_len=128)

In [None]:
train_loader = DataLoader(train_data, batch_size=64, num_workers=8, drop_last=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False, num_workers=8, drop_last=True)

## Load model

In [None]:
device = autodetect_device()

In [None]:
model = BERT(num_classes=4, dropout=0.1)
model.to(device);

Freeze all layers apart from the classification head:

In [None]:
for name, param in model.named_parameters():
    if "l1" in name:
        param.requires_grad = False
    else:
        param.requires_grad = True

In [None]:
optimizer = torch.optim.AdamW(params=model.parameters(), lr=7e-4)

Determine inverse class frequencies to re-weight the loss function:

In [None]:
counts = np.bincount(train_data.data[:]["Label"])
freq = counts / counts.sum()
weights = torch.from_numpy((1 / freq) / (1 / freq).sum()).float().to(device)

The training can be followed using tensorboard:

In [None]:
train(
    epochs=10,
    model=model,
    loss_fn=torch.nn.CrossEntropyLoss(weight=weights),
    optimizer=optimizer,
    device=device,
    train_loader=train_loader,
    val_loader=val_loader,
    val_metrics_freq=1,
    metrics=[  # Log metrics on tensorboard during training
        Accuracy("multiclass", num_classes=4),
        Accuracy("multiclass", num_classes=4, average="macro"),
        Precision("multiclass", num_classes=4, average="macro"),
        Recall("multiclass", num_classes=4, average="macro"),
        F1Score("multiclass", num_classes=4, average="macro"),
    ],
    log_dir="./logs/bert_base_uncased",  # Tensorboard log_dir
)

In [None]:
torch.save(model, "best_bert.py")

## Evaluate the model on the validation set

In [None]:
compute_metrics(
    model,
    val_loader,
    device,
    metrics=[
        Accuracy("multiclass", num_classes=4, average="micro"),
        Accuracy("multiclass", num_classes=4, average="macro"),
        Accuracy("multiclass", num_classes=4, average=None),
        Precision("multiclass", num_classes=4, average="macro"),
        Recall("multiclass", num_classes=4, average="macro"),
        F1Score("multiclass", num_classes=4, average="macro"),
        F1Score("multiclass", num_classes=4, average=None),
    ],
)