In [1]:
import ast
import json
import cattrs
import pandas as pd

from data_overlap_spec import DataOverlapStats, EntryOverlapNgrams

In [2]:
def data_overlap_stats_to_cols(data_overlap_stats, N):
    data_overlap_stats_key = data_overlap_stats.data_overlap_stats_key
    light_scenario_key = data_overlap_stats_key.light_scenario_key
    scenario_spec = light_scenario_key.scenario_spec
    class_name = scenario_spec.class_name
    class_name = class_name.split('.')[-2:]
    class_name = '.'.join(class_name)
    args = scenario_spec.args
    split = light_scenario_key.split
    n = data_overlap_stats_key.overlap_protocol_spec.n
    if n != N:
        return None
    num_instances = data_overlap_stats.num_instances
    num_overlapping_inputs = len(data_overlap_stats.instance_ids_with_overlapping_input)
    num_overlapping_references = len(data_overlap_stats.instance_ids_with_overlapping_reference)
    input_overlap_percent = num_overlapping_inputs / num_instances
    reference_overlap_percent = num_overlapping_references / num_instances
    cols = [class_name, args, split, n, input_overlap_percent, reference_overlap_percent, num_instances, num_overlapping_inputs, num_overlapping_references, sorted(data_overlap_stats.instance_ids_with_overlapping_input), sorted(data_overlap_stats.instance_ids_with_overlapping_reference)]
    for i in range(62):
        cols.append([])
    return cols

In [29]:
output_path = 'output_stats_pile_all'
ngram_path =  'output_stats_pile_ngrams_all2'
outpath = 'the_pile_overlap_stats_ngrams_all_tokenized_and_raw2.csv'
path = 'scenario_data_new'

# path = './data/xa/xaa'
# output_path = 'output_stats_pile_new2_xaa'
# output_path = 'output_stats_pile_new3_ngram_xad'

# ngram_path = 'output_stats_pi..gram_xaa_ngrams'
# ngram_path = 'output_stats_pi..gram_xad_ngrams'


In [4]:

output_stats_jsons = open(output_path, "r").readlines()

# create dict of DataOverlapStatsKey -> [aggregate_stats, n_grams]
full_stats_dict = dict()

data_overlap_stats_list = []
for output_stats_json in output_stats_jsons:
    output_stats_dict = json.loads(output_stats_json)
    data_overlap_stats_list.append(cattrs.structure(output_stats_dict, DataOverlapStats))

In [5]:
for data_overlap_stats in data_overlap_stats_list:
    data_overlap_stats_key = data_overlap_stats.data_overlap_stats_key
    cols = data_overlap_stats_to_cols(data_overlap_stats, 13)
    if cols:
        full_stats_dict[data_overlap_stats_key] = cols


In [6]:

ngram_jsons = open(ngram_path, "r").readlines()
entry_overlap_ngrams_list = []
for ngram_json in ngram_jsons:
    entry_overlap_ngrams = json.loads(ngram_json)
    entry_overlap_ngrams_list.append(cattrs.structure(entry_overlap_ngrams, EntryOverlapNgrams))
#     print(entry_overlap_ngrams)
#     break

In [7]:
MAX_NGRAMS = 20
MAX_IDS = 20
for entry_overlap_ngrams in entry_overlap_ngrams_list:
    entry_data_overlap_key = entry_overlap_ngrams.entry_data_overlap_key
    overlapping_ngram_counts = entry_overlap_ngrams.overlapping_ngram_counts
    ngram_count = entry_data_overlap_key.stats_key.overlap_protocol_spec.n
    if ngram_count in [5, 9]:
        continue
    overlapping_ngram_counts_list = list()
    for overlapping_ngram_count in overlapping_ngram_counts:
        overlapping_ngram = ast.literal_eval(overlapping_ngram_count[0])
        str_ngram_count = ' '.join(overlapping_ngram).replace(',', '|'), overlapping_ngram_count[1]
#         str_ngram_count[0] = str_ngram_count[0]
        overlapping_ngram_counts_list.append(str_ngram_count)
        if len(overlapping_ngram_counts_list) >= MAX_NGRAMS:
            break
    data_overlap_stats_key = entry_data_overlap_key.stats_key
    part = entry_data_overlap_key.part
    instance_id = entry_data_overlap_key.instance_id
    if part == 'input':
#         for i in range(11, 21):
#             if not full_stats_dict[data_overlap_stats_key][i]:
#                 full_stats_dict[data_overlap_stats_key][i] = (instance_id, overlapping_ngram_counts_list)
#                 break
        if len(full_stats_dict[data_overlap_stats_key][11]) >= MAX_IDS:
            continue
        full_stats_dict[data_overlap_stats_key][11].append((instance_id, overlapping_ngram_counts_list))
    else:
#         for i in range(21, 31):
#             if not full_stats_dict[data_overlap_stats_key][i]:
#                 full_stats_dict[data_overlap_stats_key][i] = (instance_id, overlapping_ngram_counts_list)
#                 break
        if len(full_stats_dict[data_overlap_stats_key][12]) >= MAX_IDS:
            continue
        full_stats_dict[data_overlap_stats_key][12].append((entry_data_overlap_key.instance_id, overlapping_ngram_counts_list))

In [8]:
import json
import os
import glob

from typing import List, Tuple, Set, DefaultDict
from nltk import ngrams
from typing import Dict
from tqdm import tqdm
from collections import defaultdict

