In [1]:
import syft as sy
from syft.store.blob_storage import BlobStorageConfig, BlobStorageClientConfig
from syft.store.blob_storage.seaweedfs import SeaweedFSClient, SeaweedFSClientConfig
from syft import ActionObject
from syft.service.action.action_data_empty import ActionFileData
from syft.service.queue.zmq_queue import ZMQQueueConfig, ZMQClientConfig
from collections import defaultdict



In [2]:
node = sy.orchestra.launch(name="test-domain-helm2", dev_mode=True,
                           reset=True,
                           n_consumers=4,
                           create_producer=True)
client = node.login(email="info@openmined.org", password="changethis")

Staging Protocol Changes...
Logged into <test-domain-helm2: High side Domain> as <info@openmined.org>


```bash
docker run --entrypoint /bin/sh -p 8333:8333 -p 8888:8888 chrislusf/seaweedfs -c "echo 's3.configure -access_key admin -secret_key admin -user iam -actions Read,Write,List,Tagging,Admin -apply' | weed shell > /dev/null 2>&1 & weed server -s3 -s3.port=8333 -master.volumeSizeLimitMB=2048"
```

In [3]:
blob_config = BlobStorageConfig(client_type=SeaweedFSClient,
                                client_config=SeaweedFSClientConfig(host="http://0.0.0.0",
                                                                    port="8333",
                                                                    access_key="admin",
                                                                    secret_key="admin",
                                                                    default_bucket_name="test_bucket",
                                                                    region="us-east-1")
)

In [4]:
node.python_node.init_blob_storage(blob_config)

# Inputs

In [5]:
# TODO: fix way we send list of files
scenario_obj = ActionObject.from_obj([
    sy.ActionObject.from_path(path="short_new_scenario.jsonl").send(client).syft_action_data for i in range(1)])

scenario_files_ptr = scenario_obj.send(client)

# scenario_obj = ActionObject.from_obj([
#     sy.ActionObject.from_path(path="scenario_data.jsonl").send(client).syft_action_data for i in range(2)])

# scenario_files_ptr = scenario_obj.send(client)

In [6]:
# TODO: fix way we send list of files
input_obj = ActionObject.from_obj([
    sy.ActionObject.from_path("short_input.jsonl").send(client).syft_action_data for i in range(1)])
input_files_ptr = input_obj.send(client)

In [7]:
# for line in input_files_ptr.syft_action_data[0].iter_lines():
#     print(line)

# Syft functions

In [8]:
@sy.syft_function()
def compute_document_data_overlap(domain, scenario_file, input_files, n):
    print("starting overlap computation")
    from nltk import ngrams
    from collections import defaultdict
    from string import punctuation
    import re, json

    r = re.compile(r"[\s{}]+".format(re.escape(punctuation)))
    
    def create_ngram_index(light_scenarios, n_values, stats_key_counts):
        ngram_index = {n:{}  for n in n_values}
        for scenario in light_scenarios:
            for n in n_values:
                stats_key = scenario['scenario_key'] + '_' + str(n)
                stats_key_counts[stats_key] = len(scenario['instances'])
                for instance in scenario['instances']:
                    id = instance['id']                    
                    input_tokens = r.split(instance['input'].lower())
                    for input_ngram in ngrams(input_tokens, n):
                        if input_ngram not in ngram_index[n]:
                            ngram_index[n][input_ngram] = set()
                        ngram_index[n][input_ngram].add(stats_key + '+' + id + '+' + 'input')

                    # compute reference ngrams
                    for reference in instance['references']:
                        reference_unigrams = r.split(reference.lower())
                        for reference_ngram in ngrams(reference_unigrams, n):
                            if reference_ngram not in ngram_index[n]:
                                ngram_index[n][reference_ngram] = set()
                            ngram_index[n][reference_ngram].add(stats_key + '+' + id + '+' + 'references')
        return ngram_index
    
    # # SETUP
    print("preparing scenarios and creating indexes")
    light_scenarios = []
    for light_scenario_json in scenario_file.iter_lines():
        light_scenario_dict: dict = json.loads(light_scenario_json)

        light_scenario_key_dict: dict = light_scenario_dict["scenario_key"]
        scenario_spec = str(light_scenario_key_dict["scenario_spec"])
        light_scenario_key = scenario_spec + '_' + light_scenario_key_dict["split"]
        light_instances = [
            {
                'input': instance_dict['input'], 
                'references': instance_dict['references'], 
                'id': instance_dict["id"]
            }
            for instance_dict in light_scenario_dict["instances"]
        ]
        light_scenarios.append({'scenario_key': light_scenario_key, 'instances': light_instances})
        
    stats_key_counts = defaultdict(int)
    
    ngram_index = create_ngram_index(
        light_scenarios=light_scenarios, n_values=[n], stats_key_counts=stats_key_counts
    )
    
    r = re.compile(r"[\s{}]+".format(re.escape(punctuation)))
    stats_key_to_input_ids = defaultdict(set)
    stats_key_to_reference_ids = defaultdict(set)
    entry_overlap_key_to_ngram_counts = {}
    print("computing overlap")
    
