In [1]:
import syft as sy
from syft.store.blob_storage import BlobStorageConfig, BlobStorageClientConfig
from syft.store.blob_storage.seaweedfs import SeaweedFSClient, SeaweedFSClientConfig
from syft import ActionObject
from syft.service.action.action_data_empty import ActionFileData
from syft.service.queue.zmq_queue import ZMQQueueConfig, ZMQClientConfig
from collections import defaultdict
from syft.types.blob_storage import BlobFile



In [2]:
node = sy.orchestra.launch(name="test-domain-helm2", dev_mode=True,
                           reset=True,
                           n_consumers=4,
                           create_producer=True)
client = node.login(email="info@openmined.org", password="changethis")

Staging Protocol Changes...
Data Migrated to latest version !!!
Logged into <test-domain-helm2: High side Domain> as <info@openmined.org>


In [3]:
client.register(name="A", email="a@b.org", password="b", password_verify="b")

You can also run this with seaweed, but then you need to run the seaweed container manually and connect to it:

```bash
docker run --entrypoint /bin/sh -p 8333:8333 -p 8888:8888 chrislusf/seaweedfs -c "echo 's3.configure -access_key admin -secret_key admin -user iam -actions Read,Write,List,Tagging,Admin -apply' | weed shell > /dev/null 2>&1 & weed server -s3 -s3.port=8333 -master.volumeSizeLimitMB=2048"
```

In [4]:
# blob_config = BlobStorageConfig(client_type=SeaweedFSClient,
#                                 client_config=SeaweedFSClientConfig(host="http://0.0.0.0",
#                                                                     port="8333",
#                                                                     access_key="admin",
#                                                                     secret_key="admin",
#                                                                     bucket_name="test_bucket",
#                                                                     region="us-east-1"
#                                                                    )
# )

In [5]:
# node.python_node.init_blob_storage(blob_config)

# Inputs

In [6]:
# TODO: fix way we send list of files
scenario_objs = ActionObject.from_obj([
    BlobFile.upload_from_path("scenario_data.jsonl", client)
])

scenario_files_ptr = scenario_objs.send(client)

In [7]:
input_files = ActionObject.from_obj([
    BlobFile.upload_from_path("short_input.jsonl", client)
])

In [8]:
input_files_ptr = input_files.send(client)

In [9]:
input_files_dataset = sy.Dataset(
    name="Helm dataset",
    asset_list=[
        sy.Asset(
            name="helm train data",
            data=input_files_ptr,
            mock=sy.ActionObject.empty()
        ),
        sy.Asset(
            name="helm test data",
            data=scenario_files_ptr,
            mock=sy.ActionObject.empty()
        )
    ]
)

In [10]:
client.upload_dataset(input_files_dataset)

 50%|████████████████████████████████████████████████████▌                                                    | 1/2 [00:00<00:00,  5.61it/s]

Uploading: helm train data
Uploading: helm test data


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.84it/s]


In [11]:
input_files_asset = client.datasets["Helm dataset"].assets[0]

In [12]:
scenario_files_asset = client.datasets["Helm dataset"].assets[1]

# Syft functions

