from pathlib import Path

def read_data_split(file):
    texts = []
    with open(file,'r') as f:
        for line in f:
            texts.append(line.strip())
    return texts

def read_label_split(file):
    labels = []
    with open(file,'r') as f:
        for line in f:
            if line == '1\n':
                labels.append(1)
            else:
                labels.append(0)
    return labels

train_texts = read_data_split('../data/splits/train2')
train_labels = read_label_split('../data/splits/train_label2')
val_texts = read_data_split('../data/splits/valid2')
val_labels = read_label_split('../data/splits/valid_label2')
test_texts = read_data_split('../data/splits/test2')
test_labels = read_label_split('../data/splits/test_label2')

In [1]:
from datasets import load_dataset

In [2]:
input_dataset = load_dataset('text', data_files={'train': '../data/splits/train', 'valid': '../data/splits/valid', 'test': '../data/splits/test'})
label_dataset = load_dataset('text', data_files={'train': '../data/splits/train_label', 'valid': '../data/splits/valid_label', 'test': '../data/splits/test_label'})

Using custom data configuration default
Reusing dataset text (/home/ubuntu/.cache/huggingface/datasets/text/default-df6fdfd0ab6b35f0/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)
Using custom data configuration default
Reusing dataset text (/home/ubuntu/.cache/huggingface/datasets/text/default-1d8aa2beeec09988/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)


In [3]:
from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
#tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

In [4]:
def encode(examples):
     return tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)
    
input_dataset = input_dataset.map(encode, batched=True)

Loading cached processed dataset at /home/ubuntu/.cache/huggingface/datasets/text/default-df6fdfd0ab6b35f0/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-ce35e4dba61492b4.arrow
Loading cached processed dataset at /home/ubuntu/.cache/huggingface/datasets/text/default-df6fdfd0ab6b35f0/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-4315d97384efbefc.arrow
Loading cached processed dataset at /home/ubuntu/.cache/huggingface/datasets/text/default-df6fdfd0ab6b35f0/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab/cache-e242a4782e1c64f1.arrow


In [5]:
# training model on tokenized and split data
import torch

class Dataset(torch.utils.data.Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val) for key, val in self.inputs[idx].items() if key != 'text'}
        item['labels'] = torch.tensor(int(self.labels[idx]['text']))
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = Dataset(input_dataset['train'], label_dataset['train'])
val_dataset = Dataset(input_dataset['valid'], label_dataset['valid'])
test_dataset = Dataset(input_dataset['test'], label_dataset['test'])

In [6]:
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classi

In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    roc = roc_auc_score(labels, pred.predictions[:,-1])
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'auroc': roc,
    }


training_args = TrainingArguments(
    output_dir='/home/ubuntu/data/results_lr4e-5',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=16,   # batch size for evaluation
    warmup_steps=10000,                # number of warmup steps for learning rate scheduler
    weight_decay=0.1,               # strength of weight decay
    logging_dir='/home/ubuntu/logs_lr4e-5',            # directory for storing logs
    logging_steps=100,
    evaluation_strategy='steps',
    learning_rate=4e-5,
    fp16=True,
    save_total_limit=5,
    eval_steps=2000,
    save_steps=2000,
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,             # evaluation dataset
    compute_metrics=compute_metrics,
)

trainer.train()

Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall,Auroc,Runtime,Samples Per Second
2000,0.5535,0.556295,0.715879,0.594966,0.662338,0.540035,0.766214,125.2654,142.426