from light_scenario import LightInstance, LightScenario, LightScenarioKey
from data_overlap_spec import (
    DataOverlapStats,
    DataOverlapStatsKey,
    OverlapProtocolSpec,
    EntryDataOverlapKey,
    EntryOverlapNgrams,
)
from light_tokenizer import LightTokenizer
from load_documents import get_document_iterator
from common.hierarchical_logger import hlog, htrack_block
from common.general import asdict_without_nones
from common.arguments import get_data_overlap_args
from common.util import get_tokenizer
from scenarios.scenario import ScenarioSpec


# The n values of the ngrams to be computed
N_VALUES: List[int] = [5, 9, 13]  # TODO: Pick the N values

PART_INPUT: str = "input"
PART_REF: str = "references"


# type alias for overlap-related data structures
Ngram = Tuple[str, ...]
NgramIndex = Dict[int, Dict[Ngram, Set[EntryDataOverlapKey]]]
NgramCounter = Dict[EntryDataOverlapKey, Dict[Ngram, int]]


def load_light_scenarios_from_jsonl(path: str) -> List[LightScenario]:
    """
    Create a list of light scenarios from a jsonl file, where each json represents a LightScenario object.

    Input file format:

    Instance JSON 1
    Instance JSON 2
    Instance JSON 3
    ...
    """

    def create_light_instance_from_dict(instance_dict: dict) -> LightInstance:
        return LightInstance(
            input=instance_dict[PART_INPUT], references=instance_dict[PART_REF], id=instance_dict["id"]
        )

    light_scenarios: List[LightScenario] = []
    light_scenario_jsons = open(path, "r").readlines()
    for light_scenario_json in light_scenario_jsons:
        light_scenario_dict: dict = json.loads(light_scenario_json)

        light_scenario_key_dict: dict = light_scenario_dict["scenario_key"]
        # if the light_scenarios are exported from helm, they will have a scenario_spec field
        scenario_spec = ScenarioSpec(**light_scenario_key_dict["scenario_spec"])
        light_scenario_key = LightScenarioKey(scenario_spec=scenario_spec, split=light_scenario_key_dict["split"])
        light_instances: List[LightInstance] = [
            create_light_instance_from_dict(instance_dict) for instance_dict in light_scenario_dict["instances"]
        ]
        light_scenarios.append(LightScenario(scenario_key=light_scenario_key, instances=light_instances))
    return light_scenarios
light_scenarios = load_light_scenarios_from_jsonl(path)
light_scenario_dict = dict()
for light_scenario in light_scenarios:
    light_scenario_dict[light_scenario.scenario_key] = light_scenario

In [30]:
tokenizer = get_tokenizer('default')
data_overlap_stats_rows = []
for data_overlap_stats_key, data_overlap_stats_row in full_stats_dict.items():
    if data_overlap_stats_row:
        light_scenario = light_scenario_dict[data_overlap_stats_key.light_scenario_key]
        instances = light_scenario.instances
        instance_dict = dict()
        for instance in instances:
            instance_dict[instance.id] = instance
        data_overlap_stats_row[11].sort()
        data_overlap_stats_row[12].sort()
        input_count = 0
        for i in range(min(len(data_overlap_stats_row[11]),10)):
            instance_id = data_overlap_stats_row[11][i][0]
            instance = instance_dict[instance_id]
            data_overlap_stats_row[13+3*i]=data_overlap_stats_row[11][i]
            data_overlap_stats_row[13+3*i+1]=' '.join(tokenizer.tokenize(instance.input)[:3000])
            data_overlap_stats_row[13+3*i+2]=instance.input[:30000]
        for i in range(min(len(data_overlap_stats_row[12]),10)):
            instance_id = data_overlap_stats_row[12][i][0]
            instance = instance_dict[instance_id]
            data_overlap_stats_row[43+3*i]=data_overlap_stats_row[12][i]
            data_overlap_stats_row[43+3*i+1]=[' '.join(tokenizer.tokenize(x)[:3000]) for x in instance.references]
            data_overlap_stats_row[43+3*i+2]=[x[:20000] for x in instance.references]
            
#             data_overlap_stats_row[23+i]=data_overlap_stats_row[12][i]
        data_overlap_stats_rows.append(data_overlap_stats_row)

In [31]:
data_overlap_stats_rows.sort(key=lambda x: x[4], reverse=True)
# data_overlap_stats_rows

In [32]:
# columns = ['class_name', 'args', 'split', 'n',  'input_overlap_ratio', 'reference_overlap_ratio', 'num_instances', 'inputs_num_overlapping', 'references_num_overlapping', 'input_ids', 'reference_ids', 'input_ngrams', 'reference_ngrams']
columns = ['class_name', 'args', 'split', 'n',  'input_overlap_ratio', 'reference_overlap_ratio', 'num_instances', 'inputs_num_overlapping', 'references_num_overlapping', 'input_ids', 'reference_ids']
columns.append(f'input_ngrams')
columns.append(f'reference_ngrams')
for i in range(10):
    columns.append(f'input_ngrams{i}')
    columns.append(f'input_instance{i}')
    columns.append(f'input_instance{i}')
for i in range(10):
    columns.append(f'reference_ngrams{i}')
    columns.append(f'reference_instance{i}')
    columns.append(f'reference_instance{i}')
data_overlap_stats_df = pd.DataFrame(data_overlap_stats_rows, columns=columns)

In [33]:
# data_overlap_stats_df

In [34]:
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_all_small2.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_all_small3.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_all_small4.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_all_small5.csv', index=False)

# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_xaa_small2.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_xad_small2.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_xaa_small4.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_xaa_small5.csv', index=False)
data_overlap_stats_df.to_csv(outpath, index=False)