In [12]:
# Ensure kagglehub is available and give a clear message if not
try:
    import kagglehub
except ModuleNotFoundError:
    raise ModuleNotFoundError("kagglehub is not installed. Install with `pip install kagglehub` or use the Kaggle API instead.")

# Download latest version
path = kagglehub.dataset_download("tboyle10/medicaltranscriptions")

print("Path to dataset files:", path)


Using Colab cache for faster access to the 'medicaltranscriptions' dataset.
Path to dataset files: /kaggle/input/medicaltranscriptions


In [13]:
#preview dataset
import pandas as pd
orig_df = pd.read_csv(path + "/mtsamples.csv")
orig_df.head()
print("Number of rows in dataset:", len(orig_df))
orig_df.shape

Number of rows in dataset: 4999


(4999, 6)

In [14]:
# 1. Check class distribution
print(orig_df['medical_specialty'].value_counts())


medical_specialty
Surgery                          1103
Consult - History and Phy.        516
Cardiovascular / Pulmonary        372
Orthopedic                        355
Radiology                         273
General Medicine                  259
Gastroenterology                  230
Neurology                         223
SOAP / Chart / Progress Notes     166
Obstetrics / Gynecology           160
Urology                           158
Discharge Summary                 108
ENT - Otolaryngology               98
Neurosurgery                       94
Hematology - Oncology              90
Ophthalmology                      83
Nephrology                         81
Emergency Room Reports             75
Pediatrics - Neonatal              70
Pain Management                    62
Psychiatry / Psychology            53
Office Notes                       51
Podiatry                           47
Dermatology                        29
Dentistry                          27
Cosmetic / Plastic Surgery      

In [15]:
#clean dataset first
#create new df (df) that is subset of orig_df
#drop any empty rows
#drop Unnamed:0 (index col), sample_name?,
#description seems not too informative, so maybe we can drop that as well?
df = orig_df.drop(['Unnamed: 0', 'sample_name', 'description'], axis=1)
df = df[df['transcription'].notna() & df['keywords'].notna()]

#reorder to move medical_specialty col to the right
df = df[['transcription', 'keywords', 'medical_specialty']]
df.head()

# 1. Check class distribution after cleaning
print(df['medical_specialty'].value_counts())

#print num rows
print("Number of rows in dataset:", len(df))

# Combine multiple text fields
df['combined_text'] = df['transcription'].fillna('') + ' ' + df['keywords'].fillna('')
X = df['combined_text']

medical_specialty
Surgery                          1021
Orthopedic                        303
Cardiovascular / Pulmonary        280
Radiology                         251
Consult - History and Phy.        234
Gastroenterology                  195
Neurology                         168
General Medicine                  146
SOAP / Chart / Progress Notes     142
Urology                           140
Obstetrics / Gynecology           130
ENT - Otolaryngology               84
Neurosurgery                       81
Ophthalmology                      79
Discharge Summary                  77
Nephrology                         63
Hematology - Oncology              62
Pain Management                    58
Office Notes                       44
Pediatrics - Neonatal              42
Podiatry                           42
Emergency Room Reports             31
Dentistry                          25
Cosmetic / Plastic Surgery         25
Dermatology                        25
Letters                         

In [16]:
#do some relabeling
RELABEL_RULES = {
    "Cardiovascular/Pulmonary": [
        "troponin", "acute coronary", "ecg", "ekg",
        "cardiac catheterization", "stent", "angiogram"
    ],
    "Orthopedics": [
        "fracture", "tibia", "femur", "cast",
        "weight bearing", "ligament tear"
    ],
    "Neurology": [
        "seizure", "stroke", "cva", "tia",
        "parkinson", "brain mri"
    ],
    "Gastroenterology": [
        "colonoscopy", "gi bleed", "melena",
        "pancreatitis", "cirrhosis"
    ],
    "Surgery": [
        "post operative", "incision",
        "laparoscopic", "surgical repair"
    ]
}