#     domain.init_progress(len(input_files))
    
    for input_file in input_files:
        for line in input_file.iter_lines():
            document = json.loads(line)["text"]
            document_tokens = r.split(document.lower())
            for n in ngram_index.keys():
                for document_ngram in ngrams(document_tokens, n):
                    if document_ngram in ngram_index[n]:
                        for entry_overlap_key in ngram_index[n][document_ngram]:
                            stats_key, id, part = entry_overlap_key.split("+")
                            if part == "input":
                                stats_key_to_input_ids[stats_key].add(id)
                            elif part == "references":
                                stats_key_to_reference_ids[stats_key].add(id)
                            if entry_overlap_key in entry_overlap_key_to_ngram_counts:
                                if document_ngram not in entry_overlap_key_to_ngram_counts[entry_overlap_key]:
                                    entry_overlap_key_to_ngram_counts[entry_overlap_key][document_ngram] = 0
                            else:
                                entry_overlap_key_to_ngram_counts[entry_overlap_key] = {}
                                entry_overlap_key_to_ngram_counts[entry_overlap_key][document_ngram] = 0
                            entry_overlap_key_to_ngram_counts[entry_overlap_key][document_ngram] += 1
        domain.update_progress(1)
    print("done")
    
    return stats_key_to_input_ids, stats_key_to_reference_ids, stats_key_counts, entry_overlap_key_to_ngram_counts

In [9]:
client.code.submit(compute_document_data_overlap)

In [10]:
@sy.syft_function_single_use(input_files=input_files_ptr, scenario_files=scenario_files_ptr)
def main_function(domain, input_files, scenario_files):
    N = [5, 9, 13]
    jobs = []
    for n in N[:1]:
        for scenario_file in scenario_files:
            batch_job = domain.launch_job(
                compute_document_data_overlap,
                scenario_file=scenario_file,
                input_files=input_files,
                n=n
            )
            jobs.append(batch_job)

    return None


In [11]:
client.code.request_code_execution(main_function)
client.requests[-1].approve()

Request approved for domain test-domain-helm2


In [12]:
job = client.code.main_function(input_files=input_files_ptr, scenario_files=scenario_files_ptr, blocking=False)

# Get results

In [13]:
job

```python
class Job:
    id: UID = 5e2c95f2745749f59aa4a283d94e5a7d
    status: JobStatus.CREATED
    has_parent: False
    result: None
    logs:

0 
    
```

In [14]:
job.wait()

LAUNCHING JOB compute_document_data_overlap


FUNCTION LOG (6157a9c18e744c448adec26361244a3e): starting overlap computation
FUNCTION LOG (6157a9c18e744c448adec26361244a3e): preparing scenarios and creating indexes
FUNCTION LOG (6157a9c18e744c448adec26361244a3e): computing overlap
FUNCTION LOG (6157a9c18e744c448adec26361244a3e): done


Pointer:
None

In [15]:
job.subjobs[0]

```python
class Job:
    id: UID = 6157a9c18e744c448adec26361244a3e
    status: JobStatus.COMPLETED
    has_parent: True
    result: ActionDataEmpty UID: 68c00a1093504413bf7d2d84e6fe866b <None>
    logs:

0 starting overlap computation
1 preparing scenarios and creating indexes
2 computing overlap
3 done
JOB COMPLETED
    
```

