In [2]:
from transformers import (AutoTokenizer,
                          AutoModelForSequenceClassification, 
                          TrainingArguments, 
                          Trainer)
from datasets import load_dataset

from pathlib import Path
import numpy as np

In [3]:
data_dir = Path("../data/processed/")
ckpt = "distilbert-base-uncased"

In [4]:
data_files = {
    "train": str(data_dir/"wndp-api-data-train.parquet"),
    "val": str(data_dir/"wndp-api-data-val.parquet"),
    "test": str(data_dir/"wndp-api-data-test.parquet"),
}

ds = load_dataset("parquet", data_files=data_files)
ds.set_format("torch")
ds

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9547
    })
    val: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2387
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2106
    })
})

In [5]:
labels = [
    'clinically_healthy',
    'dermatologic_disease',
    'gastrointestinal_disease',
    'hematologic_disease',
    'neurologic_disease',
    'nonspecific',
    'nutritional_disease',
    'ocular_disease',
    'physical_injury',
    'respiratory_disease',
    'urogenital_disease'
]
id2label = {idx:label for idx,label in enumerate(labels)}
label2id = {label:idx for idx,label in enumerate(labels)}

In [6]:
num_labels = len(ds["train"][0]["labels"])
tokenizer = AutoTokenizer.from_pretrained(ckpt, use_fast=True)

In [7]:
sample = ds["train"][0]
sample.keys()

dict_keys(['labels', 'input_ids', 'token_type_ids', 'attention_mask'])

In [8]:
tokenizer.decode(sample["input_ids"])

'[CLS] collision - car ( 2 separate calls in the afternoon about an owl on the side of the road just before the weigh station on the id side of the pass ( though still in wy ). bc and ca went and picked up the bird at about 3 : 15p. found readily. ). trauma - head ; eye trauma - uveitis ( l ). eyes / ears / mouth / nares : left pupil is enlarged and left ear has fresh blood coming out of it. blood in the mouth that appears to be coming from the sinus area. neurologic : head trauma. left pupil is enlarged and left [SEP]'

In [9]:
sample["labels"]

tensor([0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0.])

In [10]:
[id2label[idx] for idx, label in enumerate(sample['labels']) if label == 1.0]

['neurologic_disease', 'ocular_disease', 'physical_injury']

In [11]:
model = AutoModelForSequenceClassification.from_pretrained(
            ckpt,
            num_labels=num_labels,
            problem_type="multi_label_classification",
            id2label=id2label,
            label2id=label2id
        )

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
batch_size = 64
metric_name = "f1"

In [13]:
args = TrainingArguments(
    f"wndp-exp",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-4,
    num_train_epochs=10,
    weight_decay=1e-2,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name
)

In [14]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch
    
# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [15]:
ds["train"][0]["labels"]

tensor([0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0.])

In [16]:
ds["train"]["input_ids"][0]

tensor([  101, 12365,  1011,  2482,  1006,  1016,  3584,  4455,  1999,  1996,
         5027,  2055,  2019, 13547,  2006,  1996,  2217,  1997,  1996,  2346,
         2074,  2077,  1996, 17042,  2276,  2006,  1996,  8909,  2217,  1997,
         1996,  3413,  1006,  2295,  2145,  1999,  1059,  2100,  1007,  1012,
         4647,  1998,  6187,  2253,  1998,  3856,  2039,  1996,  4743,  2012,
         2055,  1017,  1024,  2321,  2361,  1012,  2179, 12192,  1012,  1007,
         1012, 12603,  1011,  2132,  1025,  3239, 12603,  1011, 23068, 20175,
         2483,  1006,  1048,  1007,  1012,  2159,  1013,  5551,  1013,  2677,
         1013,  6583,  6072,  1024,  2187, 11136,  2003, 11792,  1998,  2187,
         4540,  2038,  4840,  2668,  2746,  2041,  1997,  2009,  1012,  2668,
         1999,  1996,  2677,  2008,  3544,  2000,  2022,  2746,  2013,  1996,
         8254,  2271,  2181,  1012, 11265, 10976, 27179,  1024,  2132, 12603,
         1012,  2187, 11136,  2003, 11792,  1998,  2187,   102])