def relabel_specialty(text, current_label):
    text = text.lower()
    for specialty, keywords in RELABEL_RULES.items():
        if any(k in text for k in keywords):
            return specialty
    return current_label

df['specialty_refined'] = df.apply(
    lambda row: relabel_specialty(row['combined_text'], row['medical_specialty']), axis=1
)

df["specialty_refined"].value_counts()


Unnamed: 0_level_0,count
specialty_refined,Unnamed: 1_level_1
Neurology,958
Cardiovascular/Pulmonary,859
Surgery,654
Orthopedics,369
Surgery,163
Consult - History and Phy.,89
Gastroenterology,84
Radiology,84
SOAP / Chart / Progress Notes,76
Cardiovascular / Pulmonary,75


In [17]:
#combine some categories

mapping = {
    'Cosmetic / Plastic Surgery': 'Surgery',
    'Neurosurgery': 'Surgery',
    'Surgery': 'Surgery',
    'ENT - Otolaryngology': 'Surgery',

    'Orthopedic': 'Orthopedics',
    'Podiatry': 'Orthopedics',
    'Physical Medicine - Rehab': 'Orthopedics',
    'Chiropractic': 'Orthopedics',
    'Rheumatology': 'Orthopedics',

    'Cardiovascular / Pulmonary': 'Cardiovascular/Pulmonary',

    'Gastroenterology': 'Gastroenterology',
    'Bariatrics': 'Gastroenterology',

    'Neurology': 'Neurology',
    'Psychiatry / Psychology': 'Neurology',
    'Pain Management': 'Neurology',
    'Sleep Medicine': 'Neurology',

    'Obstetrics / Gynecology': 'Women/Men\'s Reproductive Health',
    'Urology': 'Women/Men\'s Reproductive Health',

    'Hematology - Oncology': 'Kidney & Blood/Oncology',
    'Nephrology': 'Kidney & Blood/Oncology',

    'Radiology': 'Radiology & Diagnostics',
    'Lab Medicine - Pathology': 'Radiology & Diagnostics',

    'General Medicine': 'General Medicine',
    'Consult - History and Phy.': 'General Medicine',
    'SOAP / Chart / Progress Notes': 'General Medicine',
    'Discharge Summary': 'General Medicine',
    'Office Notes': 'General Medicine',
    'Letters': 'General Medicine',
    'Hospice - Palliative Care': 'General Medicine',
    'IME-QME-Work Comp etc.': 'General Medicine',
    'Emergency Room Reports': 'General Medicine',

    'Ophthalmology': 'Other Specialties',
    'Dermatology': 'Other Specialties',
    'Pediatrics - Neonatal': 'Other Specialties',
    'Dentistry': 'Other Specialties',
    'Speech - Language': 'Other Specialties',
    'Endocrinology': 'Other Specialties',
    'Diets and Nutritions': 'Other Specialties',
    'Allergy / Immunology': 'Other Specialties',

}

In [18]:
df['medical_specialty'] = df['medical_specialty'].str.strip()
df['specialty_final'] = df['specialty_refined'].map(mapping)

# Keep existing labels when mapping fails
df['specialty_final'] = df['specialty_final'].fillna(df['specialty_refined'])

# Check results
print(df['specialty_final'].value_counts())

specialty_final
Neurology                         958
Cardiovascular/Pulmonary          859
Surgery                           654
Orthopedics                       369
 Surgery                          163
 Consult - History and Phy.        89
Gastroenterology                   84
 Radiology                         84
 SOAP / Chart / Progress Notes     76
 Cardiovascular / Pulmonary        75
 General Medicine                  59
 Gastroenterology                  47
 Urology                           37
 Orthopedic                        34
 ENT - Otolaryngology              32
 Obstetrics / Gynecology           31
 Pain Management                   28
 Neurology                         24
 Pediatrics - Neonatal             20
 Office Notes                      19
 Ophthalmology                     19
 Nephrology                        18
 Discharge Summary                 16
 Hematology - Oncology             12
 Emergency Room Reports            12
 Dermatology                      

