# AI: Building a Patient Priority Classification Using BERT and Transformers

**Author**: Alan Arantes - Enterprise & System Architect  
**Reading Time**: 6 min  

## TL;DR
A practical overview to implementing an automated medical triage system using BERT and the Transformers library, helping medical staff prioritize patients efficiently through deep learning.

## Introduction
In busy medical environments, quick and accurate patient prioritization can mean the difference between life and death. While experienced medical professionals excel at this task, automating the initial triage process can help streamline operations and ensure consistent evaluation.

We'll use BERT (Bidirectional Encoder Representations from Transformers) to create a system that can understand and prioritize patient descriptions. The implementation uses PyTorch and the Transformers library from Hugging Face.

## The Role of Self-Attention in Medical Triage
BERT's self-attention mechanism can serve as a tool for accurate patient prioritization. It enables the model to process multiple aspects of a patient's condition simultaneously, similar to how a medical professional assesses a situation:

1. **Symptom Relationships**: Self-attention weighs symptom combinations (e.g., "chest pain" with "shortness of breath" indicating higher urgency than either alone).
2. **Contextual Understanding**: Words are interpreted in relation to each other, distinguishing between scenarios like "severe acute pain" and "mild chronic pain".
3. **Demographic Consideration**: The attention mechanism connects patient demographics with symptoms, recognizing that similar symptoms might indicate different priorities based on age or history.

This powerful mechanism processes inputs by computing attention scores that determine how different parts of a patient's description should influence each other, leading to more accurate priority predictions.

### Self-Attention Process
The self-attention process works by computing three vectors for each element in the input:
- **Query (Q)**: What information the element is looking for
- **Key (K)**: What information the element can provide
- **Value (V)**: The actual content to be passed along

This mechanism is particularly powerful for classification tasks because it allows the model to:
- Identify and weigh crucial symptoms appropriately
- Consider multiple factors simultaneously
- Learn complex patterns in medical presentations
- Adapt to various description formats and lengths

## The Building Blocks: Understanding Each Component

### Resources and Dependencies
To run this code, you'll need:
- Python 3.6+
- PyTorch
- Transformers library
- CUDA-capable GPU (recommended)
- Pandas and NumPy
- Scikit-learn

## Understanding the Essential Imports for BERT-based Patient Classification
### Core Deep Learning Libraries

***PyTorch Framework***

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader

- torch: The main PyTorch library for deep learning operations
- Dataset: Base class for creating custom datasets
- DataLoader: Handles batch processing and data loading during training

***Transformers Components***

In [2]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup

- **AutoTokenizer:** Handles text tokenization for BERT
- **AutoModelForSequenceClassification:** Pre-trained BERT model adapted for classification
- **AdamW:** Optimizer specialized for transformer models
- **get_linear_schedule_with_warmup:** Learning rate scheduler for better training

### Data Processing Libraries
Scientific Computing

In [3]:
import numpy as np

- **numpy:** Essential for numerical operations and array manipulations

### Machine Learning Tools

In [4]:
from sklearn.model_selection import train_test_split

- **train_test_split:** Splits data into training and validation sets

### Data Manipulation

In [5]:
import pandas as pd

### Key Points

This combination of imports provides all necessary tools for:

- Deep learning model implementation
- Text processing
- Data handling
- Model training and evaluation

The focus is on transformer-based architectures with proper data management support

## 1. Custom Dataset Implementation+

This custom dataset class is the foundation of our system. It inherits from PyTorch's Dataset class and handles:

* Text tokenization for BERT processing
* Label encoding for priority levels
* Proper formatting of input data
* Batch processing preparation

The class converts raw patient descriptions into the tensor format required by BERT, handling all necessary padding and truncation automatically.

In [6]:

class PatientDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


## Model Training Architecture

In below I'll break down each main section of the training function and explain exactly what it does:

### Data Preparation and Model Setup Flow

#### Sample Data Construction

The journey begins with preparing our sample patient data, structured as clear medical descriptions. Each entry contains vital information like symptoms, gender, and age:

In [7]:
patient_data = [
    "Severe chest pain, shortness of breath, male, age 65",
    "Mild headache, female, age 25",
    "High fever, cough, difficulty breathing, male, age 45",
]

For each patient description, we assign priority levels (0: low, 1: medium, 2: high) to train our model:

In [8]:
labels = [2, 0, 1]

## BERT Model Initialization

We then initialize our BERT model, using the 'bert-base-uncased' variant. This process involves two key components:


In [9]:
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=3
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


The tokenizer will process our text inputs, while the model is configured for three-level classification (low, medium, high priority).

## Data Organization

The data is split into training and validation sets using scikit-learn's train_test_split, with 20% reserved for validation:

In [10]:
X_train, X_val, y_train, y_val = train_test_split(
    patient_data, labels, test_size=0.2, random_state=42
)

These splits are then transformed into PyTorch datasets:

In [11]:
train_dataset = PatientDataset(X_train, y_train, tokenizer)
val_dataset = PatientDataset(X_val, y_val, tokenizer)

## Training Setup

DataLoaders are created to handle batching and shuffling during training:

In [12]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32)

Finally, we ensure our model can utilize available GPU acceleration if present:

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

This setup creates a complete pipeline from raw patient descriptions to a model ready for training on the appropriate hardware.

# Training Process Breakdown

## Optimization Setup

We begin by configuring the training environment with the AdamW optimizer, specifically designed for transformer models. The learning rate is carefully set to 2e-5, a value known to work well with BERT fine-tuning:

