In [None]:
import torch
from datasets import load_from_disk

In [None]:
from torch import cuda
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)

device = "cuda" if cuda.is_available() else "cpu"

In [None]:
ds = load_from_disk("bld/python/TrainTest/TrainTest_data/")

## Main Part

Fine-tuning
There are two ways we can implement multi-label classification:

- Creating a custom BERT model that overrides the forward method
- Creating a custom Trainer that overrides the compute_loss method

In [None]:
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

In [None]:
def tokenize_and_encode(examples):
    return tokenizer(examples["text"], truncation=True)

In [None]:
# error could be here
cols = ds["train_dataset"].column_names
cols.remove("label")
ds_enc = ds.map(tokenize_and_encode, batched=True, remove_columns=cols)
ds_enc

In [None]:
class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        inputs.pop("label")  # maybe here error
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.BCEWithLogitsLoss()
        loss = loss_fct(
            logits.view(-1, self.model.config.num_labels),
            labels.float().view(-1, self.model.config.num_labels),
        )
        return (loss, outputs) if return_outputs else loss

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=3).to(
    device,
)

In [None]:
# not important first
def accuracy_thresh(y_pred, y_true, thresh=0.5, sigmoid=True):
    y_pred = torch.from_numpy(y_pred)
    y_true = torch.from_numpy(y_true)
    if sigmoid:
        y_pred = y_pred.sigmoid()
    return ((y_pred > thresh) == y_true.bool()).float().mean().item()

In [None]:
# not important first
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    return {"accuracy_thresh": accuracy_thresh(predictions, label)}

In [None]:
batch_size = 8

args = TrainingArguments(
    output_dir="jigsaw",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=1,
    weight_decay=0.01,
)

In [None]:
multi_trainer = MultilabelTrainer(
    model,
    args,
    train_dataset=ds_enc["train_dataset"],
    eval_dataset=ds_enc["val_dataset"],
    compute_metrics=compute_metrics,  # not important for problem
    tokenizer=tokenizer,
)

In [None]:
multi_trainer.evaluate()

In [None]:
multi_trainer.train()