In [1]:
import json
import os
from itertools import islice
import math

In [2]:
def num_templates_split(dataset):
    """
    Maps num templates to document id
    """
    counts = {}
    for id, ex in dataset.items():
        num_templates = len(ex['templates'])
        if not num_templates in counts:
            counts[num_templates] = [id]
        else:
            counts[num_templates].append(id)
    
    return counts

def doc_len_split(dataset):
    """
    Maps document length to document id
    """
    counts = {}
    for id, ex in dataset.items():
        doc_len = len(ex['doctext'])
        if not doc_len in counts:
            counts[doc_len] = [id]
        else:
            counts[doc_len].append(id)
    
    return counts

def argument_spread_split(dataset):
    """
    Maps argument spread to [document id, template index]
    """
    spreads = {}
    for id, ex in dataset.items():
        for template_ind, template in enumerate(ex['templates']):
            locations = []
            for role, entities in template.items():
                if role not in ['incident_type', 'Triggers']:
                    for coref_mentions in entities:
                        locations += [tup[1] for tup in coref_mentions]
            
            if len(locations) != 0:
                location_avg = sum(locations) / len(locations)
                spread = sum([(loc - location_avg) ** 2 for loc in locations]) / len(locations)

                if not spread in spreads:
                    spreads[spread] = [[id, template_ind]]
                else:
                    spreads[spread].append([id, template_ind])
    
    return spreads

def template_ordering_split(dataset):
    """
    Maps template index to [document id, template index]; index is determined by
    mean argument location; templates with no arguments are "last"
    """
    orderings = {}
    for id, ex in dataset.items():
        template_locations = []
        for template_ind, template in enumerate(ex['templates']):
            locations = []
            for role, entities in template.items():
                if role not in ['incident_type', 'Triggers']:
                    for coref_mentions in entities:
                        locations += [tup[1] for tup in coref_mentions]
            
            if len(locations) != 0:
                location_avg = sum(locations) / len(locations)
                template_locations.append([template_ind, location_avg])
            else:
                template_locations.append([template_ind, len(ex['doctext'])])
        
        template_locations = sorted(template_locations, key=lambda tup : tup[1])
        for order_ind, (template_ind, _) in enumerate(template_locations):
            if not order_ind in orderings:
                orderings[order_ind] = [[id, template_ind]]
            else:
                orderings[order_ind].append([id, template_ind])
    
    return orderings

def num_fillers_split(dataset):
    """
    Maps number of fillers to [document id, template index]
    """
    fillers_map = {}
    for id, ex in dataset.items():
        for template_ind, template in enumerate(ex['templates']):
            num_fillers = sum(len(lst) for k, lst in template.items() if not k in ['incident_type', 'Triggers'])
            if not num_fillers in fillers_map:
                fillers_map[num_fillers] = [[id, template_ind]]
            else:
                fillers_map[num_fillers].append([id, template_ind])
    
    return fillers_map

def num_entities_split(dataset):
    """
    Maps number of entities to [document id, template index]
    """
    entities_map = {}
    for id, ex in dataset.items():
        for template_ind, template in enumerate(ex['templates']):
            entities = set()
            for role, entity_lst in template.items():
                if not role in ['incident_type', 'Triggers']:
                    for coref_lst in entity_lst:
                        coref_lst = sorted(coref_lst, key=lambda tup:tup[1])
                        entities.add(json.dumps(coref_lst))
            
            num_entities = len(entities)
            if not num_entities in entities_map:
                entities_map[num_entities] = [[id, template_ind]]
            else:
                entities_map[num_entities].append([id, template_ind])
    
    return entities_map

def group_keys(input_dict, num_groups):
    sorted_keys = sorted(input_dict.keys())
    group_size = math.ceil(len(input_dict) / num_groups)

    grouped_dict = {}
    for i in range(0, len(sorted_keys), group_size):
        current_group = list(islice(sorted_keys, i, i + group_size))
        if current_group:
            range_key = f"{current_group[0]}-{current_group[-1]}"
            group_values = []
            for key in current_group:
                group_values += input_dict[key]
            grouped_dict[range_key] = list(group_values)
    
    return grouped_dict

In [6]:
dirs = ['WikiEvent/event_roles_in_template', 'MUC']
num_groups = 5

for dir in dirs:
    if dir == 'MUC':
        dataset_output_dir = os.path.join('splits/raw', 'MUC')
    else:
        dataset_output_dir = os.path.join('splits/raw', 'WikiEvent')
    os.mkdir(dataset_output_dir)
    os.mkdir(dataset_output_dir.replace('raw', 'bucketed'))
    
    for split_name, splitter in zip(
        ['num_templates', 'doc_len', 'argument_spread', 'template_ordering', 'num_entities'],
        [num_templates_split, doc_len_split, argument_spread_split, template_ordering_split, num_entities_split]
    ):
        output_dir_raw = os.path.join(dataset_output_dir, split_name)
        os.mkdir(output_dir_raw)
        output_dir_bucketed = os.path.join(dataset_output_dir.replace('raw', 'bucketed'), split_name)
        os.mkdir(output_dir_bucketed)
        
        for dataset_split in ['train', 'dev', 'test']:
            with open(os.path.join(os.path.join(dir, 'human'), f'{dataset_split}.json'), 'r') as f:
                dataset = json.loads(f.read())
        
            splits = splitter(dataset)
            with open(os.path.join(output_dir_raw, f'{dataset_split}.json'), 'w') as f:
                f.write(json.dumps(splits))
            
            grouped_splits = group_keys(splits, num_groups)
            with open(os.path.join(output_dir_bucketed, f'{dataset_split}.json'), 'w') as f:
                f.write(json.dumps(grouped_splits))

In [None]:
splits

In [59]:
dataset_dir = 'MUC'

for split in ['train', 'dev', 'test']:
    aggregates = {}
    for trigger_source in ['human', 'llm', 'keyword']:
        with open(f'{dataset_dir}/{trigger_source}/{split}.json') as f:
            examples = json.loads(f.read())
        
        for id, info in examples.items():
            for template_id, template in enumerate(info['templates']):
                del template['Triggers']
                tup = (id, template_id)
                if not tup in aggregates:
                    aggregates[tup] = [json.dumps(template)]
                else:
                    aggregates[tup].append(json.dumps(template))

In [60]:
for tup, aggregate in aggregates.items():
    assert len(aggregate) == 3 and len(set(aggregate)) == 1