In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os
sys.path.append(os.path.join(os.path.abspath(os.getcwd()), ".."))
import evaluate
import pandas as pd
import numpy as np
import wandb
import argparse

import torch
from typing import get_args
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
from datasets import Dataset
from certainty import load_file, CACHE_DIR, seed_everything, RANDOM_SEED, EventType, extract_events, TRAIN_FILENAME, TEST_FILENAME, DEV_FILENAME, load_events
OUTPUT_DIR = "../models/blabla"


2025-03-30 16:02:39.104915: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-30 16:02:39.248190: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
import numpy as np
from datasets import load_dataset, load_metric

import transformers
from filelock import FileLock
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    MBartTokenizer,
    default_data_collator,
    set_seed
)
from transformers.trainer_utils import get_last_checkpoint, is_main_process

In [4]:
class FactualitySchema:
    def __init__(self, type_list):
        self.type_list = type_list
        self.role_list = []
        self.type_role_dict = {}

In [5]:
task = "factuality"
text_column = "text"
summary_column = "factuality"
max_source_length = 256
max_target_length = 128
pad_to_max_length = True
source_lang = "en"
ignore_pad_token_for_loss = True
source_prefix = "factuality:"
decoding_format = "tree"
model_name='t5-small'
prefix="Pairs of Factualities and triggers: "


In [6]:
train = load_file('en_train.json')

In [7]:
train = pd.DataFrame(train).drop_duplicates('text').drop_duplicates('events').to_dict("records")
train[0]

{'sent_id': 'bc/CNN_IP_20030329.1600.02/001',
 'text': 'It was in northern Iraq today that an eight artillery round hit the site occupied by Kurdish fighters near Chamchamal',
 'events': [{'event_type': 'Attack',
   'event_polarity': 'Positive',
   'event_genericity': 'Specific',
   'event_modality': 'Asserted',
   'trigger': [['hit'], ['60:63']],
   'arguments': [[['northern Iraq'], ['10:23'], 'Place'],
    [['today'], ['24:29'], 'Time-Within'],
    [['an eight artillery round'], ['35:59'], 'Instrument'],
    [['the site occupied by Kurdish fighters near Chamchamal'],
     ['64:117'],
     'Target']]}]}

In [8]:
for i, sample in enumerate(train):
    triggers = list(map(lambda e: e['trigger'][0][0], sample['events']))
    trigger_set = set(triggers)

    if len(triggers) != len(trigger_set):
        print("Duplicate: " + str(i))
        print(triggers)
        print(trigger_set)
        print(sample['text'])
        break
        

Duplicate: 11
['war', 'war']
{'war'}
That Europe is against the war on humanitarian and moral grounds , but the U.S. is for the war because it wants to profit for the oil companies


In [9]:
new = []
for sample in train:
    pairs = "("
    for i, event in enumerate(sample['events']):
        pair = "(" + event['event_modality'] + " " + event['trigger'][0][0] + ")"
        if i > 0:
            pairs += " "
        pairs += str(pair)
    pairs += ")"
    new.append({'factuality': pairs,
                 'text': sample['text']})

In [10]:
new[2]

{'factuality': '((Asserted pummeled) (Asserted retreat))',
 'text': "That 's because coalition fighter jets pummeled this Iraqi position on the hills above Chamchamal and Iraqi troops made a hasty retreat"}

In [11]:
dataset = Dataset.from_pandas(pd.DataFrame(new))

In [12]:
config = AutoConfig.from_pretrained(
    model_name, 
    cache_dir=CACHE_DIR, 
    #local_files_only=True, 
    trust_remote_code=True
)

In [13]:
config.max_length = max_target_length

In [14]:
tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        cache_dir=CACHE_DIR,
        use_fast=True
)

In [15]:
to_remove_token_list = list()
if tokenizer.bos_token:
    to_remove_token_list += [tokenizer.bos_token]
if tokenizer.eos_token:
    to_remove_token_list += [tokenizer.eos_token]
if tokenizer.pad_token:
    to_remove_token_list += [tokenizer.pad_token]

In [16]:
model = AutoModelForSeq2SeqLM.from_pretrained(
        model_name,
        config=config,
        cache_dir=CACHE_DIR
)

In [17]:
if model.config.decoder_start_token_id is None and isinstance(tokenizer, MBartTokenizer):
    model.config.decoder_start_token_id = tokenizer.lang_code_to_id['en']
