In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score
from sklearn.utils.class_weight import compute_class_weight

import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding, EarlyStoppingCallback

from data_utils import DataUtils

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

In [2]:
LEARNING_RATE = 2e-5
EPOCHS = 10
BATCH_SIZE = 16
DS_SPLIT = 0.2
MAX_SEQ_LEN = 256
MIN_SPECIALITY_THRESHOLD = 100
DATASET_PATH = '../data/mtsamples.csv'
MODEL_NAME = 'emilyalsentzer/Bio_ClinicalBERT'

In [3]:
df = pd.read_csv(DATASET_PATH, usecols=['transcription', 'medical_specialty'])    
df.head()

Unnamed: 0,medical_specialty,transcription
0,Allergy / Immunology,"SUBJECTIVE:, This 23-year-old white female pr..."
1,Bariatrics,"PAST MEDICAL HISTORY:, He has difficulty climb..."
2,Bariatrics,"HISTORY OF PRESENT ILLNESS: , I have seen ABC ..."
3,Cardiovascular / Pulmonary,"2-D M-MODE: , ,1. Left atrial enlargement wit..."
4,Cardiovascular / Pulmonary,1. The left ventricular cavity size and wall ...


In [4]:
utils = DataUtils()
df = utils.handle_nulls(df)
df = utils.handle_duplicates(df)

counts = df['medical_specialty'].value_counts()
others = [k for k,v in counts.items() if v < MIN_SPECIALITY_THRESHOLD]
for each_spec in others:
    df.loc[df['medical_specialty']==each_spec,'medical_specialty']=' others' 

counts = df['medical_specialty'].value_counts()
num_classes = len(counts)
print(counts)

===== Null Summary =====
medical_specialty     0
transcription        33
dtype: int64
Dropping rows with missing values...
===== Duplicate Summary =====
Count: 2
Dropping duplicate rows...
medical_specialty
Surgery                          1088
others                           1070
Consult - History and Phy.        516
Cardiovascular / Pulmonary        371
Orthopedic                        355
Radiology                         273
General Medicine                  259
Gastroenterology                  224
Neurology                         223
SOAP / Chart / Progress Notes     166
Urology                           156
Obstetrics / Gynecology           155
Discharge Summary                 108
Name: count, dtype: int64


In [5]:
def preprocess_text(text):
    text = utils.clean_text(text)
    return text

df['text'] = df['transcription'].map(preprocess_text)

le = LabelEncoder()
df["labels"] = le.fit_transform(df["medical_specialty"])

train_df, test_df = train_test_split(df, test_size=DS_SPLIT, stratify=df['labels'], random_state=42)

In [6]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
def tokenize(batch):
    return tokenizer(batch["text"], truncation=True, padding=True, max_length=MAX_SEQ_LEN)

train_ds = Dataset.from_pandas(train_df)
test_ds = Dataset.from_pandas(test_df)

train_ds = train_ds.map(tokenize, batched=True)
test_ds = test_ds.map(tokenize, batched=True)

train_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
test_ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

Map:   0%|          | 0/3971 [00:00<?, ? examples/s]

Map:   0%|          | 0/993 [00:00<?, ? examples/s]

In [7]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_classes,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1_weighted": f1_score(labels, preds, average="weighted")
    }

In [9]:
class_weights = compute_class_weight(
    "balanced",
    classes=np.arange(num_classes),
    y=train_df["labels"]
)
class_weights = torch.tensor(class_weights, dtype=torch.float32)

In [13]:
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights.to(logits.device))
        loss = loss_fn(logits, labels.to(logits.device))
        return (loss, outputs) if return_outputs else loss

In [11]:
training_args = TrainingArguments(
    output_dir="./results",         
    load_best_model_at_end=True,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    metric_for_best_model="f1_weighted",
    greater_is_better=True,
    learning_rate=LEARNING_RATE,             
    per_device_train_batch_size=BATCH_SIZE, 
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,             
    weight_decay=0.01,
    logging_dir=None,           
    logging_steps=-1
)

In [14]:
trainer = WeightedTrainer(
    model=model,       
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

trainer.train()

  0%|          | 0/2490 [00:00<?, ?it/s]

  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 1.544062614440918, 'eval_accuracy': 0.36555891238670696, 'eval_f1_weighted': 0.30380607637302576, 'eval_runtime': 8.4259, 'eval_samples_per_second': 117.851, 'eval_steps_per_second': 7.477, 'epoch': 1.0}


  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 1.373167634010315, 'eval_accuracy': 0.35850956696878145, 'eval_f1_weighted': 0.3086646000200894, 'eval_runtime': 8.4013, 'eval_samples_per_second': 118.196, 'eval_steps_per_second': 7.499, 'epoch': 2.0}


KeyboardInterrupt: 