# Setup

## Logging

In [1]:
from datasets import logging as ds_log
from transformers import logging as trans_log
import warnings

ds_log.set_verbosity_error()
ds_log.disable_progress_bar()
trans_log.set_verbosity_error()
warnings.filterwarnings("ignore")

# Switch val and test set

In [2]:
use_test_dataset = False

# Data

## Load datasets

In [3]:
from datasets import load_dataset, Features, Sequence, Value


def read_annotations_from_file(path: str, file: str):
    features = Features(
        {
            "PTC": Sequence(feature=Value(dtype="string", id=None), length=-1, id=None),
            "Evidence": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            "Medium": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            "Topic": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            "Cue": Sequence(feature=Value(dtype="string", id=None), length=-1, id=None),
            "Addr": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            "Message": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
            "Source": Sequence(
                feature=Value(dtype="string", id=None), length=-1, id=None
            ),
        }
    )
    ds = load_dataset(
        "json",
        data_files=os.path.join(path, file),
        field="Annotations",
        split="train",
        features=features,
    )
    ds = ds.add_column("FileName", [file] * len(ds))
    return ds

In [4]:
def read_sentences_from_file(path: str, file: str):
    ds = load_dataset(
        "json", data_files=os.path.join(path, file), field="Sentences", split="train"
    )
    ds = ds.add_column("FileName", [file] * len(ds))
    ds = ds.add_column("Sentence", [" ".join(t) for t in ds["Tokens"]])
    return ds

In [5]:
from datasets import concatenate_datasets
import os
from tqdm import tqdm


def read_annotations_from_path(path: str):
    dataset = None

    for file in tqdm(sorted(os.listdir(path))):
        if not dataset:
            dataset = read_annotations_from_file(path, file)
        else:
            dataset = concatenate_datasets(
                [dataset, read_annotations_from_file(path, file)]
            )

    return dataset

In [6]:
def read_sentences_from_path(path: str):
    dataset = None

    for file in tqdm(sorted(os.listdir(path))):
        if not dataset:
            dataset = read_sentences_from_file(path, file)
        else:
            dataset = concatenate_datasets(
                [dataset, read_sentences_from_file(path, file)]
            )

    dataset = dataset.add_column("id", range(len(dataset)))
    return dataset

In [7]:
from datasets import load_from_disk


def read_annotations_train_dataset():
    path_to_train_dataset = "../../data/transformed_datasets/train/annotations"

    if os.path.isdir(path_to_train_dataset):
        result = load_from_disk(path_to_train_dataset)
    else:
        result = read_annotations_from_path("../../data/train/")
        os.makedirs(path_to_train_dataset, exist_ok=True)
        result.save_to_disk(path_to_train_dataset)

    return result

In [8]:
def read_sentences_train_dataset():
    path_to_train_dataset = "../../data/transformed_datasets/train/sentences"

    if os.path.isdir(path_to_train_dataset):
        result = load_from_disk(path_to_train_dataset)
    else:
        result = read_sentences_from_path("../../data/train/")
        os.makedirs(path_to_train_dataset, exist_ok=True)
        result.save_to_disk(path_to_train_dataset)

    return result

In [9]:
def read_annotations_val_dataset():
    path_to_val_dataset = "../../data/transformed_datasets/val/annotations"

    if os.path.isdir(path_to_val_dataset):
        return load_from_disk(path_to_val_dataset)

    result = read_annotations_from_path("../../data/dev/")
    os.makedirs(path_to_val_dataset, exist_ok=True)
    result.save_to_disk(path_to_val_dataset)
    return result

In [10]:
def read_sentences_val_dataset():
    path_to_val_dataset = "../../data/transformed_datasets/val/sentences"

    if os.path.isdir(path_to_val_dataset):
        return load_from_disk(path_to_val_dataset)

    result = read_sentences_from_path("../../data/dev/")
    os.makedirs(path_to_val_dataset, exist_ok=True)
    result.save_to_disk(path_to_val_dataset)
    return result

In [11]:
def read_sentences_test_dataset():
    path_to_test_dataset = "../../data/transformed_datasets/test/sentences"

    if os.path.isdir(path_to_test_dataset):
        return load_from_disk(path_to_test_dataset)

    result = read_sentences_from_path("../../data/test/")
    os.makedirs(path_to_test_dataset, exist_ok=True)
    result.save_to_disk(path_to_test_dataset)
    return result