In [16]:
job.subjobs[0].logs()

starting overlap computation
preparing scenarios and creating indexes
computing overlap
done



In [17]:
results = [j.wait().get() for j in job.subjobs]

# Aggregate

In [18]:
stats_key_to_input_ids, stats_key_to_reference_ids, stats_key_counts, entry_overlap_key_to_ngram_counts = zip(*results)

total_input_ids = defaultdict(set)
total_reference_ids = defaultdict(set)
total_stats_key_counts = defaultdict(int)
total_entry_overlap_key_to_ngram_counts = {}

for d in stats_key_counts:
    for key, val in d.items():
        total_stats_key_counts[key] += val


for d in stats_key_to_input_ids:
    for key in d:
        new_set = set()
        if key in total_input_ids:
            new_set = total_input_ids[key]
        new_set = new_set.union(d[key])
        total_input_ids[key] = new_set

for d in stats_key_to_reference_ids:
    for key in d:
        new_set = set()
        if key in total_reference_ids:
            new_set = total_reference_ids[key]
        new_set = total_reference_ids[key].union(d[key])
        total_reference_ids[key] = new_set
        
        
for d in entry_overlap_key_to_ngram_counts:
    for entry_overlap_key in d:
        if entry_overlap_key not in total_entry_overlap_key_to_ngram_counts:
            total_entry_overlap_key_to_ngram_counts[entry_overlap_key] = {}
        for ngram in d[entry_overlap_key]:
            if ngram not in total_entry_overlap_key_to_ngram_counts[entry_overlap_key]:
                total_entry_overlap_key_to_ngram_counts[entry_overlap_key][ngram] = 0
            k = total_entry_overlap_key_to_ngram_counts[entry_overlap_key][ngram]
            # total_entry_overlap_key_to_ngram_counts[entry_overlap_key][ngram] = max(k,d[entry_overlap_key][ngram])
            total_entry_overlap_key_to_ngram_counts[entry_overlap_key][ngram] = k + d[entry_overlap_key][ngram]
        
import json

all_data_overlap_stats = []
for stats_key, count in total_stats_key_counts.items():
    data_overlap_stats = {
        'data_overlap_stats_key': None,
        'num_instances': count,
        'instance_ids_with_overlapping_input': sorted(total_input_ids[stats_key]),
        'instance_ids_with_overlapping_reference': sorted(total_reference_ids[stats_key]),
    }
    scenario_spec, split, n_str = stats_key.rsplit('_', 2)

    scenario_spec = eval(scenario_spec)
    data_overlap_stats['data_overlap_stats_key'] = {
        'light_scenario_key': {'scenario_spec': scenario_spec, 'split': split},
        'overlap_protocol_spec': {'n': int(n_str)}
    }
    all_data_overlap_stats.append(data_overlap_stats)




In [19]:
from dataclasses import asdict, is_dataclass

def asdict_without_nones(obj):
    if not is_dataclass(obj):
        raise ValueError(f"Expected dataclass, got '{obj}'")
    return asdict(obj, dict_factory=lambda x: {k: v for (k, v) in x if v is not None})


all_entry_overlap_ngrams = []


with open(f"test_output_ngrams", "w") as f:
    for entry_overlap_key in total_entry_overlap_key_to_ngram_counts:
        ngram_counts = [
            ngram_count for ngram_count in total_entry_overlap_key_to_ngram_counts[entry_overlap_key].items()
        ]
        args, id, part = entry_overlap_key.rsplit('+', 2)
        dic, split, n = args.rsplit('_', 2)
        new_entry_overlap_key = {
            'stats_key': {
                "light_scenario_key": {
                    'scenario_spec': eval(dic),
                    'split': split
                },
                "overlap_protocol_spec": {'n': 5}
            },
            "part": part,
            "instance_id": id
        }
        entry_overlap_ngrams = {
            'entry_data_overlap_key': new_entry_overlap_key, 
            'overlapping_ngram_counts': ngram_counts
        }
        all_entry_overlap_ngrams.append(entry_overlap_ngrams)
        f.write(f"{json.dumps(entry_overlap_ngrams)}\n")

##### Add helm to Python_PATH

