In [None]:
import syft as sy
import os
from syft import ActionObject
from collections import defaultdict

Start this using

```
tox -e dev.k8s.start
tox -e dev.k8s.deploy
```

In [None]:
client = sy.login(url="http://localhost:8080", email="info@openmined.org", password="changethis")

# Mount storage container with Helm azure container

# Start workers

In [None]:
worker_count = os.cpu_count() - 2
worker_count

In [None]:
client.worker.start_workers(n=worker_count)
# client.worker.start_workers(n=1)

In [None]:
client.worker.list()

# Create Dataset

In [None]:
# split.py
# # stdlib
# import os
# import sys


# def split_file(file_path, num_chunks):
#     with open(file_path, 'r', encoding='utf-8') as file:
#         total_lines = sum(1 for line in file)

#     lines_per_chunk = total_lines // num_chunks

#     with open(file_path, 'r', encoding='utf-8') as file:
#         for chunk in range(num_chunks):
#             print(f"Creating chunk {chunk}")
#             chunk_file_name = f"{os.path.splitext(file_path)[0]}-chunk-{str(chunk).zfill(len(str(num_chunks)))}.jsonl"
#             with open(chunk_file_name, 'w', encoding='utf-8') as chunk_file:
#                 for _ in range(lines_per_chunk):
#                     line = file.readline()
#                     if not line:
#                         break
#                     chunk_file.write(line)

#                 # Handle any remaining lines for the last chunk
#                 if chunk == num_chunks - 1:
#                     for line in file:
#                         chunk_file.write(line)

# if __name__ == "__main__":
#     if len(sys.argv) != 2:
#         print("Usage: python script.py <filename>")
#         sys.exit(1)

#     file_path = sys.argv[1]
#     num_chunks = os.cpu_count() or 1  # Default to 1 if cpu_count is None
#     split_file(file_path, num_chunks)


In [None]:
# download filtered_scenario_data_new.jsonl from azure
# download train-00.jsonl from azure
# run split.py train-00.jsonl
# WARNING: bug where files around 2GB are causing issues with upload so try ~200mb

In [None]:
data_dir = "/Users/madhavajay/dev/data/"

In [None]:
scenario_path = f"{data_dir}filtered_scenario_data_new.jsonl"

In [None]:
# train-00-chunk-00.jsonl

In [None]:
split_files = []
import os
for file in os.listdir(data_dir):
    print(file)
    if file.startswith("train-00-chunk") and file.endswith(".jsonl"):
        name = file.split(".")[0]
        path = os.path.join(data_dir, file)
        split_files.append((name, path))


In [None]:
scenario_file = sy.ActionObject.from_path(path=scenario_path).send(client).syft_action_data

In [None]:
train_files = []
for file in split_files:
    path = file[1]
    data = sy.ActionObject.from_path(path)
    train_file = data.send(client).syft_action_data
    train_files.append(train_file)
    print("Added ", file)

In [None]:
helm_dataset = sy.Dataset(
    name="Helm Dataset",
    asset_list=[
        sy.Asset(
            name="helm train data",
            data=ActionObject.from_obj(train_files),
            mock=sy.ActionObject.empty()
        ),
        sy.Asset(
            name="helm test data",
            data=ActionObject.from_obj([scenario_file]),
            mock=sy.ActionObject.empty()
        )
    ]
)

In [None]:
client.upload_dataset(helm_dataset)

In [None]:
helm_ds = client.datasets["Helm Dataset"]
helm_train_files = helm_ds.assets["helm train data"]
helm_test_files = helm_ds.assets["helm test data"]

# Syft functions