In [14]:
epochs = 3
optimizer = AdamW(model.parameters(), lr=2e-5)
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)



## Training Loop

The training process unfolds across multiple epochs, where each epoch represents a complete pass through the dataset. Within each epoch:

In [15]:
for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')

        model.train()
        total_train_loss = 0

### Batch Processing
#### Each batch undergoes a series of transformations and computations:
        for batch in train_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

### The model processes these inputs to generate predictions and calculate losses:
            model.zero_grad()
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

### Optimization Step
#### The backpropagation and optimization process occurs after each batch:
            loss = outputs.loss
            total_train_loss += loss.item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

        avg_train_loss = total_train_loss / len(train_dataloader)
        print(f'Average training loss: {avg_train_loss}')

## Validation Phase
#### After training, the model enters evaluation mode to assess its performance on unseen data:
        model.eval()
        total_val_loss = 0
        predictions = []
        true_labels = []

        for batch in val_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

### Validation Processing
#### During validation, we process batches without computing gradients:
            with torch.no_grad():
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

            loss = outputs.loss
            total_val_loss += loss.item()

            predictions.extend(outputs.logits.argmax(dim=1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

### Performance Metrics
#### Finally, we calculate and display key performance metrics:
        avg_val_loss = total_val_loss / len(val_dataloader)
        accuracy = np.mean(np.array(predictions) == np.array(true_labels))
        print(f'Average validation loss: {avg_val_loss}')
        print(f'Validation accuracy: {accuracy}')



Epoch 1/3
Average training loss: 1.8512301445007324
Average validation loss: 0.6676399111747742
Validation accuracy: 1.0
Epoch 2/3
Average training loss: 1.2944049835205078
Average validation loss: 0.8853601813316345
Validation accuracy: 1.0
Epoch 3/3
Average training loss: 1.3907421827316284
Average validation loss: 1.1244556903839111
Validation accuracy: 0.0


## Real-World Application Finally
The trained model is put to practical use. New patient cases are processed:

In [16]:
# Example of prioritizing new patients
new_patients = [
    "Severe abdominal pain, vomiting, female, age 35",
    "Minor cuts and bruises, male, age 28",
]

# Understanding the Priority Prediction Function

## Function Overview

The `predict_priority` function serves as our prediction pipeline, taking a trained model and new patient data to determine medical priority levels. Let's break down its implementation:

## Data Preparation

First, we prepare our data for prediction by creating a dataset instance:

In [17]:
model.eval()

dataset = PatientDataset(patient_data, [0]*len(patient_data), tokenizer)
dataloader = DataLoader(dataset, batch_size=32)
priorities = []


The function creates temporary labels (zeros) since we only need the text processing capability of our dataset class, not actual labels.

## Prediction Process

The core prediction happens within a no-gradient context, ensuring efficiency:

In [18]:
with torch.no_grad():
  for batch in dataloader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)

### Model Inference
#### The model processes each batch and generates probability distributions for priority levels:
    outputs = model(
      input_ids=input_ids,
      attention_mask=attention_mask
    )

    priorities.extend(outputs.logits.softmax(dim=1).cpu().numpy())

Here, we use softmax to convert model outputs into probability distributions, making our predictions more interpretable. The results are moved to CPU and converted to numpy arrays for easier post-processing.

The function efficiently processes multiple patients in batches and returns their priority predictions, making it suitable for real-time applications in a medical setting.

# Results Processing and Display

## Priority List Creation

The final stage of our medical triage system involves organizing and presenting the predictions in a meaningful way:


In [19]:
priority_list = [
    (patient, priority)
    for patient, priority in zip(new_patients, priorities)
]


This creates paired tuples of patients and their predicted priority scores, combining our input data with model predictions.

## Priority Sorting

The patients are then sorted by priority level, ensuring urgent cases appear first:

In [20]:
priority_list.sort(key=lambda x: x[1].max(), reverse=True)

The sorting uses the highest probability score for each patient, with `reverse=True` ensuring a highest-to-lowest ordering.

## Results Presentation

Finally, we present the results in a clear, human-readable format:


In [21]:
print("\nPatient Priority List:")
for patient, priority in priority_list:
    priority_level = ["Low", "Medium", "High"][priority.argmax()]
    print(f"Patient: {patient}")
    print(f"Priority Level: {priority_level}\n")



Patient Priority List:
Patient: Minor cuts and bruises, male, age 28
Priority Level: Low

Patient: Severe abdominal pain, vomiting, female, age 35
Priority Level: Low



The predictions are converted from numerical values to meaningful labels ("Low", "Medium", "High") using the `argmax()` function to select the most likely priority level. This creates a straightforward report that medical staff can quickly understand and act upon.

This creates a sorted list of patients based on their priority levels, from most urgent to least urgent. The output format makes it easy for medical staff to quickly identify which patients need immediate attention.

The entire process flows from data preparation through model training to practical application, creating a complete pipeline for medical triage automation.

## Conclusion

This implementation demonstrates how modern NLP techniques can be applied to real-world healthcare challenges. The system provides:

* Automated initial triage assessment
* Consistent patient prioritization
* Scalable processing of patient descriptions
* Real-time priority predictions

While this system serves as a valuable tool for medical staff, it's important to note that it should be used as a support system rather than a replacement for professional medical judgment.

## Next Steps

To further implement within a system, consider:

1. Incorporating additional patient metadata
2. Implementing multi-lingual support
3. Adding explainability features
4. Developing a user interface for medical staff
5. Expanding the priority levels for finer-grained triage