In [12]:
train_sentences_dataset = read_sentences_train_dataset()
val_sentences_dataset = read_sentences_val_dataset()
test_sentences_dataset = read_sentences_test_dataset()
train_annotations_dataset = read_annotations_train_dataset()
val_annotations_dataset = read_annotations_val_dataset()

## Format datasets for usage in langchain

In [13]:
def get_text_from_label(train_sentences_dataset, row, annotations):
    tokens = []
    for anno in annotations:
        if int(anno.split(":")[0]) == row["SentenceId"]:
            tokens.append(row["Tokens"][int(anno.split(":")[1])])
    return tokens

In [14]:
def build_complete_dataset(sentences_dataset, annotations_dataset, dataset_name):
    path_to_dataset = (
        "../../data/transformed_datasets/" + dataset_name + "/complete-ext-2"
    )
    if os.path.isdir(path_to_dataset):
        return load_from_disk(path_to_dataset)

    ptc, ptc_temp, ptc_mapped, ptc_mapped_temp = [], [], [], []
    evidence, evidence_temp, evidence_mapped, evidence_mapped_temp = [], [], [], []
    medium, medium_temp, medium_mapped, medium_mapped_temp = [], [], [], []
    topic, topic_temp, topic_mapped, topic_mapped_temp = [], [], [], []
    cue, cue_temp, cue_mapped, cue_mapped_temp = [], [], [], []
    addr, addr_temp, addr_mapped, addr_mapped_temp = [], [], [], []
    message, message_temp, message_mapped, message_mapped_temp = [], [], [], []
    source, source_temp, source_mapped, source_mapped_temp = [], [], [], []
    (
        sentence_extended,
        tokens_extended,
        sentence_extended_ids,
    ) = (
        [],
        [],
        [],
    )

    index_in_anno_ds = 0

    for i, row in tqdm(enumerate(sentences_dataset)):
        context = row["Sentence"]
        tokens = row["Tokens"]
        ids = [row["SentenceId"]] * len(row["Tokens"])
        if (
            i + 1 < len(sentences_dataset)
            and sentences_dataset[i + 1]["FileName"] == row["FileName"]
        ):
            context = context + " " + sentences_dataset[i + 1]["Sentence"]
            tokens.extend(sentences_dataset[i + 1]["Tokens"])
            ids.extend(
                [sentences_dataset[i + 1]["SentenceId"]]
                * len(sentences_dataset[i + 1]["Tokens"])
            )
        if (
            i + 2 < len(sentences_dataset)
            and sentences_dataset[i + 2]["FileName"] == row["FileName"]
        ):
            context = context + " " + sentences_dataset[i + 2]["Sentence"]
            tokens.extend(sentences_dataset[i + 2]["Tokens"])
            ids.extend(
                [sentences_dataset[i + 2]["SentenceId"]]
                * len(sentences_dataset[i + 2]["Tokens"])
            )
        sentence_extended.append(context)
        tokens_extended.append(tokens)
        sentence_extended_ids.append(ids)

        if annotations_dataset is not None:
            id_of_next_sentence_with_annotation = (
                int(annotations_dataset[index_in_anno_ds]["Cue"][0].split(":")[0])
                if index_in_anno_ds != len(annotations_dataset)
                else -1
            )

            if row["SentenceId"] != id_of_next_sentence_with_annotation:
                ptc.append([])
                ptc_mapped.append([])
                evidence.append([])
                evidence_mapped.append([])
                medium.append([])
                medium_mapped.append([])
                topic.append([])
                topic_mapped.append([])
                cue.append([])
                cue_mapped.append([])
                addr.append([])
                addr_mapped.append([])
                message.append([])
                message_mapped.append([])
                source.append([])
                source_mapped.append([])
                continue

            while row["SentenceId"] == id_of_next_sentence_with_annotation:
                ptc_temp.append(annotations_dataset[index_in_anno_ds]["PTC"])
                evidence_temp.append(annotations_dataset[index_in_anno_ds]["Evidence"])
                medium_temp.append(annotations_dataset[index_in_anno_ds]["Medium"])
                topic_temp.append(annotations_dataset[index_in_anno_ds]["Topic"])
                cue_temp.append(annotations_dataset[index_in_anno_ds]["Cue"])
                addr_temp.append(annotations_dataset[index_in_anno_ds]["Addr"])
                message_temp.append(annotations_dataset[index_in_anno_ds]["Message"])
                source_temp.append(annotations_dataset[index_in_anno_ds]["Source"])

                ptc_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, ptc_temp[-1])
                )
                evidence_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, evidence_temp[-1])
                )
                medium_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, medium_temp[-1])
                )
                topic_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, topic_temp[-1])
                )
                cue_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, cue_temp[-1])
                )
                addr_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, addr_temp[-1])
                )
                message_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, message_temp[-1])
                )
                source_mapped_temp.append(
                    get_text_from_label(sentences_dataset, row, source_temp[-1])
                )

                index_in_anno_ds += 1
                if index_in_anno_ds == len(annotations_dataset):
                    break
                id_of_next_sentence_with_annotation = int(
                    annotations_dataset[index_in_anno_ds]["Cue"][0].split(":")[0]
                )

            ptc.append(ptc_temp)
            ptc_mapped.append(ptc_mapped_temp)
            evidence.append(evidence_temp)
            evidence_mapped.append(evidence_mapped_temp)
            medium.append(medium_temp)
            medium_mapped.append(medium_mapped_temp)
            topic.append(topic_temp)
            topic_mapped.append(topic_mapped_temp)
            cue.append(cue_temp)
            cue_mapped.append(cue_mapped_temp)
            addr.append(addr_temp)
            addr_mapped.append(addr_mapped_temp)
            message.append(message_temp)
            message_mapped.append(message_mapped_temp)
            source.append(source_temp)
            source_mapped.append(source_mapped_temp)

            ptc_temp, ptc_mapped_temp = [], []
            evidence_temp, evidence_mapped_temp = [], []
            medium_temp, medium_mapped_temp = [], []
            topic_temp, topic_mapped_temp = [], []
            cue_temp, cue_mapped_temp = [], []
            addr_temp, addr_mapped_temp = [], []
            message_temp, message_mapped_temp = [], []
            source_temp, source_mapped_temp = [], []

    res = sentences_dataset.add_column("sentence_extended", sentence_extended)
    res = res.add_column("tokens_extended", tokens_extended)
    res = res.add_column("sentence_extended_ids", sentence_extended_ids)

    if annotations_dataset is not None:
        res = res.add_column("ptc", ptc)
        res = res.add_column("ptc_mapped", ptc_mapped)
        res = res.add_column("evidence", evidence)
        res = res.add_column("evidence_mapped", evidence_mapped)
        res = res.add_column("medium", medium)
        res = res.add_column("medium_mapped", medium_mapped)
        res = res.add_column("topic", topic)
        res = res.add_column("topic_mapped", topic_mapped)
        res = res.add_column("cue", cue)
        res = res.add_column("cue_mapped", cue_mapped)
        res = res.add_column("addr", addr)
        res = res.add_column("addr_mapped", addr_mapped)
        res = res.add_column("message", message)
        res = res.add_column("message_mapped", message_mapped)
        res = res.add_column("source", source)
        res = res.add_column("source_mapped", source_mapped)

    os.makedirs(path_to_dataset, exist_ok=True)
    res.save_to_disk(path_to_dataset)

    return res