In [13]:
@sy.syft_function()
def compute_document_data_overlap(domain, scenario_file, input_files, n):
    print("starting overlap computation")
    from nltk import ngrams
    from collections import defaultdict
    from string import punctuation
    import re, json

    r = re.compile(r"[\s{}]+".format(re.escape(punctuation)))
    
    def create_ngram_index(light_scenarios, n_values, stats_key_counts):
        ngram_index = {n:{}  for n in n_values}
        for scenario in light_scenarios:
            for n in n_values:
                stats_key = scenario['scenario_key'] + '_' + str(n)
                stats_key_counts[stats_key] = len(scenario['instances'])
                for instance in scenario['instances']:
                    id = instance['id']                    
                    input_tokens = r.split(instance['input'].lower())
                    for input_ngram in ngrams(input_tokens, n):
                        if input_ngram not in ngram_index[n]:
                            ngram_index[n][input_ngram] = set()
                        ngram_index[n][input_ngram].add(stats_key + '+' + id + '+' + 'input')

                    # compute reference ngrams
                    for reference in instance['references']:
                        reference_unigrams = r.split(reference.lower())
                        for reference_ngram in ngrams(reference_unigrams, n):
                            if reference_ngram not in ngram_index[n]:
                                ngram_index[n][reference_ngram] = set()
                            ngram_index[n][reference_ngram].add(stats_key + '+' + id + '+' + 'references')
        return ngram_index
    
    # # SETUP
    print("preparing scenarios and creating indexes")
    light_scenarios = []
    for light_scenario_json in scenario_file.iter_lines():
        light_scenario_dict: dict = json.loads(light_scenario_json)

        light_scenario_key_dict: dict = light_scenario_dict["scenario_key"]
        subject_spec = light_scenario_key_dict["scenario_spec"]['args']['subject']
        light_scenario_key = subject_spec + '_' + light_scenario_key_dict["split"]
        light_instances = [
            {
                'input': instance_dict['input'], 
                'references': instance_dict['references'], 
                'id': instance_dict["id"]
            }
            for instance_dict in light_scenario_dict["instances"]
        ]
        light_scenarios.append({'scenario_key': light_scenario_key, 'instances': light_instances})
        
    stats_key_counts = defaultdict(int)
    
    ngram_index = create_ngram_index(
        light_scenarios=light_scenarios, n_values=[n], stats_key_counts=stats_key_counts
    )
    
    r = re.compile(r"[\s{}]+".format(re.escape(punctuation)))
    stats_key_to_input_ids = defaultdict(set)
    stats_key_to_reference_ids = defaultdict(set)
    print("computing overlap")
    from time import sleep
    sleep(1)
    
    domain.init_progress(input_files[0].file_size)

    for input_file in input_files:
        for bytes_read, line in input_file.iter_lines(progress=True):
            sleep(1)
            document = json.loads(line)["text"]
            document_tokens = r.split(document.lower())
            for n in ngram_index.keys():
                for document_ngram in ngrams(document_tokens, n):
                    if document_ngram in ngram_index[n]:
                        for entry_overlap_key in ngram_index[n][document_ngram]:
                            stats_key, id, part = entry_overlap_key.split("+")
                            if part == "input":
                                stats_key_to_input_ids[stats_key].add(id)
                            elif part == "references":
                                stats_key_to_reference_ids[stats_key].add(id)
            domain.set_progress(bytes_read)
    print("Finished overlap computation")
    
    return stats_key_to_input_ids, stats_key_to_reference_ids, stats_key_counts

In [14]:
@sy.syft_function()
def aggregate(batch_results):
    print("Starting aggregation")
    from collections import defaultdict
    stats_key_to_input_ids, stats_key_to_reference_ids, stats_key_counts = zip(*batch_results)

    total_input_ids = defaultdict(set)
    total_reference_ids = defaultdict(set)
    total_stats_key_counts = defaultdict(int)

    for d in stats_key_counts:
        for key, val in d.items():
            total_stats_key_counts[key] += val


    for d in stats_key_to_input_ids:
        for key in d:
            new_set = set()
            if key in total_input_ids:
                new_set = total_input_ids[key]
            new_set = new_set.union(d[key])
            total_input_ids[key] = new_set

    for d in stats_key_to_reference_ids:
        for key in d:
            new_set = set()
            if key in total_reference_ids:
                new_set = total_reference_ids[key]
            new_set = total_reference_ids[key].union(d[key])
            total_reference_ids[key] = new_set

    all_data_overlap_stats = []
    for stats_key, count in total_stats_key_counts.items():
        data_overlap_stats = {
            'data_overlap_stats_key': None,
            'num_instances': count,
            'instance_ids_with_overlapping_input': sorted(total_input_ids[stats_key]),
            'instance_ids_with_overlapping_reference': sorted(total_reference_ids[stats_key]),
        }
        subject, split, n_str = stats_key.rsplit('_', 2)
        data_overlap_stats['data_overlap_stats_key'] = {
            'light_scenario_key': {'scenario_spec': subject, 'split': split},
            'overlap_protocol_spec': {'n': int(n_str)}
        }
        all_data_overlap_stats.append(data_overlap_stats)
    print("Finished aggregation")
    return all_data_overlap_stats


