In [1]:
import ast
import json
import cattrs
import pandas as pd

from data_overlap_spec import DataOverlapStats, EntryOverlapNgrams

In [2]:
def data_overlap_stats_to_cols(data_overlap_stats, N):
    data_overlap_stats_key = data_overlap_stats.data_overlap_stats_key
    light_scenario_key = data_overlap_stats_key.light_scenario_key
    scenario_spec = light_scenario_key.scenario_spec
    class_name = scenario_spec.class_name
    class_name = class_name.split('.')[-2:]
    class_name = '.'.join(class_name)
    args = scenario_spec.args
    split = light_scenario_key.split
    n = data_overlap_stats_key.overlap_protocol_spec.n
    if n != N:
        return None
    num_instances = data_overlap_stats.num_instances
    num_overlapping_inputs = len(data_overlap_stats.instance_ids_with_overlapping_input)
    num_overlapping_references = len(data_overlap_stats.instance_ids_with_overlapping_reference)
    input_overlap_percent = num_overlapping_inputs / num_instances
    reference_overlap_percent = num_overlapping_references / num_instances
    cols = [class_name, args, split, n, input_overlap_percent, reference_overlap_percent, num_instances, num_overlapping_inputs, num_overlapping_references, sorted(data_overlap_stats.instance_ids_with_overlapping_input), sorted(data_overlap_stats.instance_ids_with_overlapping_reference)]
    for i in range(22):
        cols.append([])
    return cols

In [3]:
# output_path = 'output_stats_pile_all'
output_path = 'output_stats_pile_new2_xaa'
# output_path = 'output_stats_pile_new3_ngram_xad'
output_stats_jsons = open(output_path, "r").readlines()

# create dict of DataOverlapStatsKey -> [aggregate_stats, n_grams]
full_stats_dict = dict()

data_overlap_stats_list = []
for output_stats_json in output_stats_jsons:
    output_stats_dict = json.loads(output_stats_json)
    data_overlap_stats_list.append(cattrs.structure(output_stats_dict, DataOverlapStats))

In [4]:
for data_overlap_stats in data_overlap_stats_list:
    data_overlap_stats_key = data_overlap_stats.data_overlap_stats_key
    cols = data_overlap_stats_to_cols(data_overlap_stats, 13)
    if cols:
        full_stats_dict[data_overlap_stats_key] = cols


In [5]:
# ngram_path =  'output_stats_pile_ngrams_all2'
ngram_path = 'output_stats_pi..gram_xaa_ngrams'
# ngram_path = 'output_stats_pi..gram_xad_ngrams'
ngram_jsons = open(ngram_path, "r").readlines()
entry_overlap_ngrams_list = []
for ngram_json in ngram_jsons:
    entry_overlap_ngrams = json.loads(ngram_json)
    entry_overlap_ngrams_list.append(cattrs.structure(entry_overlap_ngrams, EntryOverlapNgrams))
#     print(entry_overlap_ngrams)
#     break

In [6]:
MAX_NGRAMS = 20
MAX_IDS = 20
for entry_overlap_ngrams in entry_overlap_ngrams_list:
    entry_data_overlap_key = entry_overlap_ngrams.entry_data_overlap_key
    overlapping_ngram_counts = entry_overlap_ngrams.overlapping_ngram_counts
    ngram_count = entry_data_overlap_key.stats_key.overlap_protocol_spec.n
    if ngram_count in [5, 9]:
        continue
    overlapping_ngram_counts_list = list()
    for overlapping_ngram_count in overlapping_ngram_counts:
        overlapping_ngram = ast.literal_eval(overlapping_ngram_count[0])
        str_ngram_count = ' '.join(overlapping_ngram).replace(',', '|'), overlapping_ngram_count[1]
#         str_ngram_count[0] = str_ngram_count[0]
        overlapping_ngram_counts_list.append(str_ngram_count)
        if len(overlapping_ngram_counts_list) >= MAX_NGRAMS:
            break
    data_overlap_stats_key = entry_data_overlap_key.stats_key
    part = entry_data_overlap_key.part
    instance_id = entry_data_overlap_key.instance_id
    if part == 'input':
#         for i in range(11, 21):
#             if not full_stats_dict[data_overlap_stats_key][i]:
#                 full_stats_dict[data_overlap_stats_key][i] = (instance_id, overlapping_ngram_counts_list)
#                 break
        if len(full_stats_dict[data_overlap_stats_key][11]) >= MAX_IDS:
            continue
        full_stats_dict[data_overlap_stats_key][11].append((instance_id, overlapping_ngram_counts_list))
    else:
        for i in range(21, 31):
            if not full_stats_dict[data_overlap_stats_key][i]:
                full_stats_dict[data_overlap_stats_key][i] = (instance_id, overlapping_ngram_counts_list)
                break
        if len(full_stats_dict[data_overlap_stats_key][12]) >= MAX_IDS:
            continue
        full_stats_dict[data_overlap_stats_key][12].append((entry_data_overlap_key.instance_id, overlapping_ngram_counts_list))

In [7]:
data_overlap_stats_rows = []
for data_overlap_stats_key, data_overlap_stats_row in full_stats_dict.items():
    if data_overlap_stats_row:
        data_overlap_stats_row[11].sort()
        data_overlap_stats_row[12].sort()
        input_count = 0
        for i in range(min(len(data_overlap_stats_row[11]),10)):
            data_overlap_stats_row[13+i]=data_overlap_stats_row[11][i]
        for i in range(min(len(data_overlap_stats_row[12]),10)):
            data_overlap_stats_row[23+i]=data_overlap_stats_row[12][i]
        data_overlap_stats_rows.append(data_overlap_stats_row)