In [15]:
train_ds = build_complete_dataset(
    train_sentences_dataset, train_annotations_dataset, "train"
)
val_ds = build_complete_dataset(val_sentences_dataset, val_annotations_dataset, "val")
test_ds = build_complete_dataset(test_sentences_dataset, None, "test")

9093it [00:12, 707.50it/s]
927it [00:01, 708.27it/s]
3067it [00:02, 1276.82it/s]


# Role Prompting

## Extract Roles

In [16]:
file_name_prefix = "qlora-exp008g-llama2-70b-specialized-models-external-context-2"

In [17]:
def extract_roles_from_output(output_string: str):
    res = {
        "ptc": "",
        "evidence": "",
        "medium": "",
        "topic": "",
        "addr": "",
        "message": "",
        "source": "",
    }

    output_rows = [v.strip() for v in output_string.strip().split("\n")]
    error = False

    try:
        if not (
            output_rows[1].startswith("ptc: ")
            and output_rows[2].startswith("evidence: ")
            and output_rows[3].startswith("medium: ")
            and output_rows[4].startswith("topic: ")
            and output_rows[5].startswith("addr: ")
            and output_rows[6].startswith("message: ")
            and output_rows[7].startswith("source: ")
        ):
            error = True
    except IndexError:
        error = True

    try:
        if output_rows[1].startswith("ptc: "):
            res["ptc"] = [
                v.strip().split(" ")[0].strip()
                for v in output_rows[1][4:].strip().split(",")
            ]
    except IndexError:
        pass
    try:
        if output_rows[2].startswith("evidence: "):
            res["evidence"] = [
                v.strip().split(" ")[0].strip()
                for v in output_rows[2][9:].strip().split(",")
            ]
    except IndexError:
        pass
    try:
        if output_rows[3].startswith("medium: "):
            res["medium"] = [
                v.strip().split(" ")[0].strip()
                for v in output_rows[3][7:].strip().split(",")
            ]
    except IndexError:
        pass
    try:
        if output_rows[4].startswith("topic: "):
            res["topic"] = [
                v.strip().split(" ")[0].strip()
                for v in output_rows[4][6:].strip().split(",")
            ]
    except IndexError:
        pass
    try:
        if output_rows[5].startswith("addr: "):
            res["addr"] = [
                v.strip().split(" ")[0].strip()
                for v in output_rows[5][5:].strip().split(",")
            ]
    except IndexError:
        pass
    try:
        if output_rows[6].startswith("message: "):
            res["message"] = [
                v.strip().split(" ")[0].strip()
                for v in output_rows[6][8:].strip().split(",")
            ]
    except IndexError:
        pass
    try:
        if output_rows[7].startswith("source: "):
            res["source"] = [
                v.strip().split(" ")[0].strip()
                for v in output_rows[7][7:].strip().split(",")
            ]
    except IndexError:
        pass

    for key, value in res.items():
        if value == [""] or value == ["#UNK#"]:
            res[key] = ""
        while "#UNK#" in value:
            value.pop(value.index("#UNK#"))
        while type(value) == list and "" in value:
            value.pop(value.index(""))
        res[key] = value

    return res, error