if model.config.decoder_start_token_id is None:
    raise ValueError(
        "Make sure that `config.decoder_start_token_id` is correctly defined")

In [18]:
column_names = dataset.column_names

In [19]:
decoding_type_schema = FactualitySchema(['Asserted', 'Other'])

In [20]:
padding = "max_length" if pad_to_max_length else False

In [21]:
def preprocess_function(examples):
    inputs = examples[text_column]
    targets = examples[summary_column]
    inputs = [prefix + inp for inp in inputs]
    model_inputs = tokenizer(
        inputs, max_length=1024, padding=padding, truncation=True, return_tensors="pt"
    )

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True, return_tensors="pt")

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length" and ignore_pad_token_for_loss:
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs



In [22]:
print(dataset[2])

{'factuality': '((Asserted pummeled) (Asserted retreat))', 'text': "That 's because coalition fighter jets pummeled this Iraqi position on the hills above Chamchamal and Iraqi troops made a hasty retreat"}


In [23]:
train_dataset = dataset

In [24]:
train_dataset = train_dataset.map(
            preprocess_function,
            batched=True,
            num_proc=1,
            remove_columns=train_dataset.column_names
        )

Map:   0%|          | 0/3246 [00:00<?, ? examples/s]



In [25]:
train_dataset.set_format("pt", columns=["input_ids", "labels", "attention_mask"], output_all_columns=True)

In [26]:
label_pad_token_id = - \
    100 if True else tokenizer.pad_token_id

In [27]:
data_collator = DataCollatorForSeq2Seq(
            tokenizer,
            model=model,
            label_pad_token_id=label_pad_token_id,
            pad_to_multiple_of=8 if True else None,
)

In [28]:
from transformers import AutoTokenizer
from typing import Dict

def get_label_name_tree(label_name_list, tokenizer, end_symbol='<end>'):
    sub_token_tree = dict()

    label_tree = dict()
    for typename in label_name_list:
        after_tokenized = tokenizer.encode(typename, add_special_tokens=False)
        label_tree[typename] = after_tokenized

    for _, sub_label_seq in label_tree.items():
        parent = sub_token_tree
        for value in sub_label_seq:
            if value not in parent:
                parent[value] = dict()
            parent = parent[value]

        parent[end_symbol] = None  # Mark end of valid sequence

    return sub_token_tree


class PrefixTree:
    def __init__(self, label_name_list, tokenizer, end_symbol='<end>'):
        self.label_name_list = label_name_list
        self._tokenizer = tokenizer
        self.label_name_tree = get_label_name_tree(label_name_list, tokenizer, end_symbol)
        self._end_symbol = end_symbol

    def is_end_of_tree(self, tree: Dict):
        return len(tree) == 1 and self._end_symbol in tree


if __name__ == "__main__":
    factuality_labels = ["Asserted", "NotAsserted"]  # Replace event types with factuality labels

    test_tokenizer = AutoTokenizer.from_pretrained('t5-small')

    factuality_tree = get_label_name_tree(factuality_labels, test_tokenizer)

    # Function to print the tree structure (if needed)
    def print_tree(tree, indent=0):
        for key, value in tree.items():
            print("  " * indent + str(key))
            if isinstance(value, dict):
                print_tree(value, indent + 1)

    print_tree(factuality_tree)  # Visualize the tree

282
  7
    49
      1054
        <end>
933
  188
    7
      7
        49
          1054
            <end>


In [29]:

from typing import Union, List, Callable, Dict, Tuple, Any, Optional

def match_sublist(the_list, to_match):
    """

    :param the_list: [1, 2, 3, 4, 5, 6, 1, 2, 4, 5]
    :param to_match: [1, 2]
    :return:
        [(0, 1), (6, 7)]
    """
    len_to_match = len(to_match)
    matched_list = list()
    for index in range(len(the_list) - len_to_match + 1):
        if to_match == the_list[index:index + len_to_match]:
            matched_list += [(index, index + len_to_match - 1)]
    return matched_list


def find_bracket_position(generated_text, _type_start, _type_end):
    bracket_position = {_type_start: list(), _type_end: list()}
    for index, char in enumerate(generated_text):
        if char in bracket_position:
            bracket_position[char] += [index]
    return bracket_position


