BioClinicalBERT is a BERT model pretrained on:

- PubMed biomedical articles

- MIMIC-III clinical notes (real clinical text)

This makes it especially suitable for symptomâ€“disease classification, often outperforming general-purpose models like DistilBERT.

In [None]:
!pip install -U transformers accelerate

In [2]:
import pandas as pd
import numpy as np
import torch

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments
)


2026-01-23 18:11:32.800682: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769191893.111505      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769191893.205313      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769191893.980762      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769191893.980815      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769191893.980818      55 computation_placer.cc:177] computation placer alr

In [3]:
df = pd.read_csv(
    "/kaggle/input/preprocessed-dataset/preprocessed_symptom_disease_dataset.csv"
)

print(df.shape)
df.head()


(1988, 3)


Unnamed: 0,symptoms,disease,symptoms_clean
0,i have been experiencing a skin rash on my arm...,psoriasis,i have been experiencing a skin rash on my arm...
1,"my skin has been peeling, especially on my kne...",psoriasis,"my skin has been peeling, especially on my kne..."
2,i have been experiencing joint pain in my fing...,psoriasis,i have been experiencing joint pain in my fing...
3,"there is a silver like dusting on my skin, esp...",psoriasis,"there is a silver like dusting on my skin, esp..."
4,"my nails have small dents or pits in them, and...",psoriasis,"my nails have small dents or pits in them, and..."


In [4]:
# Encode labels
label_encoder = LabelEncoder()
df["label_id"] = label_encoder.fit_transform(df["disease"])

num_labels = df["label_id"].nunique()
print("Number of classes:", num_labels)


Number of classes: 24


In [5]:
# Train Test split
X_train, X_test, y_train, y_test = train_test_split(
    df["symptoms_clean"],
    df["label_id"],
    test_size=0.2,
    random_state=42,
    stratify=df["label_id"]
)


In [6]:
MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

In [7]:
MAX_LEN = 64

def tokenize_text(texts):
    return tokenizer(
        texts.tolist(),
        padding=True,
        truncation=True,
        max_length=MAX_LEN
    )

train_encodings = tokenize_text(X_train)
test_encodings = tokenize_text(X_test)


In [8]:
class SymptomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels.reset_index(drop=True)

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


In [9]:
train_dataset = SymptomDataset(train_encodings, y_train)
test_dataset = SymptomDataset(test_encodings, y_test)


In [10]:
# Load BioClinicalBERT model
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels
)


pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
# Handle class imbalance (weighted loss)
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(y_train),
    y=y_train
)

class_weights = torch.tensor(class_weights, dtype=torch.float)


In [12]:
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits

        loss_fct = torch.nn.CrossEntropyLoss(
            weight=class_weights.to(logits.device)
        )
        loss = loss_fct(logits, labels)

        return (loss, outputs) if return_outputs else loss


In [13]:
training_args = TrainingArguments(
    output_dir="./bioclinicalbert_results",
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=4,
    weight_decay=0.01,
    logging_strategy="steps",
    logging_steps=50,
    report_to="none"
)


In [14]:
trainer = WeightedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    processing_class=tokenizer
)

trainer.train()




Step,Training Loss,Validation Loss
500,0.6065,0.524369


TrainOutput(global_step=796, training_loss=1.1974157292639191, metrics={'train_runtime': 2026.8848, 'train_samples_per_second': 3.138, 'train_steps_per_second': 0.393, 'total_flos': 209214606827520.0, 'train_loss': 1.1974157292639191, 'epoch': 4.0})

In [15]:
predictions = trainer.predict(test_dataset)
y_pred = predictions.predictions.argmax(axis=1)

print(classification_report(
    y_test,
    y_pred,
    target_names=label_encoder.classes_
))


                                 precision    recall  f1-score   support

                           acne       1.00      1.00      1.00         9
                        allergy       0.94      0.94      0.94        18
                      arthritis       0.94      1.00      0.97        17
               bronchial asthma       1.00      1.00      1.00        18
           cervical spondylosis       1.00      1.00      1.00        17
                    chicken pox       0.89      0.89      0.89        18
                    common cold       0.94      0.94      0.94        17
                         dengue       1.00      0.88      0.94        17
                       diabetes       1.00      0.94      0.97        18
          dimorphic hemorrhoids       1.00      1.00      1.00         8
                  drug reaction       0.94      0.94      0.94        18
               fungal infection       1.00      1.00      1.00        17
gastroesophageal reflux disease       1.00      1.

In [16]:
SAVE_DIR = "./bioclinicalbert_chatbot"

trainer.model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)


('./bioclinicalbert_chatbot/tokenizer_config.json',
 './bioclinicalbert_chatbot/special_tokens_map.json',
 './bioclinicalbert_chatbot/vocab.txt',
 './bioclinicalbert_chatbot/added_tokens.json',
 './bioclinicalbert_chatbot/tokenizer.json')

In [17]:
import joblib

joblib.dump(label_encoder, f"{SAVE_DIR}/label_encoder.pkl")


['./bioclinicalbert_chatbot/label_encoder.pkl']