In [None]:
@sy.syft_function()
def compute_document_data_overlap(domain, scenario_file, input_files, n):
    print("starting overlap computation")

    from nltk import ngrams
    from collections import defaultdict
    from string import punctuation
    import re, json
    import time

    r = re.compile(r"[\s{}]+".format(re.escape(punctuation)))
    
    def create_ngram_index(light_scenarios, n_values, stats_key_counts):
        ngram_index = {n:{}  for n in n_values}
        for i, scenario in enumerate(light_scenarios):
            if i%20 == 0:
                print(f"n_gram indexing progress: {(i/len(light_scenarios))*100:.2f}%")
            for n in n_values:
                stats_key = scenario['scenario_key'] + '_' + str(n)
                stats_key_counts[stats_key] = len(scenario['instances'])
                for instance in scenario['instances']:
                    id = instance['id']                    
                    input_tokens = r.split(instance['input'].lower())
                    for input_ngram in ngrams(input_tokens, n):
                        if input_ngram not in ngram_index[n]:
                            ngram_index[n][input_ngram] = set()
                        ngram_index[n][input_ngram].add(stats_key + '+' + id + '+' + 'input')

                    # compute reference ngrams
                    for reference in instance['references']:
                        reference_unigrams = r.split(reference.lower())
                        for reference_ngram in ngrams(reference_unigrams, n):
                            if reference_ngram not in ngram_index[n]:
                                ngram_index[n][reference_ngram] = set()
                            ngram_index[n][reference_ngram].add(stats_key + '+' + id + '+' + 'references')
        return ngram_index
    
    # SETUP
    print("preparing scenarios and creating indexes")
    start = time.time()
    light_scenarios = []
    for i, (bytes_read, light_scenario_json) in enumerate(scenario_file.iter_lines(progress=True)):
        if i % 20 == 0:
            print(f"scenario creation progress: {(bytes_read/scenario_file.file_size)*100:.2f}%")

        light_scenario_dict: dict = json.loads(light_scenario_json)

        light_scenario_key_dict: dict = light_scenario_dict["scenario_key"]
        scenario_spec = str(light_scenario_key_dict["scenario_spec"])

        light_scenario_key = scenario_spec + '_' + light_scenario_key_dict["split"]
        light_instances = [
            {
                'input': instance_dict['input'], 
                'references': instance_dict['references'], 
                'id': instance_dict["id"]
            }
            for instance_dict in light_scenario_dict["instances"]
        ]
        light_scenarios.append({'scenario_key': light_scenario_key, 'instances': light_instances})
    print(f"Finished creating scenarios ({time.time()-start}s)")
    
    print("Creating indexes")
    
    start = time.time()
    stats_key_counts = defaultdict(int)
    ngram_index = create_ngram_index(
        light_scenarios=light_scenarios, n_values=[n], stats_key_counts=stats_key_counts
    )
    print(f"Finished creating indexes ({time.time()-start}s)")
        
    
    r = re.compile(r"[\s{}]+".format(re.escape(punctuation)))
    stats_key_to_input_ids = defaultdict(set)
    stats_key_to_reference_ids = defaultdict(set)
    print("computing overlap")
    start = time.time()
    
    domain.init_progress(input_files[0].file_size)

    for input_file in input_files:
        for i, (bytes_read, line) in enumerate(input_file.iter_lines(progress=True)):
            if i%1000 == 0:
                print(f"computing overlap progress: {(bytes_read / input_file.file_size) * 100:.2f}%")
                domain.set_progress(bytes_read)
            if i==10000:
                break
            document = json.loads(line)["text"]
            document_tokens = r.split(document.lower())
            for n in ngram_index.keys():
                for document_ngram in ngrams(document_tokens, n):
                    if document_ngram in ngram_index[n]:
                        for entry_overlap_key in ngram_index[n][document_ngram]:
                            stats_key, id, part = entry_overlap_key.split("+")
                            if part == "input":
                                stats_key_to_input_ids[stats_key].add(id)
                            elif part == "references":
                                stats_key_to_reference_ids[stats_key].add(id)
    print(f"Finished computing overlap ({time.time()-start}s)")
    print("done")
    
    return stats_key_to_input_ids, stats_key_to_reference_ids, stats_key_counts

In [None]:
client.code.submit(compute_document_data_overlap)

In [None]:
@sy.syft_function_single_use(input_files=helm_train_files, scenario_files=helm_test_files)
def main_function(domain, input_files, scenario_files):
    N = [5, 9, 13]
    jobs = []
    for n in N[:1]:
        for scenario_file in scenario_files:
            for train_file in input_files:
                batch_job = domain.launch_job(
                    compute_document_data_overlap,
                    scenario_file=scenario_file,
                    input_files=[train_file],
                    n=n
                )
                jobs.append(batch_job)

    return None


In [None]:
client.code.request_code_execution(main_function)

In [None]:
client.requests[-1].approve(approve_nested=True)

In [None]:
job = client.code.main_function(input_files=helm_train_files,
                                scenario_files=helm_test_files,
                                blocking=False)

# Inspect Jobs and get results

In [None]:
job

In [None]:
job.subjobs

In [None]:
# job.wait().get()

In [None]:
job.subjobs[0].logs()

In [None]:
results = [j.wait().get() for j in job.subjobs]

In [None]:
#stats_key_to_input_ids, stats_key_to_reference_ids, stats_key_counts
results

In [None]:
# results[0]

# Aggregate

In [None]:
stats_key_to_input_ids, stats_key_to_reference_ids, stats_key_counts = zip(*results)

total_input_ids = defaultdict(set)
total_reference_ids = defaultdict(set)
total_stats_key_counts = defaultdict(int)

for d in stats_key_counts:
    for key, val in d.items():
        total_stats_key_counts[key] += val


for d in stats_key_to_input_ids:
    for key in d:
        new_set = set()
        if key in total_input_ids:
            new_set = total_input_ids[key]
        new_set = new_set.union(d[key])
        total_input_ids[key] = new_set

for d in stats_key_to_reference_ids:
    for key in d:
        new_set = set()
        if key in total_reference_ids:
            new_set = total_reference_ids[key]
        new_set = total_reference_ids[key].union(d[key])
        total_reference_ids[key] = new_set

all_data_overlap_stats = []
for stats_key, count in total_stats_key_counts.items():
    data_overlap_stats = {
        'data_overlap_stats_key': None,
        'num_instances': count,
        'instance_ids_with_overlapping_input': sorted(total_input_ids[stats_key]),
        'instance_ids_with_overlapping_reference': sorted(total_reference_ids[stats_key]),
    }
    subject, split, n_str = stats_key.rsplit('_', 2)
    data_overlap_stats['data_overlap_stats_key'] = {
        'light_scenario_key': {'scenario_spec': subject, 'split': split},
        'overlap_protocol_spec': {'n': int(n_str)}
    }
    all_data_overlap_stats.append(data_overlap_stats)


In [None]:
from pprint import pprint
pprint(all_data_overlap_stats)