In [18]:
import json


def extract_roles(output_filename_prefix):
    path = "./" + output_filename_prefix + "-outputs/"
    errors = {"no_roles_prefix": 0}

    for file in sorted(os.listdir(path)):
        if file.endswith(".zip"):
            continue
        file_content = {}

        with open(os.path.join(path, file), "r") as f:
            file_content = json.load(f)
            file_content["Outputs"]["Roles_text"] = {}

            for id, roles_for_sentence in file_content["Outputs"]["Roles"].items():
                file_content["Outputs"]["Roles_text"][id] = []

                if roles_for_sentence == []:
                    continue

                for roles_output in roles_for_sentence:
                    file_content["Outputs"]["Roles_text"][id].append([])

                    roles, error = extract_roles_from_output(roles_output[0])
                    if error:
                        errors["no_roles_prefix"] += 1
                    file_content["Outputs"]["Roles_text"][id][-1].append(roles)

        with open(os.path.join(path, file), "w", encoding="utf8") as outfile:
            json.dump(file_content, outfile, indent=3, ensure_ascii=False)

    with open(output_filename_prefix + "-errors.json", "r", encoding="utf8") as f:
        file_content = json.load(f)
        for key, value in file_content.items():
            errors[key] = value
    with open(output_filename_prefix + "-errors.json", "w", encoding="utf8") as outfile:
        json.dump(errors, outfile, indent=3, ensure_ascii=False)

In [19]:
extract_roles(file_name_prefix)

# Map model outputs

In [20]:
import Levenshtein


def count_neighbors(i, seen, skip_index):
    res = 0
    if i - 2 >= 0 and i - 2 != skip_index:
        res += 1 if seen[i - 2] else 0
    if i - 1 >= 0 and i - 1 != skip_index:
        res += 1 if seen[i - 1] else 0
    if i + 1 < len(seen) and i + 1 != skip_index:
        res += 1 if seen[i + 1] else 0
    if i + 2 < len(seen) and i + 2 != skip_index:
        res += 1 if seen[i + 2] else 0
    return res