In [19]:
#relabeling might cause issues
counts = df['specialty_final'].value_counts()
rare_labels = counts[counts < 2].index

df.loc[df['specialty_final'].isin(rare_labels), 'specialty_final'] = 'General Medicine'
df['specialty_final'].value_counts().sort_values()

Unnamed: 0_level_0,count
specialty_final,Unnamed: 1_level_1
Physical Medicine - Rehab,2
Lab Medicine - Pathology,2
Allergy / Immunology,3
Bariatrics,3
Endocrinology,3
Rheumatology,4
General Medicine,5
Letters,5
Psychiatry / Psychology,6
Dentistry,6


In [20]:
#split into train and test sets
from sklearn.model_selection import train_test_split

X = df['combined_text']
y = df['specialty_final']

# Correct assignment order: X_train, X_test, y_train, y_test
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# quick sanity-check shapes
print('Shapes ->', X_train.shape, X_test.shape, y_train.shape, y_test.shape)


Shapes -> (3118,) (780,) (3118,) (780,)


In [21]:
#TF-IDF Vectorization
from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer(
    max_features=10000,
    ngram_range=(1,3),
    stop_words='english'
)

X_train_tfidf = vectorizer.fit_transform(X_train)
X_test_tfidf = vectorizer.transform(X_test)

In [22]:
#BERT model training
import torch
from sklearn.preprocessing import LabelEncoder
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

# Ecode labels because BERT cannot handle string labels

label_encoder = LabelEncoder()
df['label'] = label_encoder.fit_transform(df['specialty_final'])

num_labels = len(label_encoder.classes_)
print(num_labels, label_encoder.classes_)

#tokenize
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

train_encodings = tokenizer(list(X_train), truncation=True, padding=True, max_length=512)
test_encodings = tokenizer(list(X_test), truncation=True, padding=True, max_length=512)

#map labels to encoded labels
y_train_encoded = label_encoder.transform(y_train)
y_test_encoded = label_encoder.transform(y_test)

39 [' Allergy / Immunology' ' Bariatrics' ' Cardiovascular / Pulmonary'
 ' Consult - History and Phy.' ' Dentistry' ' Dermatology'
 ' Diets and Nutritions' ' Discharge Summary' ' ENT - Otolaryngology'
 ' Emergency Room Reports' ' Endocrinology' ' Gastroenterology'
 ' General Medicine' ' Hematology - Oncology' ' Lab Medicine - Pathology'
 ' Letters' ' Nephrology' ' Neurology' ' Neurosurgery'
 ' Obstetrics / Gynecology' ' Office Notes' ' Ophthalmology' ' Orthopedic'
 ' Pain Management' ' Pediatrics - Neonatal' ' Physical Medicine - Rehab'
 ' Podiatry' ' Psychiatry / Psychology' ' Radiology' ' Rheumatology'
 ' SOAP / Chart / Progress Notes' ' Surgery' ' Urology'
 'Cardiovascular/Pulmonary' 'Gastroenterology' 'General Medicine'
 'Neurology' 'Orthopedics' 'Surgery']


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [23]:
#create pytorch dataset
class MedicalDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(int(self.labels[idx]))
        return item

    def __len__(self):
        return len(self.labels)


train_dataset = MedicalDataset(train_encodings, y_train_encoded)
test_dataset = MedicalDataset(test_encodings, y_test_encoded)

In [24]:
#load BERT model
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np

#define metrics
def compute_metrics(pred):
    logits, labels = pred
    predictions = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, predictions)
    return {'accuracy': acc}

    # labels = pred.label_ids
    # preds = pred.predictions.argmax(-1)
    # precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    # acc = accuracy_score(labels, preds)
    # return {
    #     'accuracy': acc,
    #     'f1': f1,
    #     'precision': precision,
    #     'recall': recall
    # }

