In [None]:
%pip install accelerate evaluate numpy transformers[torch]

In [None]:
from datasets import ClassLabel, Features, load_dataset, TextClassification, Value
from os import sched_getaffinity
from torch import get_num_threads, set_num_threads
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer

import evaluate
import numpy as np

In [None]:
num_threads = max(1, get_num_threads(), len(sched_getaffinity(0)) - 1)
set_num_threads(num_threads)
num_threads

In [None]:
meta_groups = ['Criminal', 'Tax']
labels = ['PREAMBLE', 'FAC', 'RLC', 'ISSUE', 'ARG_PETITIONER', 'ARG_RESPONDENT', 'ANALYSIS', 'STA', 'PRE_RELIED', 'PRE_NOT_RELIED', 'RATIO', 'RPC', 'NONE']

dataset = load_dataset('csv', data_files={
    'train': 'BUILD/train.csv',
    'test': 'BUILD/dev.csv',
}, features=Features({
    'doc_id': Value('uint32'),
    'doc_index': Value('uint16'),
    'sentence_index': Value('uint16'),
    'annotation_id': Value('string'),
    'text': Value('string'),
    'meta_group': ClassLabel(names=meta_groups),
    'labels': ClassLabel(names=labels),
}), task=TextClassification())

In [None]:
tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("nlpaueb/legal-bert-base-uncased", num_labels=len(labels))

In [None]:
tokenized_dataset = dataset.map(lambda t: tokenizer(t['text'], truncation=True), batched=True)
tokenized_dataset = tokenized_dataset.remove_columns('text')

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
metric = evaluate.load("f1")

def compute_metrics(eval_pred):
    predictions, references = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=references)

In [None]:
training_args = TrainingArguments(
    output_dir="test_legalbert_model",
    evaluation_strategy="epoch",
    num_train_epochs=1,
    save_strategy="epoch",
    label_names=labels,
    load_best_model_at_end=True,
    logging_dir='./logs',
    logging_steps=10,
)

#trained_dataset = tokenized_dataset['train'].train_test_split(test_size=0.2, stratify_by_column='labels')

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['train'],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()