# Data preparation

In [None]:
from datasets import load_from_disk
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
import evaluate
import numpy as np
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer, EarlyStoppingCallback

In [None]:
def new_column(example):
    example["ner_tags"] = example["labels"]
    return example

data = load_from_disk("dataset.hf")
id_column = range(data.num_rows)
data = data.add_column("id", id_column)
data = data.map(new_column)

# Split up the data for testing and training
data = data.train_test_split(test_size=0.1)
test_data = data["test"]
data = data["train"].train_test_split(test_size=0.2)


In [None]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            else:
                label_ids.append(label[word_idx]) # Label all tokens of a given word
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

tokenized_data = data.map(tokenize_and_align_labels, batched=True)
tokenized_test_data = test_data.map(tokenize_and_align_labels, batched=True)

In [None]:
# To get a feel of how the data looks like

print(data)

example = data["train"][0]
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])

print(tokenized_data["train"][0])
print("The input will be tokenized as:", tokens)

## Evaluation method

In [None]:
def filter_and_group_lists(first_list, second_list):
    result_first = []
    result_second = []
    previous_value = None
    current_group = []

    for value1, value2 in zip(first_list, second_list):
        if value1 != -100 and value1 != 0:
            if value1 != previous_value:
                if current_group:
                    result_second.append(current_group)
                current_group = [value2]
                result_first.append(value1)
            else:
                current_group.append(value2)
        previous_value = value1

    if current_group:
        result_second.append(current_group)

    return result_first, result_second

# As an example of how this function works:
# first_list = [-100, 2, 0, 1, 1, 0, 0, 2, 0, 1, 2, -100]
# second_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L']

# filtered_first, grouped_second = filter_and_group_lists(first_list, second_list)
# print("Filtered First List:", filtered_first)
# print("Grouped Second List:", grouped_second)

In [None]:
def getFinalPrediction(predictions):
    predictions = np.array(predictions)
    non_zero_predictions = predictions[predictions != 0]
    counts = np.bincount(non_zero_predictions)
    if counts.size == 0:
        return 0
    return np.argmax(counts)

In [None]:
seqeval = evaluate.load("seqeval")

#mapping = {"Theme": 1, "Agent": 2, "Patient": 3, "Experiencer": 4, "Co-Theme": 5, "Stimulus": 6, "Location": 7, "Destination": 8}
label_list = [
    "O",
    "Theme",
    "Agent",
    "Patient",
    "Experiencer",
    "Co-Theme",
    "Stimulus",
    "Location",
    "Destination",
]

labels = [label_list[i] for i in example[f"ner_tags"]]

def getTrueLabelsAndPredictions(labels, predictions):
    true_labels = []
    true_predictions = []
    for label, prediction in zip(labels, predictions):
        true_label, grouped_predictions = filter_and_group_lists(label, prediction)

        true_labels.append([label_list[l] for l in true_label])
        true_prediction = list(map(getFinalPrediction, grouped_predictions))

        true_predictions.append([label_list[p] for p in true_prediction])
    
    return true_labels, true_predictions

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    
    true_labels, true_predictions = getTrueLabelsAndPredictions(labels, predictions)

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

# Training

In [None]:
mapping = {"Theme": 1, "Agent": 2, "Patient": 3, "Experiencer": 4, "Co-Theme": 5, "Stimulus": 6, "Location": 7, "Destination": 8}

label2id = {"O": 0}
label2id.update(mapping)
id2label = {v: k for k, v in label2id.items()}
print(id2label)
print(label2id)

In [None]:
model = AutoModelForTokenClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=9, id2label=id2label, label2id=label2id
)
# Training parameters
training_args = TrainingArguments(
    output_dir="thematic_role_model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=50,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=4,
#     load_best_model_at_end=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

# Test on separate part of the dataset

In [None]:
trained_model = AutoModelForTokenClassification.from_pretrained("thematic_role_model/checkpoint-832")
testing_args = TrainingArguments(
    output_dir="./eval_output",
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    eval_steps=100,  # Adjust as needed
)
trainer = Trainer(
    model=trained_model,
    args=testing_args,
    train_dataset=None,
    eval_dataset=tokenized_data["train"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.evaluate()

# Create confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

raw_test_predictions = trainer.predict(tokenized_test_data)

In [None]:
def flatten(xss):
    return [x for xs in xss for x in xs]

test_predictions = np.argmax(raw_test_predictions.predictions, axis=2)
true_test_labels, true_test_predictions = getTrueLabelsAndPredictions(tokenized_test_data['labels'], test_predictions)
true_test_labels = flatten(true_test_labels)
true_test_predictions = flatten(true_test_predictions)
plt.clf()
fig, ax = plt.subplots()

matrix = ConfusionMatrixDisplay.from_predictions(y_true = true_test_labels, 
                                                 y_pred = true_test_predictions,
                                                 xticks_rotation = 'vertical',
                                                 ax = ax)


plt.savefig("confusion.png", dpi = 600, bbox_inches='tight')

# Inference
This block performs inference on a given sentence. Note that it returns a label per Bert token. There is always a special token at the beginning and end of each sentence. The predicted labels for these tokens have been removed. For most simple sentences, one word corresponds to one Bert token.

In [None]:
import torch
from transformers import AutoTokenizer

# text = "I deserve to know the truth."
text = "Tom didn't know when Mary had come to Boston."

tokenizer = AutoTokenizer.from_pretrained("thematic_role_model/checkpoint-832")
inputs = tokenizer(text, return_tensors="pt")

from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained("thematic_role_model/checkpoint-832")
with torch.no_grad():
    logits = model(**inputs).logits

predictions = torch.argmax(logits, dim=2)
predicted_token_class = [model.config.id2label[t.item()] for t in predictions[0]]

print(text)
print(predicted_token_class[1:-1])