[![Colab Badge Link](https://img.shields.io/badge/open-in%20colab-blue)](https://colab.research.google.com/github/Glasgow-AI4BioMed/tutorials/blob/main/pytorch_training_loop_with_custom_hugging_face_model.ipynb)

## Example of PyTorch Training Loop with Custom Hugging Face Model

This Colab illustrates a PyTorch training loop for training a custom transformer model. This contrasts with the HuggingFace Trainer. Using your own loop can give you more control over how the training works and any reporting that you want.

### Loading the dataset

We'll use part of the [Stanford IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb) to illustrate this. It is a dataset of movie reviews and a label of if they are positive or negative. We'll use Hugging Face's [datasets library](https://huggingface.co/docs/datasets/index) to download it. First, install the library:

In [1]:
!pip install -U datasets

Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-4.0.0-py3-none-any.whl (494 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.8/494.8 kB[0m [31m17.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
  Attempting uninstall: datasets
    Found existing installation: datasets 2.14.4
    Uninstalling datasets-2.14.4:
      Successfully uninstalled datasets-2.14.4
[31mERROR: pip's dependency r

Then load the imdb dataset:

In [2]:
from datasets import load_dataset

imdb = load_dataset("imdb")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Now we'll pick a tiny part of it to use. Just the texts and labels for a few hundred examples:

In [3]:
texts = imdb['train'][:500]['text']
labels = imdb['train'][:500]['label']

Here's an example:

In [4]:
labels[0], texts[0]

(0,
 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far betwee

And we'll split them into training and validation sets to illustrate working with a training and validation set.

In [5]:
from sklearn.model_selection import train_test_split

texts_train, texts_val, labels_train, labels_val = train_test_split(texts, labels, test_size=0.33, random_state=42)

### Tokenizing the text

Next we need to preprocess that data. We'll use a `bert-base-uncased` model and tokenized them, while keeping track of the labels.

In [6]:
from transformers import AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def tokenize(text, label):
  tokenized = tokenizer(text, truncation=True, max_length=512, return_tensors='pt')
  tokenized['label'] = torch.tensor(label).reshape(1,1)
  return tokenized

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [7]:
from tqdm import tqdm

tokenized_train = [ tokenize(text,label) for text,label in tqdm(zip(texts_train,labels_train)) ]
tokenized_val = [ tokenize(text,label) for text,label in tqdm(zip(texts_val,labels_val)) ]

335it [00:01, 188.19it/s]
165it [00:00, 184.92it/s]


### Setting up a custom model

We'll also use a custom model. This is a model that encodes the text using a BERT model, then takes the CLS vectors and puts them through one final layer to get two outputs.

In [8]:
from transformers import AutoModel
import torch.nn as nn

class ClassifierModel(torch.nn.Module):
	def __init__(self, model_name):
		super().__init__()
		self.bert_model = AutoModel.from_pretrained(model_name)

		self.linear = nn.Linear(self.bert_model.config.hidden_size, 2)

	def forward(self, input_ids, attention_mask):
		bert_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)

		cls_vectors = bert_output.last_hidden_state[:,0,:]

		output = self.linear(cls_vectors)

		return output

Create the model and send it to the GPU:

In [9]:
device = 'cuda'

model = ClassifierModel('bert-base-uncased')
model = model.to(device)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

### Figuring out batches of data

One of the fiddly bits of doing a training loop yourself is getting the data into nice batches. It's slow and often gives poor classification performance to train one sample at a time. So we'd like to put through a bunch together (e.g. 8 or 16 as common batch sizes). But our data can be different sizes. For instance, here is a short example:

In [10]:
tokenized_train[11]

{'input_ids': tensor([[  101,  2017,  1005,  1040,  2488,  5454,  2703,  2310, 25032,  8913,
          8159,  1005,  1055,  2130,  2065,  2017,  2031,  3427,  2009,  1012,
           102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'label': tensor([[0]])}

And here's a slightly longer one:

In [11]:
tokenized_train[27]

{'input_ids': tensor([[  101,  2023,  2001,  2019, 11757,  5236,  3185,  1012,  2009,  2001,
          4298,  1996,  5409,  3185,  1045,  1005,  2310,  2412,  2018,  1996,
         28606,  1997,  3564,  2083,  1012,  1045,  3685,  6638, 23393,  2129,
          2009,  6938,  1037,  5790,  1997,  1019,  2030,  1020,  1012,  1012,
          1012,  1012,  1012,  1012,  1012,  1012,  1012,  1012,  1012,  1012,
          1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]]), 'label': tensor([[0]])}

And you can imagine that there are a lot of varied size. One way to solve this is to get the tokenizer to do padding for you, so that every sample is the same length (e.g. 512). That may be the most straightforward way. However, you may then be storing lots and lots of zeros in memory.

An alternative way is to do padding and grouping on the fly. We'll do that.

Let's say we've got a batch of 8 samples as below

In [12]:
batch = tokenized_train[0:8]

They've all got different sizes 😞

In [13]:
[ x['input_ids'].shape[1] for x in batch ]

[448, 345, 388, 140, 274, 374, 248, 397]

Here's a function that does some padding for us with the [pad_sequence](https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pad_sequence.html) function.

In [14]:
from torch.nn.utils.rnn import pad_sequence

def custom_collate(batch):
  output = {}
  feature_names = [ f for f in batch[0].keys() ]
  for feature_name in feature_names:
    combined = [ b[feature_name][0,:] for b in batch ]
    padded = pad_sequence(combined,batch_first=True)
    assert padded.shape[0] == len(batch)
    output[feature_name] = padded

  return output

Now if we run this on the batch, it groups and pads each of the sub-parts.

In [15]:
collated = custom_collate(batch)
collated

{'input_ids': tensor([[  101,  7483,  2001,  ...,  2617,  1012,   102],
         [  101,  5515,  1010,  ...,     0,     0,     0],
         [  101,  2043,  1045,  ...,     0,     0,     0],
         ...,
         [  101,  1024, 27594,  ...,     0,     0,     0],
         [  101,  1998,  2666,  ...,     0,     0,     0],
         [  101,  1996, 12610,  ...,     0,     0,     0]]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'label': tensor([[0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0]])}

Now each part are all tensors, ready to be passed to BERT, etc.

In [16]:
collated['input_ids'].shape

torch.Size([8, 448])

### Choosing some hyperparameters

We're getting close to some training. Let's pick a few hyperparameters. These can be optimised with Weights & Biases or an equivalent library.

In [17]:
batch_size = 8
learning_rate = 1e-4
num_epochs = 4

And let's create a DataLoader and give it the `custom_collate` function that we used before. It manages the size of batches and shuffling data as well which is important for training.

In [18]:
from torch.utils.data import DataLoader

train_loader = DataLoader(tokenized_train, batch_size=batch_size, collate_fn=custom_collate, shuffle=True)
val_loader = DataLoader(tokenized_val, batch_size=batch_size, collate_fn=custom_collate, shuffle=False)

Next we set up which optimizer we'll use (with the learning rate) as well as the loss function that we'll use to compare the model outputs to our targets.

In [19]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_func = torch.nn.CrossEntropyLoss()

### Training time

And here is the big training loop. It first iterates through the training set, updates the model and calculates the loss. Then it iterates through the validation data and calculates the loss on it.

In [20]:
print("Training...")
for epoch in range(num_epochs):
  model.train()
  train_loss = 0.0

  for batch in tqdm(train_loader):
    batch = { k:v.to(device) for k,v in batch.items() }

    # Zero the parameter gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = model(batch['input_ids'], batch['attention_mask'])
    loss = loss_func(outputs, batch['label'].reshape(-1))

    # Backward pass and optimize
    loss.backward()
    optimizer.step()

    train_loss += loss.item()

  train_loss /= len(train_loader)

  # Validation after each epoch
  model.eval()
  val_loss = 0.0

  with torch.no_grad():
    for batch in tqdm(val_loader):
      batch = { k:v.to(device) for k,v in batch.items() }

      # Forward pass and compute loss
      outputs = model(batch['input_ids'], batch['attention_mask'])
      loss = loss_func(outputs, batch['label'].reshape(-1))
      val_loss += loss.item()

  val_loss /= len(val_loader)

  print(f"{epoch=} {train_loss=:.4f} {val_loss=:.4f}")

Training...


100%|██████████| 42/42 [00:25<00:00,  1.63it/s]
100%|██████████| 21/21 [00:03<00:00,  5.49it/s]


epoch=0 train_loss=0.0139 val_loss=0.0000


100%|██████████| 42/42 [00:26<00:00,  1.59it/s]
100%|██████████| 21/21 [00:03<00:00,  5.38it/s]


epoch=1 train_loss=0.0000 val_loss=0.0000


100%|██████████| 42/42 [00:27<00:00,  1.52it/s]
100%|██████████| 21/21 [00:03<00:00,  5.27it/s]


epoch=2 train_loss=0.0000 val_loss=0.0000


100%|██████████| 42/42 [00:27<00:00,  1.55it/s]
100%|██████████| 21/21 [00:04<00:00,  5.17it/s]

epoch=3 train_loss=0.0000 val_loss=0.0000



