In [None]:
from transformers import (AutoTokenizer,
                          AutoModelForSequenceClassification, 
                          TrainingArguments, 
                          Trainer)
from datasets import load_dataset

from pathlib import Path
import numpy as np

In [None]:
data_dir = Path("../data/processed/")
ckpt = "bert-base-uncased"

In [None]:
data_files = {
    "train": str(data_dir/"wndp-api-data-train.parquet"),
    "val": str(data_dir/"wndp-api-data-val.parquet"),
    "test": str(data_dir/"wndp-api-data-test.parquet"),
}

ds = load_dataset("parquet", data_files=data_files)
ds.set_format("torch")
ds

In [None]:
labels = [
    'clinically_healthy',
    'dermatologic_disease',
    'gastrointestinal_disease',
    'hematologic_disease',
    'neurologic_disease',
    'nonspecific',
    'nutritional_disease',
    'ocular_disease',
    'physical_injury',
    'respiratory_disease',
    'urogenital_disease'
]
id2label = {idx:label for idx,label in enumerate(labels)}
label2id = {label:idx for idx,label in enumerate(labels)}

In [None]:
num_labels = len(ds["train"][0]["labels"])
tokenizer = AutoTokenizer.from_pretrained(ckpt, use_fast=True)

In [None]:
sample = ds["train"][0]
sample.keys()

In [None]:
tokenizer.decode(sample["input_ids"])

In [None]:
sample["labels"]

In [None]:
[id2label[idx] for idx, label in enumerate(sample['labels']) if label == 1.0]

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
            ckpt,
            num_labels=num_labels,
            problem_type="multi_label_classification",
            id2label=id2label,
            label2id=label2id
        )

In [None]:
batch_size = 32
metric_name = "f1"

In [None]:
args = TrainingArguments(
    f"wndp-exp",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-4,
    num_train_epochs=5,
    weight_decay=0.01,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name
)

In [None]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch
    
# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [None]:
ds["train"][0]["labels"]

In [None]:
ds["train"]["input_ids"][0]

In [None]:
outputs = model(
            input_ids=ds["train"]["input_ids"][0].unsqueeze(0),
            labels=ds["train"][0]["labels"].unsqueeze(0)            
)
outputs.logits

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=ds["train"],
    eval_dataset=ds["val"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
sample = "found on the ground by window - breathing hard, eyes not open, couldn't stand up, ants covering him, some spazmotic movements of leg, wing, seemed better today. emaciated fledgling with torticollis. Neurologic: torticollis Legs / Feet / Hocks: not using legs. poor prognosis given age, emaciation, and degree of debilitation"

In [None]:
enc = tokenizer(sample, return_tensors="pt")

In [None]:
enc = {k: v.to(trainer.model.device) for k,v in enc.items()}

In [None]:
enc

In [None]:
outputs = trainer.model(**enc)

In [None]:
outputs

In [None]:
import torch.nn.functional as F

In [None]:
probs = F.sigmoid(outputs.logits.squeeze().detach().cpu())

In [None]:
probs

In [None]:
preds = (probs > 0.5).int()

In [None]:
predicted_labels = [id2label[idx] for idx, label in enumerate(preds) if label == 1.0]

In [None]:
predicted_labels