In [8]:
data_overlap_stats_rows.sort(key=lambda x: x[4], reverse=True)
data_overlap_stats_rows

[['lsat_qa_scenario.LSATScenario',
  {'task': 'grouping'},
  'valid',
  13,
  0.6666666666666666,
  0.0,
  12,
  8,
  0,
  ['id271', 'id272', 'id273', 'id274', 'id275', 'id276', 'id277', 'id278'],
  [],
  [('id271',
    [('jazz opera pop rap and soul the store is having a sale on', 2),
     ('opera pop rap and soul the store is having a sale on some', 2),
     ('which one of the following could be a complete and accurate list of',
      4),
     ('one of the following could be a complete and accurate list of the', 4),
     ('new and used of each of jazz opera pop rap and soul the', 1),
     ('and used of each of jazz opera pop rap and soul the store', 1),
     ('used of each of jazz opera pop rap and soul the store is', 1),
     ('of each of jazz opera pop rap and soul the store is having', 1),
     ('each of jazz opera pop rap and soul the store is having a', 1),
     ('of jazz opera pop rap and soul the store is having a sale', 1),
     ('pop rap and soul the store is having a sale o

In [10]:
# columns = ['class_name', 'args', 'split', 'n',  'input_overlap_ratio', 'reference_overlap_ratio', 'num_instances', 'inputs_num_overlapping', 'references_num_overlapping', 'input_ids', 'reference_ids', 'input_ngrams', 'reference_ngrams']
columns = ['class_name', 'args', 'split', 'n',  'input_overlap_ratio', 'reference_overlap_ratio', 'num_instances', 'inputs_num_overlapping', 'references_num_overlapping', 'input_ids', 'reference_ids']
columns.append(f'input_ngrams')
columns.append(f'reference_ngrams')
for i in range(10):
    columns.append(f'input_ngrams{i}')
for i in range(10):
    columns.append(f'reference_ngrams{i}')
data_overlap_stats_df = pd.DataFrame(data_overlap_stats_rows, columns=columns)

In [11]:
data_overlap_stats_df

Unnamed: 0,class_name,args,split,n,input_overlap_ratio,reference_overlap_ratio,num_instances,inputs_num_overlapping,references_num_overlapping,input_ids,...,reference_ngrams0,reference_ngrams1,reference_ngrams2,reference_ngrams3,reference_ngrams4,reference_ngrams5,reference_ngrams6,reference_ngrams7,reference_ngrams8,reference_ngrams9
0,lsat_qa_scenario.LSATScenario,{'task': 'grouping'},valid,13,0.666667,0.0,12,8,0,"[id271, id272, id273, id274, id275, id276, id2...",...,[],[],[],[],[],[],[],[],[],[]
1,ice_scenario.ICEScenario,"{'subset': 'can', 'category': 'W1'}",test,13,0.34,0.0,50,17,0,"[id0, id1, id13, id27, id3, id37, id39, id4, i...",...,[],[],[],[],[],[],[],[],[],[]
2,ice_scenario.ICEScenario,"{'subset': 'hk', 'category': 'W1'}",test,13,0.28,0.0,50,14,0,"[id1, id10, id14, id15, id16, id19, id20, id24...",...,[],[],[],[],[],[],[],[],[],[]
3,ice_scenario.ICEScenario,"{'subset': 'usa', 'category': 'W2'}",test,13,0.266667,0.0,150,40,0,"[id1, id11, id110, id121, id123, id125, id127,...",...,[],[],[],[],[],[],[],[],[],[]
4,lsat_qa_scenario.LSATScenario,{'task': 'grouping'},train,13,0.169742,0.00369,271,46,1,"[id101, id102, id103, id104, id105, id106, id1...",...,"(id156, [(only two of the six kinds of birds a...",[],[],[],[],[],[],[],[],[]
5,lsat_qa_scenario.LSATScenario,{'task': 'miscellaneous'},train,13,0.156171,0.0,397,62,0,"[id112, id116, id160, id161, id162, id163, id1...",...,[],[],[],[],[],[],[],[],[],[]
6,ice_scenario.ICEScenario,"{'subset': 'ind', 'category': 'S2'}",test,13,0.15,0.0,120,18,0,"[id108, id111, id116, id12, id16, id17, id35, ...",...,[],[],[],[],[],[],[],[],[],[]
7,ice_scenario.ICEScenario,"{'subset': 'sin', 'category': 'W2'}",test,13,0.12,0.0,150,18,0,"[id10, id121, id13, id133, id142, id148, id15,...",...,[],[],[],[],[],[],[],[],[],[]
8,ice_scenario.ICEScenario,{'category': 'S'},test,13,0.089836,0.0,2371,213,0,"[id1, id1009, id1014, id1043, id1047, id1050, ...",...,[],[],[],[],[],[],[],[],[],[]
9,lsat_qa_scenario.LSATScenario,{'task': 'miscellaneous'},valid,13,0.068063,0.0,191,13,0,"[id421, id422, id423, id424, id425, id426, id5...",...,[],[],[],[],[],[],[],[],[],[]


In [12]:
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_all_small2.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_all_small3.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_all_small4.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_all_small5.csv', index=False)

# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_xaa_small2.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_xad_small2.csv', index=False)
# data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_xaa_small4.csv', index=False)
data_overlap_stats_df.to_csv('the_pile_overlap_stats_ngrams_xaa_small5.csv', index=False)