def calculate_neighborhood_swap(seen, tokens):
    for i, v in enumerate(seen):
        if not v:
            continue

        neigh_c_v = count_neighbors(i, seen, -1)
        neigh = [
            j
            for j, t in enumerate(tokens)
            if seen[j] == False and Levenshtein.distance(t, tokens[i]) <= 1
        ]
        neigh_c_other = [count_neighbors(n, seen, i) for n in neigh]
        if len(neigh_c_other) > 0:
            neigh_c_other_max = max(neigh_c_other)
            if neigh_c_other_max > neigh_c_v:
                return i, neigh[neigh_c_other.index(neigh_c_other_max)]

    return -1, -1

In [21]:
def map_output_list(output_list: list, ids: list, tokens: list, seen_old=None):
    res = []
    seen = [False] * len(tokens)
    if seen_old == None:
        seen_old = [False] * len(tokens)
    error = False

    for output in output_list:
        indices = [
            i
            for i, v in enumerate(tokens)
            if v == output and seen[i] == False and seen_old[i] == False
        ]
        if len(indices) > 0:
            seen[indices[0]] = True
        if len(indices) == 0:
            indices = [
                i
                for i, v in enumerate(tokens)
                if seen[i] == False
                and seen_old[i] == False
                and Levenshtein.distance(output, v) <= 1
            ]
            if len(indices) > 0:
                seen[indices[0]] = True

    if sum(seen) != len(output_list):
        error = True

    changed = True
    while changed:
        changed = False
        i, j = calculate_neighborhood_swap(seen, tokens)
        while i != j:
            seen[i] = False
            seen[j] = True
            changed = True
            i, j = calculate_neighborhood_swap(seen, tokens)

        for i in range(len(seen)):
            if (
                seen[i] == False
                and i != 0
                and i != len(seen) - 1
                and seen[i - 1]
                and seen[i + 1]
                and (
                    tokens[i] == ","
                    or tokens[i] == ":"
                    or tokens[i] == ";"
                    or tokens[i] == "-"
                )
            ):
                seen[i] = True
                changed = True

    for i in range(len(seen)):
        if seen[i]:
            res.append(str(ids[i]) + ":" + str(i))

    return res, error, [v or seen_old[i] for i, v in enumerate(seen)]

In [22]:
def map_outputs(output_filename_prefix, ds):
    path = "./" + output_filename_prefix + "-outputs/"
    errors = {"cue_not_mappable": 0, "roles_not_mappable": 0}

    for file in sorted(os.listdir(path)):
        if file.endswith(".zip"):
            continue
        file_content = {}

        with open(os.path.join(path, file), "r") as f:
            file_content = json.load(f)
            file_content["Annotations"] = []

            for cues_text, roles_text in zip(
                file_content["Outputs"]["Cues_text"].items(),
                file_content["Outputs"]["Roles_text"].items(),
            ):
                id, cues = cues_text
                id, roles_list = roles_text

                if cues == []:
                    continue

                tokens = ds.filter(
                    lambda r: r["FileName"] == file and r["SentenceId"] == int(id)
                )[0]["tokens_extended"]
                ids = ds.filter(
                    lambda r: r["FileName"] == file and r["SentenceId"] == int(id)
                )[0]["sentence_extended_ids"]

                seen_cues = None
                for cue, roles in zip(cues, roles_list):
                    roles = roles[0]

                    cue, error, seen_cues = map_output_list(cue, ids, tokens, seen_cues)
                    if error:
                        errors["cue_not_mappable"] += 1

                    if cue != []:
                        addr, error, _ = map_output_list(
                            roles["addr"],
                            ids,
                            tokens,
                        )
                        if error:
                            errors["roles_not_mappable"] += 1

                        evidence, error, _ = map_output_list(
                            roles["evidence"],
                            ids,
                            tokens,
                        )
                        if error:
                            errors["roles_not_mappable"] += 1

                        medium, error, _ = map_output_list(
                            roles["medium"],
                            ids,
                            tokens,
                        )
                        if error:
                            errors["roles_not_mappable"] += 1

                        message, error, _ = map_output_list(
                            roles["message"],
                            ids,
                            tokens,
                        )
                        if error:
                            errors["roles_not_mappable"] += 1

                        source, error, _ = map_output_list(
                            roles["source"],
                            ids,
                            tokens,
                        )
                        if error:
                            errors["roles_not_mappable"] += 1

                        topic, error, _ = map_output_list(
                            roles["topic"],
                            ids,
                            tokens,
                        )
                        if error:
                            errors["roles_not_mappable"] += 1

                        ptc, error, _ = map_output_list(
                            roles["ptc"],
                            ids,
                            tokens,
                        )
                        if error:
                            errors["roles_not_mappable"] += 1

                        annotation = {
                            "Addr": addr,
                            "Evidence": evidence,
                            "Medium": medium,
                            "Message": message,
                            "Source": source,
                            "Topic": topic,
                            "Cue": cue,
                            "PTC": ptc,
                        }
                        file_content["Annotations"].append(annotation)

        with open(os.path.join(path, file), "w", encoding="utf8") as outfile:
            json.dump(file_content, outfile, indent=3, ensure_ascii=False)

    with open(output_filename_prefix + "-errors.json", "r", encoding="utf8") as f:
        file_content = json.load(f)
        for key, value in file_content.items():
            errors[key] = value
    with open(output_filename_prefix + "-errors.json", "w", encoding="utf8") as outfile:
        json.dump(errors, outfile, indent=3, ensure_ascii=False)

