In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torchinfo import summary
import transformers
from sklearn.preprocessing import LabelEncoder
import kagglehub
import pandas as pd
import numpy as np
import math
import matplotlib.pyplot as plt
from pathlib import Path
from os import cpu_count

In [2]:
torch.set_printoptions(threshold=float('inf'))
pd.set_option("display.max_colwidth", 1000)

Set manual random seed for reproducibility.

In [3]:
path = Path(kagglehub.dataset_download("artemminiailo/medicalconversations2disease")) / "medical_conversations.csv"
df = pd.read_csv(path)
df.head()

Unnamed: 0,conversations,disease
0,"User: I’ve been sneezing a lot today and my nose feels congested. </s> Bot: That sounds like it could be an allergy. Do you know what might be triggering it? </s> User: I’m not sure. Maybe pollen? </s> Bot: Pollen is a common allergen. Have you had any other symptoms, like itchy eyes or a sore throat? </s> User: Yes, my eyes are itchy too. </s> Bot: Sounds like classic hay fever. Over-the-counter antihistamines could help relieve the symptoms. </s>",allergy
1,"User: I’ve developed a rash after eating some strawberries. </s> Bot: That could be an allergic reaction. Have you had strawberries before without any issues? </s> User: Yes, but this time I ate a lot of them. </s> Bot: It could be a food allergy. If the rash worsens or if you have difficulty breathing, it’s important to seek medical attention immediately. </s>",allergy
2,"User: My eyes are swollen and itchy, and I can’t stop sneezing. </s> Bot: That could be an allergy, possibly triggered by dust or pollen. Do you have a history of allergies? </s> User: Yes, I usually get allergies in the spring. </s> Bot: Spring allergies are often due to pollen. You might want to take an antihistamine to help with the symptoms. </s>",allergy
3,User: I’ve been getting headaches and a stuffy nose for a few days. </s> Bot: Do you think it might be an allergy or possibly a cold? </s> User: I think it’s allergies because I’ve been around pets a lot. </s> Bot: Pet allergies are common. You could try using allergy medication to relieve your symptoms. </s>,allergy
4,"User: Every time I eat nuts, my mouth itches. </s> Bot: That sounds like a food allergy. Have you ever had a more severe reaction? </s> User: No, it’s just the itching. </s> Bot: It might still be a good idea to avoid nuts, especially if you start experiencing more serious symptoms, like swelling or difficulty breathing. </s>",allergy


In [4]:
# SEED = 42
# torch.manual_seed(SEED)
# torch.cuda.manual_seed(SEED)

In [5]:
MODELS_BASEDIR = Path("saved-models")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("cpu")
MODEL_NAME = 'bert-base-uncased'
EPOCHS = 3000
BATCH_SIZE = 32
LEARNING_RATE = 5e-4
DEVICE.type

'cuda'

In [6]:
bert = transformers.BertModel.from_pretrained(MODEL_NAME)
tokenizer = transformers.BertTokenizer.from_pretrained(MODEL_NAME)
dataset_path = Path(kagglehub.dataset_download("artemminiailo/medicalconversations2disease")) / "medical_conversations.csv"

2024-12-30 09:32:53.055712: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-12-30 09:32:53.067311: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1735543973.087228  517103 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1735543973.093396  517103 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Custom Dataset that returns tokenized text and labels.

In [7]:
class LabeledText(Dataset):
    def __init__(self, df, tokenizer):
        self.tokens = tokenizer(df['text'].tolist(), padding=True, truncation=True, return_tensors='pt').to(DEVICE)
        self.label_encoder = LabelEncoder()
        self.labels = torch.tensor(self.label_encoder.fit_transform(df['label'])).to(DEVICE)
        self.classes = self.label_encoder.classes_
    
    def __len__(self):
        return len(self.tokens['input_ids'])

    def __getitem__(self, index):
        return (self.tokens['input_ids'][index], self.tokens['token_type_ids'][index], self.tokens['attention_mask'][index]), self.labels[index]

In [None]:
def get_datasets(path, tokenizer):
    # Ensure uniform class distribution in train and test sets
    df = pd.read_csv(path, index_col=False, names=('text', 'label'), skiprows=1)
    # df.columns = ['text', 'label']
    classes_dict = df.groupby('label').groups
    train_set_size = 0.9
    train_set_idxs = []
    test_set_idxs = []
    for idxs in classes_dict.values():
        train_idxs = np.random.choice(idxs, math.floor(len(idxs) * train_set_size), replace=False)
        test_idxs = np.setdiff1d(idxs, train_idxs)
        train_set_idxs.extend(train_idxs.tolist())
        test_set_idxs.extend(test_idxs.tolist())
    
    train_set = LabeledText(df.iloc[train_set_idxs], tokenizer)
    test_set = LabeledText(df.iloc[test_set_idxs], tokenizer)
    return train_set, test_set