def generated_search_src_sequence(generated, src_sequence, end_sequence_search_tokens=None):

    if len(generated) == 0:
        # It has not been generated yet. All SRC are valid.
        return src_sequence

    matched_tuples = match_sublist(the_list=src_sequence, to_match=generated)

    valid_token = list()
    for _, end in matched_tuples:
        next_index = end + 1
        if next_index < len(src_sequence):
            valid_token += [src_sequence[next_index]]

    if end_sequence_search_tokens:
        valid_token += end_sequence_search_tokens

    return valid_token


class ConstraintDecoder:
    def __init__(self, tokenizer, source_prefix):
        self.tokenizer = tokenizer
        self.source_prefix = source_prefix
        self.source_prefix_tokenized = tokenizer.encode(source_prefix,
                                                        add_special_tokens=False) if source_prefix else []

    def get_state_valid_tokens(self, src_sentence: List[str], tgt_generated: List[str]) -> List[str]:
        pass

    def constraint_decoding(self, src_sentence, tgt_generated):
        if self.source_prefix_tokenized:
            # Remove Source Prefix for Generation
            src_sentence = src_sentence[len(self.source_prefix_tokenized):]

        valid_token_ids = self.get_state_valid_tokens(
            src_sentence.tolist(),
            tgt_generated.tolist()
        )

        # return self.tokenizer.convert_tokens_to_ids(valid_tokens)
        return valid_token_ids



