In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding
)
from transformers import TrainingArguments
import os
from datasets import Dataset, DatasetDict
from transformers import Trainer
import numpy as np
from sklearn.metrics import f1_score, accuracy_score
import pandas as pd
from transformers import EarlyStoppingCallback

In [None]:
def create_dataset(train_path="", dev_path="", test_path="", as_df=True):
    """transform pandas dataset into transformer datasets.dataset

    Args:
        sliding_window (str, optional): Token limit avoiding approaches. Defaults to "".

    Returns:
        datasets.dataset: dataset
    """
    
    train_data = pd.read_csv(train_path)
    dev_data = pd.read_csv(dev_path)
    test_data = pd.read_csv(test_path)

    train_data["idx"] = [i for i in range(train_data["question"].size)]
    dev_data["idx"] = [i for i in range(dev_data["question"].size)]
    test_data["idx"] = [i for i in range(test_data["question"].size)]

    if(as_df):
        return train_data, dev_data
    else:
        dataset_train = Dataset.from_pandas(train_data)
        dataset_dev = Dataset.from_pandas(dev_data)
        dataset_test = Dataset.from_pandas(test_data)

        complete_ds = DatasetDict(
            {"train": dataset_train, "dev": dataset_dev, "test": dataset_test}
        )
        
        return complete_ds

In [None]:
def tokenize_function(example):
    return tokenizer(
        example["input"], example["output"], truncation=True, max_length=512
    )  

def compute_metrics(pred):
    
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    print([(l, p) for l, p in list(zip(labels, preds)) if l == 1][:10])
    # calculate accuracy using sklearn's function
    f1 = f1_score(labels, preds, average="macro")
    acc = accuracy_score(labels, preds)
    print("dev set len:", len(tokenized_datasets["dev"]), "pred len:", len(preds))
    print("F1:", f1, "Acc:", acc)
    return {
        "f1_score": f1,
    }

def combine_splitted_samples(dataset):
    """split the dataset in many datasets where all samples are combined with the same idx_complete

    Args:
        dataset (huggingface datasets): dataset to be splitted

    Returns:
        list: list of hugginface datasets
    """
    ds_list = []
    temp_dict = {}
    for sample in dataset:
        if sample["idx_complete"] not in temp_dict:
            temp_dict[sample["idx_complete"]] = dict(
                map(lambda x: (x, []), sample.keys())
            )
        for key in temp_dict[sample["idx_complete"]].keys():
            temp_dict[sample["idx_complete"]][key].append(sample[key])

    for sample in temp_dict.values():
        ds_list.append(Dataset.from_dict(sample))

    return ds_list


def predict_on_dataset(list_dataset):
    """Predict all samples in a list of datasets and take the mean of the predictions as the final prediction

    Args:
        list_dataset (list): list of small datasets containing all sub samples of a complete sample

    Returns:
        tuple: prediction and correct label as two lists
    """
    preds = []
    gt = []
    for sub_dataset in list_dataset:
        sub_predicition = trainer.predict(sub_dataset)
        # calc the average of the predictions
        sub_predicition = sub_predicition.predictions.mean(axis=0)
        preds.append(sub_predicition.argmax())
        # gt.append(sub_dataset[0]["label"])
    return preds, gt

In [None]:
legal_ds = create_dataset(sliding_window="keep_question", train_path="data/train.csv", dev_path="data/dev.csv", test_path="data/final_test.csv", as_df=False)

In [None]:
checkpoint = 'nlpaueb/legal-bert-base-uncased'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, padding=True)

tokenized_datasets = legal_ds.map(tokenize_function, batched=True)
samples = tokenized_datasets["train"]["input_ids"][:8]

In [None]:
training_args = TrainingArguments(
    output_dir=os.path.join(checkpoint + "_finetuneFFF2"),
    group_by_length=True,
    per_device_train_batch_size=4,  
    evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=100,
    bf16=True,
    learning_rate=2.2493504069942526e-05,
    warmup_steps=10,
    gradient_accumulation_steps=2,
    logging_strategy="epoch",
    seed=7854,
    load_best_model_at_end=True,
    greater_is_better=True,
    metric_for_best_model="f1_score",
    weight_decay = 0.015631376305494087,
    # disable_tqdm=True
)

model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint, num_labels=2
)

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["dev"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=10, early_stopping_threshold=-0.05)]

)

trainer.train()

# Eval
# Combine all samples with the same idx_complete into a single batch for testing
dev_dataset_list = combine_splitted_samples(tokenized_datasets["dev"])
# predict over dev dataset list and combine for each list entry all predictions to a final prediction
preds, gt = predict_on_dataset(dev_dataset_list)

# Score dev set
print("F1 Score (Macro)", f1_score(preds, gt, average="macro"))
print("F1 Score (binary)", f1_score(preds, gt, average="micro"))
print("Accuracy", accuracy_score(preds, gt))
results.append(
    (
        f1_score(preds, gt, average="macro"),
        f1_score(preds, gt, average="micro"),
        accuracy_score(preds, gt),
    )
)

In [None]:
pred_dataset_list = combine_splitted_samples(tokenized_datasets["test"])

In [None]:
preds, gt = predict_on_dataset(pred_dataset_list)