In [17]:
outputs = model(
            input_ids=ds["train"]["input_ids"][0].unsqueeze(0),
            labels=ds["train"][0]["labels"].unsqueeze(0)            
)
outputs.logits

tensor([[-0.0525,  0.1150,  0.0513, -0.1181,  0.0271,  0.1413, -0.0987,  0.0019,
         -0.0383, -0.0723, -0.0325]], grad_fn=<AddmmBackward0>)

In [18]:
trainer = Trainer(
    model,
    args,
    train_dataset=ds["train"],
    eval_dataset=ds["val"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [19]:
%%time
trainer.train()

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,No log,0.131074,0.783668,0.84642,0.641391
2,No log,0.103688,0.832922,0.898346,0.702556
3,No log,0.098942,0.849139,0.914972,0.724759
4,0.133900,0.102042,0.852691,0.917778,0.732719
5,0.133900,0.106306,0.852102,0.908435,0.736489
6,0.133900,0.111819,0.857694,0.920284,0.741935
7,0.038300,0.118829,0.855172,0.919756,0.738165
8,0.038300,0.1238,0.859615,0.922573,0.743192
9,0.038300,0.125703,0.857648,0.920145,0.741935
10,0.013200,0.127081,0.855297,0.917393,0.738165


CPU times: total: 1h 16min 37s
Wall time: 1h 51min 12s


TrainOutput(global_step=1500, training_loss=0.061807768185933434, metrics={'train_runtime': 6672.0794, 'train_samples_per_second': 14.309, 'train_steps_per_second': 0.225, 'total_flos': 3162173091786240.0, 'train_loss': 0.061807768185933434, 'epoch': 10.0})

In [20]:
trainer.evaluate()

{'eval_loss': 0.12380031496286392,
 'eval_f1': 0.8596153846153847,
 'eval_roc_auc': 0.9225732951924468,
 'eval_accuracy': 0.7431922915793884,
 'eval_runtime': 32.9476,
 'eval_samples_per_second': 72.448,
 'eval_steps_per_second': 1.153,
 'epoch': 10.0}

In [32]:
sample = "found on the ground by window - breathing hard, eyes not open, couldn't stand up, ants covering him, some spazmotic movements of leg, wing, seemed better today. emaciated fledgling with torticollis. Neurologic: torticollis Legs / Feet / Hocks: not using legs. poor prognosis given age, emaciation, and degree of debilitation"

### example text

sample = "found on the ground by window - breathing hard, eyes not open, couldn't stand up, ants covering him, some spazmotic movements of leg, wing, seemed better today. emaciated fledgling with torticollis. Neurologic: torticollis Legs / Feet / Hocks: not using legs. poor prognosis given age, emaciation, and degree of debilitation"

In [22]:
enc = tokenizer(sample, return_tensors="pt")

In [23]:
enc = {k: v.to(trainer.model.device) for k,v in enc.items()}

In [24]:
outputs = trainer.model(**enc)

In [25]:
outputs

SequenceClassifierOutput(loss=None, logits=tensor([[-7.6655, -7.0422, -6.6422, -7.1536,  5.8147, -8.2998,  3.2121, -5.6077,
         -4.8006, -5.8661, -7.4474]], device='cuda:0',
       grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [26]:
import torch.nn.functional as F

In [27]:
probs = F.sigmoid(outputs.logits.squeeze().detach().cpu())

In [28]:
probs

tensor([4.6848e-04, 8.7343e-04, 1.3025e-03, 7.8141e-04, 9.9703e-01, 2.4851e-04,
        9.6129e-01, 3.6562e-03, 8.1578e-03, 2.8259e-03, 5.8260e-04])

In [29]:
preds = (probs > 0.5).int()

In [30]:
predicted_labels = [id2label[idx] for idx, label in enumerate(preds) if label == 1.0]

In [31]:
predicted_labels

['neurologic_disease', 'nutritional_disease']