In [26]:
#training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=8,   # Reduce if out of memory
    per_device_eval_batch_size=8,
    # warmup_steps=500,
    weight_decay=0.01,
    eval_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    load_best_model_at_end=True,
    metric_for_best_model='eval_accuracy',
)

In [27]:
#trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

In [28]:
#train BERT model
trainer.train()

#evaluate
trainer.evaluate()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 3


[34m[1mwandb[0m: You chose "Don't visualize my results"


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,2.050115,0.40641


{'eval_loss': 2.0501153469085693,
 'eval_accuracy': 0.4064102564102564,
 'eval_runtime': 24.0518,
 'eval_samples_per_second': 32.43,
 'eval_steps_per_second': 4.075,
 'epoch': 1.0}

In [29]:
#save BERT results
bert_save_path = "./bert_model"

trainer.save_model(bert_save_path)
tokenizer.save_pretrained(bert_save_path)

('./bert_model/tokenizer_config.json',
 './bert_model/special_tokens_map.json',
 './bert_model/vocab.txt',
 './bert_model/added_tokens.json',
 './bert_model/tokenizer.json')

In [30]:
#Logistic Regression Model
from sklearn.linear_model import LogisticRegression

model = LogisticRegression(
    max_iter=1000,
    class_weight='balanced',
    random_state=42
)

model.fit(X_train_tfidf, y_train)
print("Model training complete.")


Model training complete.


In [31]:
#save logreg results
import pickle
with open("logreg.pkl", "wb") as f:
    pickle.dump({
        "model": model,
        "vectorizer": vectorizer,
        "label_encoder": label_encoder
    }, f)


In [32]:
#ensemble predictions
from scipy.special import softmax

#BERT predictions
bert_preds = trainer.predict(test_dataset)
bert_probs = softmax(bert_preds.predictions, axis=1)

#LogReg predictions
logreg_probs = model.predict_proba(X_test_tfidf)

#average probs
# Note: Need to align class order between models
alpha = 0.6  # trust BERT slightly more
ensemble_probs = alpha * bert_probs + (1 - alpha) * logreg_probs
ensemble_preds = label_encoder.inverse_transform(np.argmax(ensemble_probs, axis=1))


In [33]:
# Evaluate ensemble
from sklearn.metrics import classification_report, accuracy_score
accuracy_ensemble = accuracy_score(y_test, ensemble_preds)
print("\n" + "="*50)
print("ENSEMBLE MODEL RESULTS")
print("="*50)
print(classification_report(y_test, ensemble_preds))
print(f"Accuracy: {accuracy_ensemble:.4f}")


ENSEMBLE MODEL RESULTS
                                precision    recall  f1-score   support

          Allergy / Immunology       0.00      0.00      0.00         1
                    Bariatrics       0.00      0.00      0.00         0
    Cardiovascular / Pulmonary       0.00      0.00      0.00        15
    Consult - History and Phy.       0.00      0.00      0.00        18
                     Dentistry       0.33      1.00      0.50         1
                   Dermatology       0.33      0.50      0.40         2
          Diets and Nutritions       0.00      0.00      0.00         2
             Discharge Summary       0.00      0.00      0.00         3
          ENT - Otolaryngology       0.50      0.17      0.25         6
        Emergency Room Reports       0.00      0.00      0.00         2
                 Endocrinology       0.00      0.00      0.00         1
              Gastroenterology       0.12      0.11      0.12         9
              General Medicine       0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [34]:
import pickle

with open("specialty_classifier.pkl", "wb") as f:
  pickle.dump(
      {
          "model": model,
          "vectorizer": vectorizer,
          "label_encoder": label_encoder,
      },
  f
)
  print("Model saved as specialty_classifier.pkl")

Model saved as specialty_classifier.pkl


In [None]:
!python app.py

2025-12-16 04:46:36.085642: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765860396.106722    3230 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765860396.112726    3230 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1765860396.129134    3230 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765860396.129164    3230 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1765860396.129170    3230 computation_placer.cc:177] computation placer alr