class TreeConstraintDecoder(ConstraintDecoder):
    def __init__(self, tokenizer, type_schema, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
        self.tree_end = '<tree-end>'
        self.type_tree = get_label_name_tree(
            type_schema.type_list, self.tokenizer, end_symbol=self.tree_end)
        self.role_tree = get_label_name_tree(
            type_schema.role_list, self.tokenizer, end_symbol=self.tree_end)
        self.type_start = self.tokenizer.convert_tokens_to_ids(["("])[0]
        self.type_end = self.tokenizer.convert_tokens_to_ids([")"])[0]

    def check_state(self, tgt_generated):
        if tgt_generated[-1] == self.tokenizer.pad_token_id:
            return 'start', -1

        special_token_set = {self.type_start, self.type_end}
        special_index_token = list(
            filter(lambda x: x[1] in special_token_set, list(enumerate(tgt_generated))))

        last_special_index, last_special_token = special_index_token[-1]

        if len(special_index_token) == 1:
            if last_special_token != self.type_start:
                return 'error', 0

        bracket_position = find_bracket_position(
            tgt_generated, _type_start=self.type_start, _type_end=self.type_end)
        start_number, end_number = len(bracket_position[self.type_start]), len(
            bracket_position[self.type_end])

        if start_number == end_number:
            return 'end_generate', -1
        if start_number == end_number + 1:
            state = 'start_first_generation'
        elif start_number == end_number + 2:
            state = 'generate_trigger'
        elif start_number == end_number + 3:
            state = 'generate_role'
        else:
            state = 'error'
        return state, last_special_index

    def search_prefix_tree_and_sequence(self, generated: List[str], prefix_tree: Dict, src_sentence: List[str],
                                        end_sequence_search_tokens: List[str] = None):
        """
        Generate Type Name + Text Span
        :param generated:
        :param prefix_tree:
        :param src_sentence:
        :param end_sequence_search_tokens:
        :return:
        """
        tree = prefix_tree
        for index, token in enumerate(generated):
            tree = tree[token]
            is_tree_end = len(tree) == 1 and self.tree_end in tree

            if is_tree_end:
                valid_token = generated_search_src_sequence(
                    generated=generated[index + 1:],
                    src_sequence=src_sentence,
                    end_sequence_search_tokens=end_sequence_search_tokens,
                )
                return valid_token

            if self.tree_end in tree:
                try:
                    valid_token = generated_search_src_sequence(
                        generated=generated[index + 1:],
                        src_sequence=src_sentence,
                        end_sequence_search_tokens=end_sequence_search_tokens,
                    )
                    return valid_token
                except IndexError:
                    # Still search tree
                    continue

        valid_token = list(tree.keys())
        return valid_token

    def get_state_valid_tokens(self, src_sentence, tgt_generated):
        """

        :param src_sentence:
        :param tgt_generated:
        :return:
            List[str], valid token list
        """
        if self.tokenizer.eos_token_id in src_sentence:
            src_sentence = src_sentence[:src_sentence.index(
                self.tokenizer.eos_token_id)]

        state, index = self.check_state(tgt_generated)


        if state == 'error':
            print("Error:")
            print("Src:", src_sentence)
            print("Tgt:", tgt_generated)
            valid_tokens = [self.tokenizer.eos_token_id]

        elif state == 'start':
            valid_tokens = [self.type_start]

        elif state == 'start_first_generation':
            valid_tokens = [self.type_start, self.type_end]

        elif state == 'generate_trigger':

            if tgt_generated[-1] == self.type_start:
                # Start Event Label
                return list(self.type_tree.keys())

            elif tgt_generated[-1] == self.type_end:
                # EVENT_TYPE_LEFT: Start a new role
                # EVENT_TYPE_RIGHT: End this event
                return [self.type_start, self.type_end]
            else:
                valid_tokens = self.search_prefix_tree_and_sequence(
                    generated=tgt_generated[index + 1:],
                    prefix_tree=self.type_tree,
                    src_sentence=src_sentence,
                    end_sequence_search_tokens=[self.type_start, self.type_end]
                )

        elif state == 'generate_role':

            if tgt_generated[-1] == self.type_start:
                # Start Role Label
                return list(self.role_tree.keys())

            generated = tgt_generated[index + 1:]
            valid_tokens = self.search_prefix_tree_and_sequence(
                generated=generated,
                prefix_tree=self.role_tree,
                src_sentence=src_sentence,
                end_sequence_search_tokens=[self.type_end]
            )

        elif state == 'end_generate':
            valid_tokens = [self.tokenizer.eos_token_id]

        else:
            raise NotImplementedError(
                'State `%s` for %s is not implemented.' % (state, self.__class__))

        return valid_tokens


In [30]:
def get_constraint_decoder(tokenizer, type_schema, decoding_schema, source_prefix=None):
    return TreeConstraintDecoder(tokenizer=tokenizer, type_schema=type_schema, source_prefix=source_prefix)

In [31]:
from transformers import Seq2SeqTrainingArguments
from dataclasses import dataclass, field

@dataclass
class ConstraintSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
    """
    Parameters:
        constraint_decoding (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether to use Constraint Decoding
        structure_weight (:obj:`float`, `optional`, defaults to :obj:`None`):
    """
    constraint_decoding: bool = field(default=False, metadata={"help": "Whether to Constraint Decoding or not."})
    label_smoothing_sum: bool = field(default=False,
                                      metadata={"help": "Whether to use sum token loss for label smoothing"})


In [32]:

training_args = ConstraintSeq2SeqTrainingArguments(
        output_dir=OUTPUT_DIR,
        learning_rate=5e-5,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        constraint_decoding=True,
        num_train_epochs=10,
        predict_with_generate=True,
        label_smoothing_factor=0.1,  # Correct parameter
        eval_strategy="epoch",
        save_strategy="no",
)

In [33]:
from transformers import (
    PreTrainedTokenizer,
    EvalPrediction,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    default_data_collator,
)
import torch.nn as nn
from typing import Union, List, Callable, Dict, Tuple, Any, Optional


class ConstraintSeq2SeqTrainer(Seq2SeqTrainer):
    def __init__(self, decoding_type_schema=None, decoding_format='tree', source_prefix=None, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.decoding_format = decoding_format
        self.decoding_type_schema = decoding_type_schema
        print(self.decoding_format)
        print(self.decoding_type_schema)

        # Label smoothing by sum token loss, different from different Label smootheing

        if self.args.label_smoothing_sum and self.args.label_smoothing_factor != 0:
            self.label_smoother = SumLabelSmoother(epsilon=self.args.label_smoothing_factor)
            print('Using %s' % self.label_smoother)
        elif self.args.label_smoothing_factor != 0:
            print('Using %s' % self.label_smoother)
        else:
            self.label_smoother = None

        print(self.label_smoother)
        if self.args.constraint_decoding:
            self.constraint_decoder = get_constraint_decoder(tokenizer=self.tokenizer,
                                                             type_schema=self.decoding_type_schema,
                                                             decoding_schema=self.decoding_format,
                                                             source_prefix=source_prefix)
            print(self.constraint_decoder)
        else:
            self.constraint_decoder = None
        print("Trainer initialized! Training will use constraint decoding?", self.args.constraint_decoding)

    def prediction_step(
            self,
            model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]],
            prediction_loss_only: bool,
            ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """

        def prefix_allowed_tokens_fn(batch_id, sent):
            # print(self.tokenizer.convert_ids_to_tokens(inputs['labels'][batch_id]))
            src_sentence = inputs['input_ids'][batch_id]
            return self.constraint_decoder.constraint_decoding(src_sentence=src_sentence,
                                                               tgt_generated=sent)

        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model=model,
                inputs=inputs,
                prediction_loss_only=prediction_loss_only,
                ignore_keys=ignore_keys,
                prefix_allowed_tokens_fn=prefix_allowed_tokens_fn if self.constraint_decoder else None,
            )

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)

        gen_kwargs = {
            "max_length": self.model.config.max_length,
            "num_beams": self.model.config.num_beams,
            "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn if self.constraint_decoder else None,
        }

        generated_tokens = self.model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            **gen_kwargs,
        )

        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

        with torch.no_grad():
            outputs = model(**inputs)
            if has_labels:
                if self.label_smoother is not None:
                    loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
                else:
                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
            else:
                loss = None

        if self.args.prediction_loss_only:
            return loss, None, None

        labels = inputs["labels"]
        if labels.shape[-1] < gen_kwargs["max_length"]:
            labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])

        return loss, generated_tokens, labels



In [52]:
trainer = ConstraintSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset.select(range(200)),
    eval_dataset=train_dataset.select(range(50)),
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    decoding_type_schema=decoding_type_schema,
    decoding_format='tree',
    source_prefix=prefix,
)



  super().__init__(*args, **kwargs)
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


tree
<__main__.FactualitySchema object at 0x75873977caf0>
Using LabelSmoother(epsilon=0.1, ignore_index=-100)
LabelSmoother(epsilon=0.1, ignore_index=-100)
<__main__.TreeConstraintDecoder object at 0x7586042877c0>
Trainer initialized! Training will use constraint decoding? True


In [53]:
trainer.train()

Epoch,Training Loss,Validation Loss,Trigger Precision,Trigger Recall,Trigger F1,Discovered Other Recall,Discovered Other Precision,Discovered Other F1,Discovered Asserted Recall,Discovered Asserted Precision,Discovered Asserted F1,Asserted Precision,Asserted Recall,Asserted F1,Other Precision,Other Recall,Other F1
1,No log,1.585396,0.757576,0.769231,0.763359,0.0,0.0,0.0,1.0,0.76,0.863636,0.575758,0.745098,0.649573,0.0,0.0,0.0
2,No log,1.585131,0.666667,0.892308,0.763158,0.0,0.0,0.0,1.0,0.758621,0.862745,0.505747,0.862745,0.637681,0.0,0.0,0.0
3,No log,1.578724,1.0,0.830769,0.907563,0.0,0.0,0.0,1.0,0.777778,0.875,0.777778,0.823529,0.8,0.0,0.0,0.0
4,No log,1.558194,1.0,0.846154,0.916667,0.0,0.0,0.0,1.0,0.781818,0.877551,0.781818,0.843137,0.811321,0.0,0.0,0.0
5,1.643700,1.556314,0.797297,0.907692,0.848921,0.0,0.0,0.0,1.0,0.779661,0.87619,0.621622,0.901961,0.736,0.0,0.0,0.0
6,1.643700,1.558743,0.701149,0.938462,0.802632,0.0,0.0,0.0,1.0,0.786885,0.880734,0.551724,0.941176,0.695652,0.0,0.0,0.0
7,1.643700,1.550248,0.701149,0.938462,0.802632,0.0,0.0,0.0,1.0,0.786885,0.880734,0.551724,0.941176,0.695652,0.0,0.0,0.0
8,1.643700,1.543612,0.819444,0.907692,0.861314,0.0,0.0,0.0,1.0,0.779661,0.87619,0.638889,0.901961,0.747967,0.0,0.0,0.0
9,1.643700,1.547158,0.808219,0.907692,0.855072,0.0,0.0,0.0,1.0,0.779661,0.87619,0.630137,0.901961,0.741935,0.0,0.0,0.0
10,1.644500,1.5483,0.810811,0.923077,0.863309,0.0,0.0,0.0,1.0,0.783333,0.878505,0.635135,0.921569,0.752,0.0,0.0,0.0


TrainOutput(global_step=1000, training_loss=1.6441085205078125, metrics={'train_runtime': 296.9471, 'train_samples_per_second': 6.735, 'train_steps_per_second': 3.368, 'total_flos': 541367205888000.0, 'train_loss': 1.6441085205078125, 'epoch': 10.0})

In [51]:
import re

def parse_sample(sample):
    pairs = re.findall(r'\(?\s*(\w+)\s+((?:\w+\s*)+?)(?=\s*\)?(?:\s*\(|\)\))|\s*\)\))', sample)
    return [(typ, content.strip()) for typ, content in pairs]

def get_word_fact(parsed, is_true):
    word_fact = {}
    for sample in parsed:
        factuality = sample[0]
        word = sample[1]
        if is_true:
            polarity = sample[2]
            genericity = sample[3]
            e_type = sample[4]
            text = sample[5]
        if word in word_fact:
            if is_true:
                word_fact[word].append((factuality, polarity, genericity, e_type, text))
            else:
                word_fact[word].append((factuality,))
        else:
            if is_true:
                word_fact[word] = [(factuality, polarity, genericity, e_type, text)]
            else:
                word_fact[word] = [(factuality,)]
    return word_fact
    
eval_set = train
df = None

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(
        preds, skip_special_tokens=False)
    if True:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(
        labels, skip_special_tokens=False)

    def clean_str(x_str):
        for to_remove_token in to_remove_token_list:
            x_str = x_str.replace(to_remove_token, '')
        return x_str.strip()

    decoded_preds = [clean_str(x) for x in decoded_preds]
    decoded_labels = [clean_str(x) for x in decoded_labels]

    parsed_pred = list(map(parse_sample, decoded_preds))
    parsed_true = []
    for sample in eval_set[:len(parsed_pred)]:
        parsed_true.append([(event['event_modality'], 
                             event['trigger'][0][0], 
                             event['event_polarity'], 
                             event['event_genericity'], 
                             event['event_type'],
                     sample['text']) for event in sample["events"]])
        
    trues = [get_word_fact(sample, True) for sample in parsed_true]
    preds = [get_word_fact(sample, False) for sample in parsed_pred]
    spurious = []
    undiscovered = []
    discovered = []
    for true_wf, pred_wf in zip(trues, preds):
        for key, value in pred_wf.items():
            if key not in true_wf:
                spurious += [
                    {"true": "Other" if el[0] == 'Asserted' else "Asserted",
                     "pred": el[0],
                     "trigger": key,
                     "label": "spurious"
                    }
                    for el in value
                ]
            elif len(value) > len(true_wf[key]):
                # Key is in true, but length is longer for pred, we then have more spurious events
                spurious += [
                    {"true": "Other" if el[0] == 'Asserted' else "Asserted",
                     "pred": el[0],
                     "trigger": key,
                     "label": "spurious"
                    }
                    for el in value[len(true_wf[key]):]
                ]
            else:
                #key is in true, but length is shorter than for true, we have undiscovered events
                undiscovered += [
                    {"true": el[0],
                     "pred": "Other" if el[0] == 'Asserted' else "Asserted",
                     "trigger": key,
                     "polarity": el[1],
                     "genericity": el[2],
                     "type": el[3],
                     "text": el[4],
                     "label": "undiscovered"
                    }
                     for el in true_wf[key][len(value):]
                ]
            
        for key, value in true_wf.items():
            if key in pred_wf:
                discovered += [
                    {"true": t[0],
                     "pred": p[0],
                     "polarity": t[1],
                     "genericity": t[2],
                     "type": t[3],
                     "trigger": key,
                     "label": "discovered",
                     "text": t[4]}
                    for t, p in zip(value, pred_wf[key])
                ]
            else:
                undiscovered += [
                     {"true": t[0],
                      "pred": "Other" if t[0] == 'Asserted' else "Asserted",
                      "polarity": t[1],
                      "genericity": t[2],
                      "type": t[3],
                      "trigger": key,
                      "label": "undiscovered",
                      "text": t[4]}
                    for t in value
                ]
    df = pd.DataFrame(spurious + discovered + undiscovered)
    trigger_fp = len(spurious)
    trigger_fn = len(undiscovered)
    trigger_tp = len(discovered)
    trigger_precision = trigger_tp / (trigger_tp + trigger_fp)
    trigger_recall = trigger_tp / (trigger_tp + trigger_fn)
    trigger_f1 = (2*trigger_precision*trigger_recall) / (trigger_precision + trigger_recall)

    discovered_other_fp = len(df[(df['label'] == 'discovered') & (df['true'] == 'Asserted') & (df['pred'] == 'Other')])
    discovered_other_fn = len(df[(df['label'] == 'discovered') & (df['true'] == 'Other') & (df['pred'] == 'Asserted')])
    discovered_other_tp = len(df[(df['label'] == 'discovered') & (df['true'] == 'Other') & (df['pred'] == 'Other')])
    
    discovered_asserted_fp = len(df[(df['label'] == 'discovered') & (df['true'] == 'Other') & (df['pred'] == 'Asserted')])
    discovered_asserted_fn = len(df[(df['label'] == 'discovered') & (df['true'] == 'Asserted') & (df['pred'] == 'Other')])
    discovered_asserted_tp = len(df[(df['label'] == 'discovered') & (df['true'] == 'Asserted') & (df['pred'] == 'Asserted')])

    discovered_other_precision = discovered_other_tp / (discovered_other_tp + discovered_other_fp) if discovered_other_tp + discovered_other_fp > 0 else 0.0
    discovered_other_recall = discovered_other_tp / (discovered_other_tp + discovered_other_fn) if discovered_other_tp + discovered_other_fn > 0 else 0.0
    discovered_other_f1 = (2*discovered_other_precision * discovered_other_recall)/(discovered_other_recall + discovered_other_precision) if (discovered_other_recall + discovered_other_precision) else 0.0
    
    discovered_asserted_precision = discovered_asserted_tp / (discovered_asserted_tp + discovered_asserted_fp) if (discovered_asserted_tp + discovered_asserted_fp) > 0 else 0.0
    discovered_asserted_recall = discovered_asserted_tp / (discovered_asserted_tp + discovered_asserted_fn) if (discovered_asserted_tp + discovered_asserted_fn) > 0 else 0.0
    discovered_asserted_f1 = (2*discovered_asserted_precision * discovered_asserted_recall)/(discovered_asserted_recall + discovered_asserted_precision) if (discovered_asserted_recall + discovered_asserted_precision) > 0 else 0.0

    asserted_fp = len(df[(df['label'] == 'spurious') & (df['pred'] == 'Asserted')])
    asserted_fn = len(df[(df['label'] == 'undiscovered') & (df['true'] == 'Asserted')])
    other_fp = len(df[(df['label'] == 'spurious') & (df['pred'] == 'Other')])
    other_fn = len(df[(df['label'] == 'undiscovered') & (df['true'] == 'Other')])

    tot_as_fp = asserted_fp + discovered_asserted_fp
    tot_as_fn = asserted_fn + discovered_asserted_fn
    asserted_precision = discovered_asserted_tp / (discovered_asserted_tp + tot_as_fp) if (discovered_asserted_tp + tot_as_fp) > 0 else 0.0
    asserted_recall = discovered_asserted_tp / (discovered_asserted_tp + tot_as_fn) if (discovered_asserted_tp + tot_as_fn) > 0 else 0.0
    asserted_f1 = (2*asserted_precision * asserted_recall) / (asserted_precision + asserted_recall) if (asserted_precision + asserted_recall) > 0 else 0.0
    
    tot_ot_fp = other_fp + discovered_other_fp
    tot_ot_fn = other_fn + discovered_other_fn
    other_precision = discovered_other_tp / (discovered_other_tp + tot_ot_fp) if (discovered_other_tp + tot_ot_fp) > 0 else 0.0
    other_recall = discovered_other_tp / (discovered_other_tp + tot_ot_fn) if (discovered_other_tp + tot_ot_fn) > 0 else 0.0
    other_f1 = (2*other_precision * other_recall) / (other_precision + other_recall) if (other_precision + other_recall) > 0 else 0.0

    df.to_csv('../results/text2event_results.csv')
    return {
        "trigger_precision": trigger_precision,
        "trigger_recall": trigger_recall,
        "trigger_f1": trigger_f1,
        "discovered_other_recall": discovered_other_recall,
        "discovered_other_precision": discovered_other_precision,
        "discovered_other_f1": discovered_other_f1,
        "discovered_asserted_recall": discovered_asserted_recall,
        "discovered_asserted_precision": discovered_asserted_precision,
        "discovered_asserted_f1": discovered_asserted_f1,
        "asserted_precision": asserted_precision,
        "asserted_recall": asserted_recall,
        "asserted_f1": asserted_f1,
        "other_precision": other_precision,
        "other_recall": other_recall,
        "other_f1": other_f1
    }