In [22]:
import json
import ast
from transformers import AutoTokenizer
from copy import deepcopy

In [47]:
def extract_exs(in_lines):
    buffer = []
    c = 0
    while c < len(in_lines):
        line = in_lines[c]
        if line.startswith('id'):
            docid = line[3:].strip()
            next_line = in_lines[c+1]
            trigger_info = ast.literal_eval(next_line[9:])
            buffer.append([docid, trigger_info])
        
        c += 1

    buffers = {tup[0] : {'outputs': tup[1]} for tup in buffer}

    return buffers

def align_exs(ref_tanls, ref_ogs):
    buffers = []
    for tanl_ex in ref_tanls:
        ref_og = list(filter(lambda info : info['docid'] == tanl_ex['id'], ref_ogs.values()))[0]
        buffers.append((tanl_ex, ref_og))
    
    return buffers

tokenizer = AutoTokenizer.from_pretrained('t5-base')

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [68]:
def span_in_offsets(coref_span, offsets):
    coref_span, start_ind = coref_span
    return any(tup[0] <= start_ind <= tup[1] for tup in offsets) and any(tup[0] <= start_ind + len(coref_span) <= tup[1] for tup in offsets)

for dataset in ['MUC']:
    for trigger_source in ['human']:
        for split in ['train', 'dev', 'test']:
            with open(f'datasets/{dataset}/{trigger_source}/{split}.json', 'r') as f:
                gtt_refs = json.loads(f.read())
            
            for ex in gtt_refs.values():
                offsets = tokenizer(
                    ex['doctext'],
                    truncation=True,
                    max_length=512,
                    return_tensors='pt',
                    return_offsets_mapping=True
                )['offset_mapping'][0]

                truncated_templates = []
                for og_template in ex['templates']:
                    include_template = len(og_template['Triggers']) == 0
                    for trigger_corefs in og_template['Triggers']:
                        if any(span_in_offsets(trigger_tup, offsets) for trigger_tup in trigger_corefs):
                            include_template = True

                    if include_template:
                        template_copy = {
                            'incident_type': og_template['incident_type']
                        }
                        for role, entity_lst in og_template.items():
                            if role != 'incident_type':
                                entities = []
                                for coref_lst in entity_lst:
                                    filtered_coref_lst = []
                                    for coref_span, start_ind in coref_lst:
                                        if span_in_offsets([coref_span, start_ind], offsets):
                                            filtered_coref_lst.append([coref_span, start_ind])
                                    
                                    if len(filtered_coref_lst):
                                        entities.append(filtered_coref_lst)
                                
                                template_copy[role] = entities
                        truncated_templates.append(template_copy)

                ex['templates'] = truncated_templates
            
            with open(f'datasets/{dataset}/{trigger_source}/{split}_trimmed.json', 'w') as f:
                f.write(json.dumps(gtt_refs))