In [1]:
import ast
import json
import cattrs
import pandas as pd
from nltk import ngrams
from collections import defaultdict

from data_overlap_spec import DataOverlapStats, EntryOverlapNgrams
from compute_data_overlap_metrics import load_light_scenarios_from_jsonl
from common.util import get_tokenizer
from common.general import asdict_without_nones

In [2]:
from dataclasses import dataclass
from typing import List, Tuple
from light_scenario import LightScenarioKey


@dataclass(frozen=True)
class OverlapProtocolSpec:
    """Specification for how we compute overlap"""

    # the N of the n_grams we're running
    n: int


@dataclass(frozen=True)
class DataOverlapStatsKey:
    """Dataclass that represents output data overlap stats"""

    light_scenario_key: LightScenarioKey

    overlap_protocol_spec: OverlapProtocolSpec


@dataclass(frozen=True)
class DataOverlapStats:
    """Dataclass that represents output data overlap stats"""

    data_overlap_stats_key: DataOverlapStatsKey

    num_instances: int

    instance_ids_with_overlapping_input: List[str]

    instance_ids_with_overlapping_reference: List[str]


@dataclass(frozen=True)
class EntryDataOverlapKey:
    """Unique key representing either the input or references of a single instance in a scenario."""

    stats_key: DataOverlapStatsKey
    part: str
    """Either PART_INPUT or PART_REF"""
    instance_id: str


@dataclass(frozen=True)
class EntryOverlapNgrams:

    entry_data_overlap_key: EntryDataOverlapKey

    overlapping_ngram_counts: List[Tuple[str, int]]

@dataclass(frozen=True)
class AnnotatedEntryOverlap:
    """
    Dataclass annotates a given scenario entry with overlaps
    """

    entry_data_overlap_key: EntryDataOverlapKey
    
    counts: int
    
    weighted_counts: int

    annotated_entry_overlap: List[Tuple[str, int]]
    """list of (word, count) where (word, count) is the 13-gram that starts with word"""




In [3]:
def data_overlap_stats_to_cols(data_overlap_stats, N):
    data_overlap_stats_key = data_overlap_stats.data_overlap_stats_key
    light_scenario_key = data_overlap_stats_key.light_scenario_key
    scenario_spec = light_scenario_key.scenario_spec
    class_name = scenario_spec.class_name
    class_name = class_name.split('.')[-2:]
    class_name = '.'.join(class_name)
    args = scenario_spec.args
    split = light_scenario_key.split
    n = data_overlap_stats_key.overlap_protocol_spec.n
    if n != N:
        return None
    num_instances = data_overlap_stats.num_instances
    num_overlapping_inputs = len(data_overlap_stats.instance_ids_with_overlapping_input)
    num_overlapping_references = len(data_overlap_stats.instance_ids_with_overlapping_reference)
    input_overlap_percent = num_overlapping_inputs / num_instances
    reference_overlap_percent = num_overlapping_references / num_instances
    cols = [class_name, args, split, n, input_overlap_percent, reference_overlap_percent, num_instances, num_overlapping_inputs, num_overlapping_references, sorted(data_overlap_stats.instance_ids_with_overlapping_input), sorted(data_overlap_stats.instance_ids_with_overlapping_reference)]
    for i in range(62):
        cols.append([])
    return cols

In [12]:
output_path = 'output_stats_pile_all'
ngram_path =  'output_stats_pile_ngrams_all2'
outpath = 'the_pile_overlap_stats_ngrams_all_tokenized_and_raw2.csv'
path = 'scenario_data_new'

# output_path = 'output_stats_pile_new2_xaa'
# ngram_path =  'output_stats_pi..gram_xaa_ngrams'
# outpath = 'test_outpath_xaa.csv'
# path = './data/xa/xaa'

# path = './data/xa/xaa'
# output_path = 'output_stats_pile_new2_xaa'
# output_path = 'output_stats_pile_new3_ngram_xad'

# ngram_path = 'output_stats_pi..gram_xaa_ngrams'
# ngram_path = 'output_stats_pi..gram_xad_ngrams'


In [13]:

output_stats_jsons = open(output_path, "r").readlines()

# create dict of DataOverlapStatsKey -> output_stats
full_stats_dict = dict()

data_overlap_stats_list = []
for output_stats_json in output_stats_jsons:
    output_stats_dict = json.loads(output_stats_json)
    data_overlap_stats_list.append(cattrs.structure(output_stats_dict, DataOverlapStats))

In [14]:
for data_overlap_stats in data_overlap_stats_list:
    data_overlap_stats_key = data_overlap_stats.data_overlap_stats_key
    full_stats_dict[data_overlap_stats_key] = data_overlap_stats


In [15]:

ngram_jsons = open(ngram_path, "r").readlines()
entry_overlap_ngrams_list = []
for ngram_json in ngram_jsons:
    entry_overlap_ngrams = json.loads(ngram_json)
    entry_overlap_ngrams_list.append(cattrs.structure(entry_overlap_ngrams, EntryOverlapNgrams))