In [23]:
map_outputs(file_name_prefix, test_ds if use_test_dataset else val_ds)

# Prepare zip file

In [24]:
import shutil

if os.path.exists(file_name_prefix + "-outputs/" + file_name_prefix + ".zip"):
    os.remove(file_name_prefix + "-outputs/" + file_name_prefix + ".zip")
shutil.copytree(
    "./" + file_name_prefix + "-outputs", "./" + file_name_prefix + "-outputs/temp"
)
path = "./" + file_name_prefix + "-outputs/temp"

for file in sorted(os.listdir(path)):
    file_content = {}

    with open(os.path.join(path, file), "r") as f:
        file_content = json.load(f)
        file_content.pop("Outputs")

    with open(os.path.join(path, file), "w", encoding="utf8") as outfile:
        json.dump(file_content, outfile, indent=3, ensure_ascii=False)
shutil.make_archive(
    file_name_prefix + "-outputs/temp", "zip", "./" + file_name_prefix + "-outputs/temp"
)
shutil.move(
    file_name_prefix + "-outputs/temp.zip",
    file_name_prefix + "-outputs/" + file_name_prefix + ".zip",
)
shutil.rmtree(file_name_prefix + "-outputs/temp")

# Metric

In [25]:
def compute_metrics(output_filename_prefix):
    assert use_test_dataset == False

    path = "./" + output_filename_prefix + "-outputs/"

    result = {
        "f1": 0,
        "precision": 0,
        "recall": 0,
        "f1_cues": 0,
        "precision_cues": 0,
        "recall_cues": 0,
        "f1_roles": 0,
        "precision_roles": 0,
        "recall_roles": 0,
        # "count_gold_cues": 0,
        # "count_pred_cues": 0,
        # "count_exact_match": 0,
        # "count_partly_match": 0,
        # "count_no_match": 0,
    }

    tp = 0
    fp = 0
    fn = 0
    tp_cues = 0
    fp_cues = 0
    fn_cues = 0
    tp_roles = 0
    fp_roles = 0
    fn_roles = 0

    roles_names = ["Addr", "Evidence", "Medium", "Message", "Source", "Topic", "PTC"]

    for file in sorted(os.listdir(path)):
        if file.endswith(".zip"):
            continue
        file_content = {}

        with open(os.path.join(path, file), "r") as f:
            file_content = json.load(f)

            pred_cue = [val["Cue"] for val in file_content["Annotations"]]
            pred_roles = [
                [val[role] for val in file_content["Annotations"]]
                for role in roles_names
            ]
            gold_cue = [
                val
                for val in val_annotations_dataset.filter(
                    lambda row: row["FileName"] == file, load_from_cache_file=False
                )["Cue"]
            ]
            gold_roles = [
                [
                    val
                    for val in val_annotations_dataset.filter(
                        lambda row: row["FileName"] == file, load_from_cache_file=False
                    )[role]
                ]
                for role in roles_names
            ]

            # result["count_gold_cues"] += len(gold)
            # result["count_pred_cues"] += len(pred)
            # count_exact_match = len([p for p in pred if p in gold])
            # count_partly_match = len(
            #     [
            #         list(set(i) & set(j))
            #         for i in gold
            #         for j in pred
            #         if len(list(set(i) & set(j))) > 0
            #         and len(list(set(i) & set(j))) != len(i)
            #         and len(list(set(i) & set(j))) != len(j)
            #     ]
            # )
            # result["count_exact_match"] += count_exact_match
            # result["count_partly_match"] += count_partly_match
            # result["count_no_match"] += (
            #     len(gold) - count_exact_match - count_partly_match
            # )

            pred_cue_flattend = [v for val in pred_cue for v in val]
            pred_roles_flattend = [
                v for role in pred_roles for val in role for v in val
            ]
            gold_cue_flattend = [v for val in gold_cue for v in val]
            gold_roles_flattend = [
                v for role in gold_roles for val in role for v in val
            ]

            tp_cues += len(list(set(pred_cue_flattend) & set(gold_cue_flattend)))
            fp_cues += len(list(set(pred_cue_flattend) - set(gold_cue_flattend)))
            fn_cues += len(list(set(gold_cue_flattend) - set(pred_cue_flattend)))

            tp_roles += len(list(set(pred_roles_flattend) & set(gold_roles_flattend)))
            fp_roles += len(list(set(pred_roles_flattend) - set(gold_roles_flattend)))
            fn_roles += len(list(set(gold_roles_flattend) - set(pred_roles_flattend)))

    tp = tp_cues + tp_roles
    fp = fp_cues + fp_roles
    fn = fn_cues + fn_roles

    result["precision_cues"] = (
        tp_cues / (tp_cues + fp_cues) if (tp_cues + fp_cues) > 0 else 0
    )
    result["recall_cues"] = (
        tp_cues / (tp_cues + fn_cues) if (tp_cues + fp_cues) > 0 else 0
    )
    result["f1_cues"] = (
        (2 * result["precision_cues"] * result["recall_cues"])
        / (result["precision_cues"] + result["recall_cues"])
        if result["precision_cues"] + result["recall_cues"] > 0
        else 0
    )
    result["precision_roles"] = (
        tp_roles / (tp_roles + fp_roles) if (tp_roles + fp_roles) > 0 else 0
    )
    result["recall_roles"] = (
        tp_roles / (tp_roles + fn_roles) if (tp_roles + fp_roles) > 0 else 0
    )
    result["f1_roles"] = (
        (2 * result["precision_roles"] * result["recall_roles"])
        / (result["precision_roles"] + result["recall_roles"])
        if result["precision_roles"] + result["recall_roles"] > 0
        else 0
    )
    result["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0
    result["recall"] = tp / (tp + fn) if (tp + fp) > 0 else 0
    result["f1"] = (
        (2 * result["precision"] * result["recall"])
        / (result["precision"] + result["recall"])
        if result["precision"] + result["recall"] > 0
        else 0
    )

    return result

In [26]:
metrics = compute_metrics(file_name_prefix)

In [27]:
print("f1_cues:\t", metrics["f1_cues"])
print("precision_cues:\t", metrics["precision_cues"])
print("recall_cues:\t", metrics["recall_cues"])
print()
print("f1_roles:\t", metrics["f1_roles"])
print("precision_roles:", metrics["precision_roles"])
print("recall_roles:\t", metrics["recall_roles"])
print()
print("f1:\t\t", metrics["f1"])
print("precision:\t", metrics["precision"])
print("recall:\t\t", metrics["recall"])
# print("count_gold_cues:\t", metrics["count_gold_cues"])
# print("count_pred_cues:\t", metrics["count_pred_cues"])
# print("count_exact_match:\t", metrics["count_exact_match"])
# print("count_partly_match:\t", metrics["count_partly_match"])
# print("count_no_match:\t\t", metrics["count_no_match"])

f1_cues:	 0.8658865886588658
precision_cues:	 0.8635547576301615
recall_cues:	 0.868231046931408

f1_roles:	 0.8447427293064877
precision_roles: 0.8740740740740741
recall_roles:	 0.8173160173160173

f1:		 0.8477482088024564
precision:	 0.8725309454832763
recall:		 0.8243344115451605
