[![Colab Badge Link](https://img.shields.io/badge/open-in%20colab-blue)](https://colab.research.google.com/github/Glasgow-AI4BioMed/tutorials/blob/main/pytorch_training_loop_with_custom_hugging_face_model.ipynb)

## Example of PyTorch Training Loop with Custom Hugging Face Model

This Colab illustrates a PyTorch training loop for training a custom transformer model. This contrasts with the HuggingFace Trainer. Using your own loop can give you more control over how the training works and any reporting that you want.

### Loading the dataset

We'll use part of the [Stanford IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb) to illustrate this. It is a dataset of movie reviews and a label of if they are positive or negative. We'll use Hugging Face's [datasets library](https://huggingface.co/docs/datasets/index) to download it. First, install the library:

In [None]:
!pip install -U datasets

Then load the imdb dataset:

In [None]:
from datasets import load_dataset

imdb = load_dataset("imdb")

Now we'll pick a tiny part of it to use. Just the texts and labels for a few hundred examples:

In [None]:
texts = imdb['train'][:500]['text']
labels = imdb['train'][:500]['label']

Here's an example:

In [None]:
labels[0], texts[0]

And we'll split them into training and validation sets to illustrate working with a training and validation set.

In [None]:
from sklearn.model_selection import train_test_split

texts_train, texts_val, labels_train, labels_val = train_test_split(texts, labels, test_size=0.33, random_state=42)

### Tokenizing the text

Next we need to preprocess that data. We'll use a `bert-base-uncased` model and tokenized them, while keeping track of the labels.

In [None]:
from transformers import AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def tokenize(text, label):
  tokenized = tokenizer(text, truncation=True, max_length=512, return_tensors='pt')
  tokenized['label'] = torch.tensor(label).reshape(1,1)
  return tokenized

In [None]:
from tqdm import tqdm

tokenized_train = [ tokenize(text,label) for text,label in tqdm(zip(texts_train,labels_train)) ]
tokenized_val = [ tokenize(text,label) for text,label in tqdm(zip(texts_val,labels_val)) ]

### Setting up a custom model

We'll also use a custom model. This is a model that encodes the text using a BERT model, then takes the CLS vectors and puts them through one final layer to get two outputs.

In [None]:
from transformers import AutoModel
import torch.nn as nn

class ClassifierModel(torch.nn.Module):
	def __init__(self, model_name):
		super().__init__()
		self.bert_model = AutoModel.from_pretrained(model_name)

		self.linear = nn.Linear(self.bert_model.config.hidden_size, 2)

	def forward(self, input_ids, attention_mask):
		bert_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)

		cls_vectors = bert_output.last_hidden_state[:,0,:]

		output = self.linear(cls_vectors)

		return output

Create the model and send it to the GPU:

In [None]:
device = 'cuda'

model = ClassifierModel('bert-base-uncased')
model = model.to(device)

### Figuring out batches of data

One of the fiddly bits of doing a training loop yourself is getting the data into nice batches. It's slow and often gives poor classification performance to train one sample at a time. So we'd like to put through a bunch together (e.g. 8 or 16 as common batch sizes). But our data can be different sizes. For instance, here is a short example:

In [None]:
tokenized_train[11]

And here's a slightly longer one:

In [None]:
tokenized_train[27]

And you can imagine that there are a lot of varied size. One way to solve this is to get the tokenizer to do padding for you, so that every sample is the same length (e.g. 512). That may be the most straightforward way. However, you may then be storing lots and lots of zeros in memory.

An alternative way is to do padding and grouping on the fly. We'll do that.

Let's say we've got a batch of 8 samples as below

In [None]:
batch = tokenized_train[0:8]

They've all got different sizes 😞

In [None]:
[ x['input_ids'].shape[1] for x in batch ]

Here's a function that does some padding for us with the [pad_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html) function.

In [None]:
from torch.nn.utils.rnn import pad_sequence

def custom_collate(batch):
  output = {}
  feature_names = [ f for f in batch[0].keys() ]
  for feature_name in feature_names:
    combined = [ b[feature_name][0,:] for b in batch ]
    padded = pad_sequence(combined,batch_first=True)
    assert padded.shape[0] == len(batch)
    output[feature_name] = padded

  return output

Now if we run this on the batch, it groups and pads each of the sub-parts.

In [None]:
collated = custom_collate(batch)
collated

Now each part are all tensors, ready to be passed to BERT, etc.

In [None]:
collated['input_ids'].shape

### Choosing some hyperparameters

We're getting close to some training. Let's pick a few hyperparameters. These can be optimised with Weights & Biases or an equivalent library.

In [None]:
batch_size = 8
learning_rate = 1e-4
num_epochs = 4

And let's create a DataLoader and give it the `custom_collate` function that we used before. It manages the size of batches and shuffling data as well which is important for training.

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(tokenized_train, batch_size=batch_size, collate_fn=custom_collate, shuffle=True)
val_loader = DataLoader(tokenized_val, batch_size=batch_size, collate_fn=custom_collate, shuffle=False)

Next we set up which optimizer we'll use (with the learning rate) as well as the loss function that we'll use to compare the model outputs to our targets.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_func = torch.nn.CrossEntropyLoss()

### Training time

And here is the big training loop. It first iterates through the training set, updates the model and calculates the loss. Then it iterates through the validation data and calculates the loss on it.

In [None]:
print("Training...")
for epoch in range(num_epochs):
  model.train()
  train_loss = 0.0

  for batch in tqdm(train_loader):
    batch = { k:v.to(device) for k,v in batch.items() }

    # Zero the parameter gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = model(batch['input_ids'], batch['attention_mask'])
    loss = loss_func(outputs, batch['label'].reshape(-1))

    # Backward pass and optimize
    loss.backward()
    optimizer.step()

    train_loss += loss.item()

  train_loss /= len(train_loader)

  # Validation after each epoch
  model.eval()
  val_loss = 0.0

  with torch.no_grad():
    for batch in tqdm(val_loader):
      batch = { k:v.to(device) for k,v in batch.items() }

      # Forward pass and compute loss
      outputs = model(batch['input_ids'], batch['attention_mask'])
      loss = loss_func(outputs, batch['label'].reshape(-1))
      val_loss += loss.item()

  val_loss /= len(val_loader)

  print(f"{epoch=} {train_loss=:.4f} {val_loss=:.4f}")