In [20]:
from util import get_metrics
metrics_list = get_metrics("test_output_ngrams", "short_new_scenario.jsonl", "new_metrics", '', 5)

In [26]:
metrics_list[0:10]

In [46]:
[m.entry_data_overlap_key.stats_key.overlap_protocol_spec.n for m in metrics_list[0:10]]

In [None]:
def merge_metrics(metrics_list):
    merged_metrics = []
    for i in range(len(metrics_list)//10):
        group_metrics = {
            "entry_data_overlap_key": metrics_list[10*i].entry_data_overlap_key,
            "metrics": [m.overlap_metric for m in metrics_list[10*i:10*(i+1)]]
        }
        merged_metrics.append(group_metrics)
    return merged_metrics

In [68]:
new_metrics = merge_metrics(metrics_list)

In [62]:
import pandas as pd

In [73]:
def metrics_to_cols(metrics):
    entry_data_overlap_key = metrics['entry_data_overlap_key']
    stats_key = entry_data_overlap_key.stats_key
    scenario_spec = stats_key.light_scenario_key.scenario_spec
    class_name = '.'.join(scenario_spec.class_name.split('.')[-2:])
    args = scenario_spec.args
    split = stats_key.light_scenario_key.split
    n = stats_key.overlap_protocol_spec.n
    binary_score = metrics['metrics'][0].metric_score
    jaccard_unweighted = metrics['metrics'][1].metric_score
    jaccard_weighted = metrics['metrics'][2].metric_score
    token_unweighted = metrics['metrics'][3].metric_score
    token_weighted = metrics['metrics'][4].metric_score
    return [class_name, args, split, n, binary_score, jaccard_unweighted, jaccard_weighted, token_unweighted, token_weighted]

In [74]:
columns = ['class_name', 'args', 'split', 'n', 'binary', 'jaccard_unweighted', 'jaccard_weighted', 'token_unweighted', 'token_weighted']
metrics_rows = [metrics_to_cols(metrics) for metrics in new_metrics]

In [75]:
metrics_df = pd.DataFrame(metrics_rows, columns=columns)

In [76]:
metrics_df

Unnamed: 0,class_name,args,split,n,binary,jaccard_unweighted,jaccard_weighted,token_unweighted,token_weighted
0,summarization_scenario.SummarizationScenario,"{'dataset_name': 'xsum-sampled', 'sampling_min...",valid,5,1,0.006667,0.006667,0.019737,0.019737
1,summarization_scenario.SummarizationScenario,"{'dataset_name': 'xsum-sampled', 'sampling_min...",valid,5,1,0.001916,0.001916,0.009506,0.009506
2,summarization_scenario.SummarizationScenario,"{'dataset_name': 'xsum-sampled', 'sampling_min...",test,5,1,0.001912,0.001912,0.009488,0.009488
3,summarization_scenario.SummarizationScenario,"{'dataset_name': 'xsum-sampled', 'sampling_min...",valid,5,1,0.001969,0.001969,0.009766,0.009766
4,summarization_scenario.SummarizationScenario,"{'dataset_name': 'xsum-sampled', 'sampling_min...",valid,5,1,0.00271,0.00271,0.013405,0.013405
5,summarization_scenario.SummarizationScenario,"{'dataset_name': 'xsum-sampled', 'sampling_min...",valid,5,1,0.002066,0.002066,0.010246,0.010246
6,summarization_scenario.SummarizationScenario,"{'dataset_name': 'xsum-sampled', 'sampling_min...",valid,5,1,0.001901,0.001901,0.009434,0.009434
7,summarization_scenario.SummarizationScenario,"{'dataset_name': 'xsum-sampled', 'sampling_min...",valid,5,1,0.001946,0.001946,0.009653,0.009653
8,summarization_scenario.SummarizationScenario,"{'dataset_name': 'xsum-sampled', 'sampling_min...",valid,5,1,0.003268,0.003268,0.016129,0.016129
9,summarization_scenario.SummarizationScenario,"{'dataset_name': 'xsum-sampled', 'sampling_min...",test,5,1,0.001919,0.001919,0.009524,0.009524


In [None]:
assert False

AssertionError: 

In [None]:
import ast
from nltk import ngrams

def compute_binary_overlap(instance_str, overlapping_ngram_counts, tokenizer, N, frequency = 0):
        """ 
        Compute  binary overlap
        If pass in frequency, include only the ngrams with count <= frequency
        """
        tokens = tokenizer.tokenize(instance_str)
        ngram_counts_dict = defaultdict(int)
        
        # construct a dict of ngram -> count
        for ngram, count in overlapping_ngram_counts:
            ngram = tuple(ast.literal_eval(ngram))
            ngram_counts_dict[ngram] = count

        metric_score = 0

        for ngram in ngrams(tokens, N):
            count = ngram_counts_dict[ngram]
            if frequency == 0 or count <= frequency:
                if count != 0:
                    metric_score = 1
                    break

        overlap_metric = {
            "metric_score": metric_score,
            "metric_protocol_spec": {
                "partial_overlap_spec": 0, #PartialOverlapSpec.binary,
                "frequency_spec": {
                    "filter_value": frequency,
                    "weighting": False
                }
            }
        }

        return overlap_metric

def compute_jaccard_overlap(instance_str, overlapping_ngram_counts, tokenizer, N, frequency = 0):
    """ 
    Compute weighted and unweighted jaccard overlap
    If pass in frequency, include only the ngrams with count <= frequency
    """
    tokens = tokenizer.tokenize(instance_str)
    ngram_counts_dict = defaultdict(int)
    
    # construct a dict of ngram -> count
    for ngram, count in overlapping_ngram_counts:
        ngram = tuple(ast.literal_eval(ngram))
        ngram_counts_dict[ngram] = count

    total_ngram_count = 0
    counts = 0
    weighted_score = 0

    for ngram in ngrams(tokens, N):
        count = ngram_counts_dict[ngram]
        if frequency == 0 or count <= frequency:
            if count != 0:
                counts += 1
                weighted_score += 1 / count
        total_ngram_count += 1

    unweighted_score = counts / total_ngram_count
    weighted_score = weighted_score / total_ngram_count

    unweighted_overlap_metric = {
        "metric_score": unweighted_score ,
        "metric_protocol_spec": {
            "partial_overlap_spec": 1, #PartialOverlapSpec.jaccard,
            "frequency_spec": {
                "filter_value": frequency,
                "weighting": False
            }
        }
    }

    weighted_overlap_metric = {
        "metric_score": weighted_score ,
        "metric_protocol_spec": {
            "partial_overlap_spec": 1, #PartialOverlapSpec.jaccard,
            "frequency_spec": {
                "filter_value": frequency,
                "weighting": True
            }
        }
    }

    return unweighted_overlap_metric, weighted_overlap_metric

# Token overlap
def compute_token_overlap(instance_str, overlapping_ngram_counts, tokenizer, N, frequency = 0):
    """ 
    Compute weighted and unweighted token overlap
    If pass in frequency, include only the ngrams with count <= frequency
    """
    tokens = tokenizer.tokenize(instance_str)
    ngram_counts_dict = defaultdict(int)
    
    # construct a dict of ngram -> count
    for ngram, count in overlapping_ngram_counts:
        ngram = tuple(ast.literal_eval(ngram))
        ngram_counts_dict[ngram] = count

    total_token_count = 0
    counts = 0
    weighted_score = 0
    weight = 0
    token_budget = 0

    for ngram in ngrams(tokens, N):
        curr_count = ngram_counts_dict[ngram]

        # either no frequency, or check current count is less than frequency
        # or a previous contiguous count (weight != 0) less than frequency
        if frequency == 0 or curr_count <= frequency or (weight != 0 and weight <= frequency):
            if curr_count != 0:
                token_budget = N
                if weight > 0:
                    weight = min(curr_count, weight)
                else:
                    weight = curr_count 

        if token_budget > 0:
            token_budget -= 1
            counts += 1
            weighted_score += 1 / weight
        else:
            weight = 0
        total_token_count += 1

    for token in ngram[1:]:
        if token_budget > 0:
            token_budget -= 1
            counts += 1
            weighted_score += 1 / weight
        total_token_count += 1

    unweighted_score = counts / total_token_count
    weighted_score = weighted_score / total_token_count

    unweighted_overlap_metric = {
        "metric_score": unweighted_score ,
        "metric_protocol_spec": {
            "partial_overlap_spec": 2, #PartialOverlapSpec.token,
            "frequency_spec": {
                "filter_value": frequency,
                "weighting": False
            }
        }
    }

    weighted_overlap_metric = {
        "metric_score": weighted_score ,
        "metric_protocol_spec": {
            "partial_overlap_spec": 2, #PartialOverlapSpec.token,
            "frequency_spec": {
                "filter_value": frequency,
                "weighting": True
            }
        }
    }

    return unweighted_overlap_metric, weighted_overlap_metric

def compute_and_add_metrics(instance_str, overlapping_ngram_counts, tokenizer, entry_data_overlap_key, entry_overlap_metric_list, N, frequency = 0):

    overlap_metric = compute_binary_overlap(instance_str, overlapping_ngram_counts, tokenizer, N, frequency)
    binary_metric = {"entry_data_overlap_key": entry_data_overlap_key, "overlap_metric": overlap_metric}
    entry_overlap_metric_list.append(binary_metric)

    unweighted_overlap_metric, weighted_overlap_metric = compute_jaccard_overlap(instance_str, overlapping_ngram_counts, tokenizer, N, frequency)
    unweighted_jaccard = {"entry_data_overlap_key": entry_data_overlap_key, "overlap_metric": unweighted_overlap_metric}
    weighted_jaccard = {"entry_data_overlap_key": entry_data_overlap_key, "overlap_metric": weighted_overlap_metric}
    entry_overlap_metric_list.append(unweighted_jaccard)
    entry_overlap_metric_list.append(weighted_jaccard)

    unweighted_overlap_metric, weighted_overlap_metric = compute_token_overlap(instance_str, overlapping_ngram_counts, tokenizer, N, frequency)
    unweighted_token = {"entry_data_overlap_key": entry_data_overlap_key, "overlap_metric": unweighted_overlap_metric}
    weighted_token = {"entry_data_overlap_key": entry_data_overlap_key, "overlap_metric": weighted_overlap_metric}
    entry_overlap_metric_list.append(unweighted_token)
    entry_overlap_metric_list.append(weighted_token)

def save_metrics_to_jsonl(overlap_metrics, filename):
    with open(filename, "w") as f:
        for overlap_metric in overlap_metrics:
            f.write(json.dumps(asdict_without_nones(overlap_metric), ensure_ascii=False) + "\n")

In [None]:
scenario_path = 'short_new_scenario.jsonl'
ngrams_path = 'test_output_ngrams'
import cattrs
N = 5

# Read Ngrams
ngram_jsons = open(ngrams_path, "r").readlines()
entry_overlap_ngrams_list = []
for ngram_json in ngram_jsons:
    entry_overlap_ngrams = json.loads(ngram_json)
    scenario_spec = entry_overlap_ngrams["entry_data_overlap_key"]["stats_key"]["light_scenario_key"]["scenario_spec"]
    entry_overlap_ngrams_list.append(entry_overlap_ngrams)
            
    def merge_entries(entry_overlap_ngrams_list):
        overlapping_counts = defaultdict(int)
        for entry_overlap_ngrams in entry_overlap_ngrams_list:
            entry_data_overlap_key = entry_overlap_ngrams["entry_data_overlap_key"]
            overlapping_ngram_counts = entry_overlap_ngrams["overlapping_ngram_counts"]
            for ngram, count in overlapping_ngram_counts:
                overlapping_counts[ngram] += count
        overlapping_ngram_counts_list = []
        for ngram, count in overlapping_counts.items():
            overlapping_ngram_counts_list.append((ngram, count))
        return [{"entry_data_overlap_key": entry_data_overlap_key, "overlapping_ngram_counts": overlapping_ngram_counts_list}]

    # create entry_overlap_ngrams_dict, a dict of entry_data_overlap_key -> EntryOverlapNgrams
    entry_overlap_ngrams_dict = defaultdict(list)
    for entry_overlap_ngrams in entry_overlap_ngrams_list:
        entry_data_overlap_key = entry_overlap_ngrams["entry_data_overlap_key"]
        overlapping_ngram_counts = entry_overlap_ngrams["overlapping_ngram_counts"]
        ngram_count = entry_data_overlap_key["stats_key"]["overlap_protocol_spec"]["n"]
        if ngram_count not in [N]:
            continue
        entry_overlap_ngrams_dict[str(entry_data_overlap_key)].append(entry_overlap_ngrams)
        
        # We need to merge entries if sharded by training data, since there'll be redundancy
        # Can refactor to no list later
        if len(entry_overlap_ngrams_dict[str(entry_data_overlap_key)]) > 1:
            entry_overlap_ngrams_dict[entry_data_overlap_key] = merge_entries(entry_overlap_ngrams_dict[entry_data_overlap_key])

    # Read Scenarios
    light_scenarios = []
    light_scenario_jsons = open(scenario_path, "r").readlines()
    for light_scenario_json in light_scenario_jsons:
        light_scenario_dict: dict = json.loads(light_scenario_json)

        light_scenario_key_dict: dict = light_scenario_dict["scenario_key"]
        scenario_spec = str(light_scenario_key_dict["scenario_spec"])
        light_scenario_key = scenario_spec + '_' + light_scenario_key_dict["split"]
        light_instances = [
            {
                'input': instance_dict['input'], 
                'references': instance_dict['references'], 
                'id': instance_dict["id"]
            }
            for instance_dict in light_scenario_dict["instances"]
        ]
        light_scenarios.append({'scenario_key': light_scenario_key, 'instances': light_instances})
    light_scenario_instance_dict = dict()
    for light_scenario in light_scenarios:
        instances = light_scenario["instances"]
        instance_dict = dict()
        for instance in instances:
            instance_dict[instance["id"]] = instance
        light_scenario_instance_dict[light_scenario["scenario_key"]] = instance_dict


In [None]:
out_path = 'metrics'

entry_overlap_metric_list = []
import re
from string import punctuation


class LightTokenizer:
    """
    Tokenize texts by splitting on whitespaces.
    """

    def tokenize(self, text: str):
        return text.split()
    
class DefaultTokenizer(LightTokenizer):
    """
    Normalize and tokenize texts by converting all characters to the lower case and
    splitting on whitespaces and punctuations.
    """

    def __init__(self):
        super().__init__()
        self.r = re.compile(r"[\s{}]+".format(re.escape(punctuation)))

    def tokenize(self, text: str):
        return self.r.split(text.lower())

tokenizer = DefaultTokenizer()
for entry_data_overlap_key, entry_overlap_ngrams_list in entry_overlap_ngrams_dict.items():
    entry_data_overlap_key = eval(entry_data_overlap_key)
    data_overlap_stats_key = entry_data_overlap_key["stats_key"]
    light_scenario_key = data_overlap_stats_key["light_scenario_key"]
    instance_dict = light_scenario_instance_dict[str(light_scenario_key)]
    for entry_overlap_ngrams in entry_overlap_ngrams_list:
        entry_data_overlap_key = entry_overlap_ngrams["entry_data_overlap_key"]
        instance_id = entry_data_overlap_key["instance_id"]
        instance = instance_dict[instance_id]
        part = entry_data_overlap_key["part"]
        overlapping_ngram_counts = entry_overlap_ngrams["overlapping_ngram_counts"]
        if part == 'input':
            compute_and_add_metrics(instance["input"], overlapping_ngram_counts, tokenizer, entry_data_overlap_key, entry_overlap_metric_list, N)
            compute_and_add_metrics(instance["input"], overlapping_ngram_counts, tokenizer, entry_data_overlap_key, entry_overlap_metric_list, N, frequency=10)
        if part == 'references':
            reference = ' '.join(instance.references)
            compute_and_add_metrics(reference, overlapping_ngram_counts, tokenizer, entry_data_overlap_key, entry_overlap_metric_list, N)
            compute_and_add_metrics(reference, overlapping_ngram_counts, tokenizer, entry_data_overlap_key, entry_overlap_metric_list, N, frequency=10)

save_metrics_to_jsonl(entry_overlap_metric_list, out_path)

KeyError: "{'scenario_spec': {'class_name': 'helm.benchmark.scenarios.summarization_scenario.SummarizationScenario', 'args': {'dataset_name': 'xsum-sampled', 'sampling_min_length': 50, 'sampling_max_length': 150, 'doc_max_length': 512}}, 'split': 'valid'}"