In [15]:
import pandas as pd
import json
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import matplotlib.pyplot as plt
import os
from datetime import datetime

if torch.cuda.is_available():
    print(f"GPU found, using: {torch.cuda.get_device_name(0)}")
    device = torch.device("cuda")
else:
    print("GPU not found")
    device = torch.device("cpu")

GPU found, using: NVIDIA GeForce RTX 3070


In [16]:
# MODEL_NAME = 'distilbert-base-uncased'
# MODEL_NAME = 'emilyalsentzer/Bio_ClinicalBERT'
# Small guy
MODEL_NAME = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
# Big fella
# Crashed GPU :(
# MODEL_NAME = 'microsoft/BiomedNLP-BiomedBERT-large-uncased-abstract'
CSV = '../../Data/Specialty-Data/specialty_data.csv'
MAPPINGS = '../../Data/Specialty-Data/specialty_data_label_mappings.json'

current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ROOT_OUTPUT_DIR = f"./Saved-Models/{MODEL_NAME}/training_run_{current_time}"

TRAINING_OUTPUT_DIRECTORY = os.path.join(ROOT_OUTPUT_DIR, 'results')
MODEL_FINAL_DIRECTORY = os.path.join(ROOT_OUTPUT_DIR, 'final_model')
LOGGING_DIRECTORY = os.path.join(ROOT_OUTPUT_DIR, 'logs')

os.makedirs(TRAINING_OUTPUT_DIRECTORY, exist_ok=True)
os.makedirs(MODEL_FINAL_DIRECTORY, exist_ok=True)
os.makedirs(LOGGING_DIRECTORY, exist_ok=True)

print(f"All outputs will be saved to: {ROOT_OUTPUT_DIR}")

def tokenize_function(examples, tokenizer):
    return tokenizer(
        examples['transcription'], 
        padding="max_length", 
        truncation=True,
        max_length=512
    )

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    f1 = f1_score(labels, predictions, average="weighted")
    acc = accuracy_score(labels, predictions)
    
    return {"accuracy": acc, "f1": f1}

All outputs will be saved to: ./Saved-Models/microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext/training_run_2025-11-08_18-29-45


In [17]:
try:
    raw_df = pd.read_csv(CSV)
    with open(MAPPINGS, 'r') as f:
        specialty_and_id_map = json.load(f)
except:
    print(f"Data not found, make sure to run the specialty_data_preprocessing.ipynb file in its entirety to retrieve the data")
# Retrieve labels
label_to_id = specialty_and_id_map['label_to_id']
id_to_label = {int(k): v for k, v in specialty_and_id_map['id_to_label'].items()}

total_specialties = len(label_to_id)

# Format dataframe for model
df = raw_df[['transcription', 'medical_specialty', 'label']].dropna()

df

Unnamed: 0,transcription,medical_specialty,label
0,"2-D M-MODE: , ,1. Left atrial enlargement wit...",Cardiovascular / Pulmonary,0
1,1. The left ventricular cavity size and wall ...,Cardiovascular / Pulmonary,0
2,"2-D ECHOCARDIOGRAM,Multiple views of the heart...",Cardiovascular / Pulmonary,0
3,"DESCRIPTION:,1. Normal cardiac chambers size....",Cardiovascular / Pulmonary,0
4,"2-D STUDY,1. Mild aortic stenosis, widely calc...",Cardiovascular / Pulmonary,0
...,...,...,...
1258,"EXAM: , Left heart cath, selective coronary an...",Cardiovascular / Pulmonary,0
1259,"INDICATION:, Acute coronary syndrome.,CONSENT...",Cardiovascular / Pulmonary,0
1260,"ANGINA, is chest pain due to a lack of oxygen ...",Cardiovascular / Pulmonary,0
1261,"INDICATION: , Chest pain.,TYPE OF TEST: , Aden...",Cardiovascular / Pulmonary,0


In [18]:
# 80% Train, 10% Validation, 10% Test
train_df, test_val_df = train_test_split(
    df, 
    test_size=0.2, 
    stratify=df['label'], 
    random_state=0
)

val_df, test_df = train_test_split(
    test_val_df, 
    test_size=0.5, 
    stratify=test_val_df['label'], 
    random_state=0
)

ds = DatasetDict({
    'train': Dataset.from_pandas(train_df.reset_index(drop=True)),
    'validation': Dataset.from_pandas(val_df.reset_index(drop=True)),
    'test': Dataset.from_pandas(test_df.reset_index(drop=True))
})

ds

DatasetDict({
    train: Dataset({
        features: ['transcription', 'medical_specialty', 'label'],
        num_rows: 1010
    })
    validation: Dataset({
        features: ['transcription', 'medical_specialty', 'label'],
        num_rows: 126
    })
    test: Dataset({
        features: ['transcription', 'medical_specialty', 'label'],
        num_rows: 127
    })
})

In [19]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

tokenized_ds = ds.map(lambda examples: tokenize_function(examples, tokenizer), batched=True)

tokenized_ds = tokenized_ds.remove_columns(['transcription', 'medical_specialty'])

tokenized_ds

Map:   0%|          | 0/1010 [00:00<?, ? examples/s]

Map:   0%|          | 0/126 [00:00<?, ? examples/s]

Map:   0%|          | 0/127 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1010
    })
    validation: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 126
    })
    test: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 127
    })
})

In [20]:
# Include this in the model for Bio_ClinicalBERT to force the use of safetensors vs using insecure load
# May not be necessary on other models
# use_safetensors=True,

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, 
    use_safetensors=True,
    num_labels=total_specialties,
    id2label=id_to_label,
    label2id=label_to_id
)

training_args = TrainingArguments(
    output_dir=TRAINING_OUTPUT_DIRECTORY,
    num_train_epochs=7,
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir=LOGGING_DIRECTORY,
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["validation"],
    compute_metrics=compute_metrics,
    tokenizer=tokenizer,
)

OSError: The paging file is too small for this operation to complete. (os error 1455)

In [None]:
print(f"Training Model")
trainer.train()
print(f"Training Complete")

print("-------------------------------------")

print(f"Evaluating on validation dataset")
test_results = trainer.evaluate(tokenized_ds["test"])

print(f"Validation results")
print(test_results)

with open(f"{TRAINING_OUTPUT_DIRECTORY}/test_results.json", 'w') as f:
    json.dump(test_results, f, indent=4)

print(f"Saving final model to {MODEL_FINAL_DIRECTORY}")
trainer.save_model(MODEL_FINAL_DIRECTORY)
print(f"Model saved")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None}.


Training Model
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "c:\Users\spenc\anaconda3\envs\interview_env\Lib\site-packages\IPython\core\interactiveshell.py", line 2194, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\spenc\anaconda3\envs\interview_env\Lib\site-packages\IPython\core\ultratb.py", line 1179, in structured_traceback
    return FormattedTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\spenc\anaconda3\envs\interview_env\Lib\site-packages\IPython\core\ultratb.py", line 1050, in structured_traceback
    return VerboseTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\spenc\anaconda3\envs\interview_env\Lib\site-packages\IPython\core\ultratb.py", line 858, in structured_traceback
    formatted_exceptions: list[list[str]] = self.format_exception_as_a_whole(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File