In [None]:
_, test_set = get_datasets(dataset_path, tokenizer)
len(test_set)

In [9]:
class BertClassifier(nn.Module):
    def __init__(self, bert, num_classes):
        super(BertClassifier, self).__init__()
        self.bert = bert
        for param in self.bert.parameters():
            param.requires_grad = False
        self.dropout = nn.Dropout(0.3)
        self.output = nn.Linear(in_features=self.bert.pooler.dense.out_features, out_features=num_classes, bias=True)
    
    def forward(self, input_ids, token_type_ids, attention_mask):
        x = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)['pooler_output']
        x = self.dropout(x)
        x = self.output(x)
        return x

Class that stops training when the validation loss stops decreasing.

In [10]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float("inf")

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            print('Patience left:', self.patience - self.counter)
            if self.counter >= self.patience:
                return True
        return False

In [None]:
def train_step(model, loader, loss_fn, optimizer, set_size):
    model.train()
    total_loss, total_correct = 0, 0
    for (input_ids, token_type_ids, attention_masks), y in loader:
        outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_masks, labels=y)
        loss = outputs.loss
        total_loss += loss.item()
        total_correct += sum(outputs.logits.argmax(dim=1) == y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    train_loss = total_loss / len(loader)
    train_accuracy = total_correct / set_size
    print(f"Train loss: {train_loss:.4f} | Train accuracy: {train_accuracy:.4f}")

    return train_loss, train_accuracy


def test_step(model, loader, loss_fn, set_size):
    model.eval()
    total_loss, total_correct = 0, 0

    with torch.inference_mode():
        for (input_ids, token_type_ids, attention_masks), y in loader:
            outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_masks, labels=y)
            total_loss += outputs.loss
            total_correct += sum(outputs.logits.argmax(dim=1) == y)
        
        test_loss = total_loss / len(loader)
        test_accuracy = total_correct / set_size
        print(f"Test loss: {test_loss:.4f} | Test accuracy: {test_accuracy:.4f}")

    return test_loss, test_accuracy

In [12]:
def save_model(model: nn.Module, model_name):
    MODELS_BASEDIR.mkdir(parents=True, exist_ok=True)
    model_save_path = MODELS_BASEDIR / model_name
    print(f"Saving model to: {model_save_path}")
    torch.save(obj=model.state_dict(), f=model_save_path)

In [13]:
def train():
    train_data, test_data = get_datasets(dataset_path, tokenizer)
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE,)# num_workers=cpu_count())
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE,)# num_workers=cpu_count())

    # model = BertClassifier(bert, len(train_data.classes))
    model = transformers.BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(train_data.classes))
    for param in model.base_model.parameters():
        param.requires_grad = False
    print(summary(model))

    optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.CrossEntropyLoss()
    early_stopper = EarlyStopper(patience=20)

    model.to(DEVICE)
    loss_fn.to(DEVICE)

    for epoch in range(1, EPOCHS + 1):
        print(f"Epoch {epoch}\n-------------------------------")
        train_loss, train_accuracy = train_step(model, train_loader, loss_fn, optimizer, len(train_data))
        test_loss, test_accuracy = test_step(model, test_loader, loss_fn, len(test_data))

        if early_stopper.early_stop(test_loss):
            print("Early stopping")
            break
        torch.cuda.empty_cache()

    save_model(model, "bert-disease-predictor2.pt")

In [None]:
train()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Layer (type:depth-idx)                                       Param #
BertForSequenceClassification                                --
├─BertModel: 1-1                                             --
│    └─BertEmbeddings: 2-1                                   --
│    │    └─Embedding: 3-1                                   (23,440,896)
│    │    └─Embedding: 3-2                                   (393,216)
│    │    └─Embedding: 3-3                                   (1,536)
│    │    └─LayerNorm: 3-4                                   (1,536)
│    │    └─Dropout: 3-5                                     --
│    └─BertEncoder: 2-2                                      --
│    │    └─ModuleList: 3-6                                  (85,054,464)
│    └─BertPooler: 2-3                                       --
│    │    └─Linear: 3-7                                      (590,592)
│    │    └─Tanh: 3-8                                        --
├─Dropout: 1-2                                         