In [15]:
client.code.submit(compute_document_data_overlap)

In [16]:
client.code.submit(aggregate)

In [17]:
@sy.syft_function_single_use(input_files=input_files_asset, scenario_files=scenario_files_asset)
def main_function(domain, input_files, scenario_files):
    N = [5, 9, 13]
    batch_results = []
    for n in N[:1]:
        for scenario_file in scenario_files:
            batch_job = domain.launch_job(
                compute_document_data_overlap,
                scenario_file=scenario_file,
                input_files=input_files,
                n=n
            )
            batch_results.append(batch_job.result)
    
    aggregate_job = domain.launch_job(aggregate, batch_results=batch_results)
    print("Finished main function")
    return aggregate_job.result


In [18]:
client.code.request_code_execution(main_function)
client.requests[-1].approve(approve_nested=True)

Request approved for domain test-domain-helm2


In [19]:
job = client.code.main_function(input_files=input_files_asset,
                                scenario_files=scenario_files_asset,
                                blocking=False)

# Get results

In [20]:
job

20/12/23 13:41:09 FUNCTION LOG (d95fbe1f54f947ceae2fdb090d2449c8): Finished main function


```python
class Job:
    id: UID = d95fbe1f54f947ceae2fdb090d2449c8
    status: processing
    has_parent: False
    result: syft.service.action.action_data_empty.ObjectNotReady
    logs:

0 Finished main function
    
```

In [21]:
job.subjobs

20/12/23 13:41:11 FUNCTION LOG (47e67b4fdbe84f54a31684c8cf0917bc): starting overlap computation
20/12/23 13:41:12 FUNCTION LOG (47e67b4fdbe84f54a31684c8cf0917bc): preparing scenarios and creating indexes
20/12/23 13:41:12 FUNCTION LOG (47e67b4fdbe84f54a31684c8cf0917bc): computing overlap


In [22]:
res = job.result.wait().get()

20/12/23 13:41:24 FUNCTION LOG (47e67b4fdbe84f54a31684c8cf0917bc): Finished overlap computation
20/12/23 13:41:27 FUNCTION LOG (9aed743396c84e6fa3ded244132be58b): Starting aggregation
20/12/23 13:41:27 FUNCTION LOG (9aed743396c84e6fa3ded244132be58b): Finished aggregation


In [23]:
from pprint import pprint
pprint(res)

[{'data_overlap_stats_key': {'light_scenario_key': {'scenario_spec': 'philosophy',
                                                    'split': 'train'},
                             'overlap_protocol_spec': {'n': 5}},
  'instance_ids_with_overlapping_input': [],
  'instance_ids_with_overlapping_reference': [],
  'num_instances': 5},
 {'data_overlap_stats_key': {'light_scenario_key': {'scenario_spec': 'philosophy',
                                                    'split': 'valid'},
                             'overlap_protocol_spec': {'n': 5}},
  'instance_ids_with_overlapping_input': ['id12'],
  'instance_ids_with_overlapping_reference': [],
  'num_instances': 34},
 {'data_overlap_stats_key': {'light_scenario_key': {'scenario_spec': 'philosophy',
                                                    'split': 'test'},
                             'overlap_protocol_spec': {'n': 5}},
  'instance_ids_with_overlapping_input': ['id328'],
  'instance_ids_with_overlapping_reference': [],
  