In [None]:
import torch
import torch.nn as nn
from custom_dataloaders import construct_dataloaders
from hf_trainer import infer, train
from roberta_classification_model import RobertaClsModel
from torch import cuda
from transformers import AutoModelForSequenceClassification, AutoTokenizer

Choose your dataset. Make sure that the number of classes in your model matches the number of different labels in that dataset.

In [None]:
# AG News Dataset for classifying news headlines.
dataset_name = "ag_news"
dataset_num_labels = 4

# Uncomment the code below to use the SST2 dataset for sentiment analysis.
# NOTE: If you're going to use the SST2 dataset, you need to make sure that use_hf_sequence_classification = True
# The custom RoBERTa model is only defined for ag_news.
# NOTE: For SST2 to train well, you'll need to adjust the learning rate and weight decay in the hf_trainer file
# A good place to start is lr=0.00001, weight_decay=0.001
# dataset_name = "SetFit/sst2"
# dataset_num_labels = 2

Choose your pre-trained model and setup the dataloaders.

By default, the HuggingFace Transformer models will provide the dense hidden states of the last layer, one vector for each token in the input. These vectors are not directly usable for our task of classification at the sequence level. While they can be combined using the "attention mechanism" into a single class-specific sequence-level representation, we opt for an easier solution here.

This can be done by adding a "classification head"- a linear projection layer (`nn.Dense`)- on top of one of these token vectors in the output. For bi-directional encoder-only transformers such as BERT and RoBERTa, there is a special token at the beginning of the input, \[CLS\], that contains information about the entire document. This layer will be added on top of the vector of the \[CLS\] token. For decoder-only transformers such as GPT and OPT, this projection layer might be added to the last non-pad token in the sentence.

The HuggingFace Transformers library provides a convenient way to add this layer to your pre-trained model. For a wide range of base models including RoBERTa and OPT, you can load the pre-trained model with the projection layer added and initialized for you using the `AutoModelForSequenceClassification` class:

```python
model = AutoModelForSequenceClassification.from_pretrained("roberta-base")
```

To demonstrate how this useful abstraction works, we've manually added a classification head on top of a HuggingFace [**RoBERTa**](https://arxiv.org/abs/1907.11692) model in a custom torch.nn module. The RoBERTa model is very similar to the BERT model, with a few minor differences. For example the next-sentence prediction task was removed in pretraining of RoBERTa.

We encourage you to take a look at our implementation in *roberta_classification_model.py* and see whether the behavior differs from that of AutoModelForSequenceClassification. Note that there is also an implementation of the "decoder-only" style head in *gpt2_classification_model.py*.

Please note that if you need to experiment with a base model other than RoBERTa- for example, OPT- you will need to set `use_hf_sequence_classification = False` and use the HuggingFace AutoModelForSequenceClassification instead. 

In [None]:
# NOTE: If you're going to use the SST2 dataset, you need to make sure that use_hf_sequence_classification = True
# The custom RoBERTa model is only defined for ag_news
use_hf_sequence_classification = True  # set to True to use the HuggingFace abstraction
hf_model_name = "roberta-base"

# Uncomment the code below to use facebook/opt-125m as the base model.
# Note that using OPT-125m requires the use_hf_sequence_classification = True
# use_hf_sequence_classification = True
# hf_model_name = "facebook/opt-125m"  # Also try "facebook/opt-125m" for OPT.

In [None]:
# Create a tokenizer instance for a pretrained model vocabulary.
tokenizer = AutoTokenizer.from_pretrained(hf_model_name)

# Set the maximum number of tokens in each input.
tokenizer.model_max_length = 512
# Create data loader objects for train, validation, and test splits.
train_dataloader, val_dataloader, test_dataloader = construct_dataloaders(
    batch_size=8, train_split_ratio=0.8, tokenizer=tokenizer, dataset_name=dataset_name
)

Setup the different variables we'd like for training

In [None]:
device = "cuda" if cuda.is_available() else "cpu"
print(f"Detected Device {device}")
# We'll provide two options. First we create our own model on top of the vanilla RoBERTa model. The second is to use
# HuggingFace's AutoModel class, which essentially does the same thing for RoBERTa, but with support additional base
# models such as OPT and GPT-J.
classifier_model = (
    AutoModelForSequenceClassification.from_pretrained(hf_model_name, num_labels=dataset_num_labels)
    if use_hf_sequence_classification
    else RobertaClsModel()
)
loss_function = nn.CrossEntropyLoss()
n_training_epochs = 1
n_training_steps = 300

Train the model on the training dataset

In [None]:
print("Begin Model Training...")
# Initiates an Adam optimizer and runs the training loop.
train(
    classifier_model,
    train_dataloader,
    val_dataloader,
    loss_function,
    device,
    n_training_epochs,
    n_training_steps,
)
print("Training Complete")

Once training is complete, we save the fine-tuned model to disk.

In [None]:
print("Saving model...")
hf_model_name_formatted = hf_model_name.split("/")[-1]
dataset_name_formatted = dataset_name.split("/")[-1]
output_model_file = f"./{hf_model_name_formatted}_{dataset_name_formatted}.bin"
torch.save(classifier_model, output_model_file)
print("Model saved to", output_model_file)

Next, we load the model saved above, perform inference on the test set and measure loss and accuracy.

In [None]:
print("Loading model...")
classifier_model = torch.load(output_model_file)
print("Model loaded.")

print("Evaluating model on test set...")
test_accuracy, test_loss = infer(classifier_model, loss_function, test_dataloader, device)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_accuracy}%")
print("Model evaluated.")