In [None]:
# %pip install --upgrade "accelerate>=0.26.0"

Note: you may need to restart the kernel to use updated packages.


In [12]:
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import classification_report

In [13]:
print(torch.cuda.is_available())  # Should return True

True


In [3]:
train_df = pd.read_csv("data/medical_tc_train.csv")
test_df = pd.read_csv("data/medical_tc_test.csv")

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
# Tokenize text
train_encodings = tokenizer(list(train_df['medical_abstract']),
                            truncation=True,
                            padding=True,
                            max_length=128)

test_encodings = tokenizer(list(test_df['medical_abstract']),
                           truncation=True,
                           padding=True,
                           max_length=128)

In [6]:
class MedicalDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_labels = train_df['condition_label'].tolist()
test_labels = test_df['condition_label'].tolist()

train_dataset = MedicalDataset(train_encodings, train_labels)
test_dataset = MedicalDataset(test_encodings, test_labels)

In [7]:
num_labels = max(train_labels) + 1
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    logging_dir='./logs',
    load_best_model_at_end=True
)



In [9]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

In [11]:
trainer.train()

# Step 10: Evaluate
predictions = trainer.predict(test_dataset)
preds = torch.argmax(torch.tensor(predictions.predictions), axis=1)

# Step 11: Report
print(classification_report(test_labels, preds.tolist()))

 23%|██▎       | 500/2166 [03:15<10:55,  2.54it/s]

{'loss': 1.0608, 'grad_norm': 10.147701263427734, 'learning_rate': 3.845798707294552e-05, 'epoch': 0.69}


                                                  
 33%|███▎      | 722/2166 [05:03<09:11,  2.62it/s]

{'eval_loss': 0.9103921055793762, 'eval_runtime': 20.3389, 'eval_samples_per_second': 141.994, 'eval_steps_per_second': 8.899, 'epoch': 1.0}


 46%|████▌     | 1000/2166 [06:54<07:40,  2.53it/s] 

{'loss': 0.8651, 'grad_norm': 4.523875713348389, 'learning_rate': 2.6915974145891044e-05, 'epoch': 1.39}


                                                   
 67%|██████▋   | 1444/2166 [10:10<04:36,  2.61it/s]

{'eval_loss': 0.8782744407653809, 'eval_runtime': 20.3648, 'eval_samples_per_second': 141.813, 'eval_steps_per_second': 8.888, 'epoch': 2.0}


 69%|██████▉   | 1500/2166 [10:34<04:23,  2.53it/s]  

{'loss': 0.7841, 'grad_norm': 7.045168399810791, 'learning_rate': 1.5373961218836565e-05, 'epoch': 2.08}


 92%|█████████▏| 2000/2166 [13:51<01:05,  2.53it/s]

{'loss': 0.6709, 'grad_norm': 4.823151111602783, 'learning_rate': 3.831948291782087e-06, 'epoch': 2.77}


                                                   
100%|██████████| 2166/2166 [15:19<00:00,  2.59it/s]

{'eval_loss': 0.8954036831855774, 'eval_runtime': 20.4585, 'eval_samples_per_second': 141.164, 'eval_steps_per_second': 8.847, 'epoch': 3.0}


100%|██████████| 2166/2166 [15:22<00:00,  2.35it/s]


{'train_runtime': 922.5237, 'train_samples_per_second': 37.56, 'train_steps_per_second': 2.348, 'train_loss': 0.8312384727809128, 'epoch': 3.0}


100%|██████████| 181/181 [00:20<00:00,  8.80it/s]


              precision    recall  f1-score   support

           1       0.72      0.77      0.74       633
           2       0.50      0.75      0.60       299
           3       0.60      0.63      0.61       385
           4       0.66      0.85      0.74       610
           5       0.65      0.38      0.48       961

    accuracy                           0.64      2888
   macro avg       0.62      0.68      0.64      2888
weighted avg       0.64      0.64      0.62      2888