In [16]:
# create dict of DataOverlapStatsKey -> EntryOverlapNgrams
entry_overlap_ngrams_dict = defaultdict(list)
for entry_overlap_ngrams in entry_overlap_ngrams_list:
    entry_data_overlap_key = entry_overlap_ngrams.entry_data_overlap_key
    overlapping_ngram_counts = entry_overlap_ngrams.overlapping_ngram_counts
    ngram_count = entry_data_overlap_key.stats_key.overlap_protocol_spec.n
    stats_key = entry_data_overlap_key.stats_key
    if ngram_count in [5, 9]:
        continue
    entry_overlap_ngrams_dict[stats_key].append(entry_overlap_ngrams)
        

In [17]:
light_scenarios = load_light_scenarios_from_jsonl(path)
light_scenario_instance_dict = dict()
for light_scenario in light_scenarios:
    instances = light_scenario.instances
    instance_dict = dict()
    for instance in instances:
        instance_dict[instance.id] = instance
    light_scenario_instance_dict[light_scenario.scenario_key] = instance_dict

In [34]:
def annotate_with_ngrams(instance_str, overlapping_ngram_counts, tokenizer):
    """ 
    Tokenize instance str and get ngrams, then we match the ngrams to the ngram counts and annotate
    Returns a list of (word, count) where (word, count) is the 13-gram that starts with word,
    along with counts of overlaps. If count > 0 then there is an overlap
    """
    tokens = tokenizer.tokenize(instance_str)
    ngram_counts_dict = defaultdict(int)
    for ngram, count in overlapping_ngram_counts:
        ngram = tuple(ast.literal_eval(ngram))
        ngram_counts_dict[ngram] = count
#         print(ngram_list, count)
#         print(type(ngram_list))
#         break
    annotated_input = []
    counts = 0
    weighted_counts = 0
    for ngram in ngrams(tokens, 13):
        count = ngram_counts_dict[ngram]
        annotated_input.append((ngram[0], count))
        if count != 0:
            counts += 1
        weighted_counts += count
    for token in ngram[1:]:
        annotated_input.append((token, 0))
    return annotated_input, counts, weighted_counts

import os
def append_entry_overlap_to_jsonl(annotated_entry_overlap_list, filename: str):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, "a") as f:
        for annotated_entry_overlap in annotated_entry_overlap_list:
            f.write(json.dumps(asdict_without_nones(annotated_entry_overlap), ensure_ascii=False) + "\n")

MAX_IDS = 10
MAX_LEN = 1000
tokenizer = get_tokenizer('default')
annotated_entry_overlap_list = []
iterations = 1000
for data_overlap_stats_key, entry_overlap_ngrams_list in entry_overlap_ngrams_dict.items():
    light_scenario_key = data_overlap_stats_key.light_scenario_key
    split = data_overlap_stats_key.light_scenario_key.split
    class_name = data_overlap_stats_key.light_scenario_key.scenario_spec.class_name
    class_name = class_name.split('.')[-2:]
    class_name = '.'.join(class_name)
    instance_dict = light_scenario_instance_dict[light_scenario_key]
    id_count = 0
    ref_count = 0
    for entry_overlap_ngrams in entry_overlap_ngrams_list:
        entry_data_overlap_key = entry_overlap_ngrams.entry_data_overlap_key
        instance_id = entry_data_overlap_key.instance_id
        instance = instance_dict[instance_id]
        part = entry_data_overlap_key.part
        overlapping_ngram_counts = entry_overlap_ngrams.overlapping_ngram_counts
        if part == 'input':
            annotated_overlap, counts, weighted_counts = annotate_with_ngrams(instance.input, overlapping_ngram_counts, tokenizer)
            if counts > 0 and id_count < MAX_IDS:
                annotated_overlap = annotated_overlap[:MAX_LEN]
                annotated_entry_overlap = AnnotatedEntryOverlap(entry_data_overlap_key=entry_data_overlap_key, annotated_entry_overlap=annotated_overlap, counts=counts, weighted_counts=weighted_counts)
                annotated_entry_overlap_list.append(annotated_entry_overlap)
                id_count += 1
        if part == 'references':
            reference = ' '.join(instance.references)
            annotated_overlap, counts, weighted_counts = annotate_with_ngrams(reference, overlapping_ngram_counts, tokenizer)
            if counts > 0 and ref_count < MAX_IDS:
                annotated_overlap = annotated_overlap[:MAX_LEN]
                annotated_entry_overlap = AnnotatedEntryOverlap(entry_data_overlap_key=entry_data_overlap_key, annotated_entry_overlap=annotated_overlap, counts=counts, weighted_counts=weighted_counts)
                annotated_entry_overlap_list.append(annotated_entry_overlap)
                ref_count += 1
#                 print('hi')
        if ref_count > MAX_IDS and id_count > MAX_IDS:
            break
#         iterations -= 1
    
    
    append_entry_overlap_to_jsonl(annotated_entry_overlap_list, f'ngram_data/{class_name}_{split}.jsonl')
    annotated_entry_overlap_list = []
#     if iterations < 0:
#         break
    

            

