In [None]:
import os
from distributed import Client
from lpcjobqueue import LPCCondorCluster
import awkward as ak
import numpy as np
import torch
from utils.mlbench import SimpleWorkLog
from utils.mlbench import process_function, create_local_pnmodel, get_triton_client, run_inference_pnmodel, generate_pseudodata_from_seed
import time
import pathlib
#Can use ship_env and the .triton_env with LPCCondorCluster, but here's an alternative that should work for other cluster types
from distributed.diagnostics.plugin import UploadDirectory
import pickle

## Creating a dask cluster
A dask cluster can be created, which allows the scale-out of work distributed across multiple networked servers. These servers typically have only CPU resources (no GPU/FPGA co-processors)

In [None]:
cluster = LPCCondorCluster(cores=2, 
                           memory="7.5GB", 
                           disk="4GB", 
                           log_directory='/uscmst1b_scratch/lpc1/3DayLifetime/'+str(os.getlogin),
                           #ship_env=False,
                           #death_timeout=240,
                           #schedule_options={"dashboard_address": f":{__get_port():d}"},
                          )

The ```adapt``` function can change the requested workers

In [None]:
cluster.adapt(minimum=50, maximum=50)

In [None]:
cluster.workers

## The client
The listed IP address is incorrect for monitoring if we use a tunnel and port-forwarding. Instead, point your browser to 
https://localhost:8787/

In [None]:
client = Client(cluster)
client

## Uploading directories to access python packages and data
The remote workers are not guaranteed to have access to the python packages from this repo. There are multiple mechanisms to rectify this.
A general solution for a directory of files is to use the worker plugin ```UploadDirectory```, making sure to update_path and ```restart``` the
workers.

In [None]:
client.register_worker_plugin(UploadDirectory("../utils",restart=True,update_path=True), nanny=True) 

In [None]:
client.register_worker_plugin(UploadDirectory("../models",restart=True,update_path=True), nanny=True)  

## Testing client workers
To test that all the workers created have access to the packages needed to run the benchmark, we map a function (```test_workers```) to the workers that
use try-except clauses to test a few key features, and log the information in the return dictionary

In [None]:
def test_structure(x):
    """Simple test that the uploaded directories are where we expect, and we can import from them"""
    import os
    import sys
    import pathlib
    test = pathlib.Path("/srv/utils/")
    success = False
    try:
        from utils.mlbench import SimpleWorkLog
        success = True
    except:
        pass
    
    return os.environ, sys.path, success, list(test.iterdir())

def test_triton_dask(worker):
    """Test function for instantiating a triton client on a remote dask worker"""
    x = get_triton_client()
    if x is not None:
        return "success"
    else:
        return type(x)

def print_cluster_info(cluster):
    """Function for printing some cluster information locally, with some very quick/manual pretty-print formatting"""
    for key in cluster.scheduler_info.keys():
        if key not in ["workers"]:
            print(key, cluster.scheduler_info[key])
        else:
            print(key)
            for address, details in cluster.scheduler_info[key].items():
                print("\t", address)
                maxdkey = max([len(dkey) for dkey in details])
                for dkey, dval in details.items():
                    diff = maxdkey - len(dkey)
                    extras = " "*diff
                    extras += "  =\t"    
                    print("\t\t", dkey, extras, dval)
def test_workers(x):
    """Fully test the core elements for this benchmark: Being able to identify the client machine (pid/hostname),
    import the worklog function, import and run the local and remote model functions"""
    results = {}
    try:
        import os
        results["pid"] = os.getpid()
    except:
        results["pid"] = False
        
    import socket
    try:
        import socket
        results["hostname"] = socket.gethostname()
    except:
        results["hostname"] = False
        
    try:
        from utils.mlbench import SimpleWorkLog
        results["utils"] = True
    except:
        results["utils"] = False
        
    try:
        from utils.mlbench import get_triton_client
        _ = get_triton_client()
        results["triton"] = True
    except:
        results["triton"] = False
        
    try:
        from utils.mlbench import create_local_pnmodel
        _ = create_local_pnmodel()
        results["local"] = True
    except:
        results["local"] = False
    
    return results

Here is the test of the workers, using the ```gather``` and ```map``` functions

In [None]:
#Test the workers can perform basic functions
test = client.gather(client.map(test_workers, range(len(cluster.workers))))

We need to know how many unique workers are accessible and ready to distribute work to. If all the numbers match, we have ```n_workers``` ready and able to proceed

In [None]:
unique_workers = {}
for r in test:
    unique_workers[r['hostname']+str(r['pid'])] = r
n_workers = len(unique_workers.keys())
n_utils_imports = sum([r['utils'] for r in unique_workers.values()])
n_triton_functioning = sum([r['triton'] for r in unique_workers.values()])
n_local_functioning = sum([r['local'] for r in unique_workers.values()])
n_workers, n_utils_imports, n_triton_functioning, n_local_functioning

## Creating work parcels
Here we create work parcels to map to the ```process_function``` defined in ```utils.mlbench```

For triton inference, we have ```workargstriton``` and ```workargstritonlong```, the latter of which is designed to run 10 times as long. ```workargslocal``` is for local CPU inference, and will perform a factor of ~50x slower than the triton inference.
Each sublist is (in this order) the seed for random number generation, the number of pseudo-events to create, the batchsize for inference (limited memory requiring smaller batches on CPU-only workers), and finally whether the inference should proceed via Triton or not.
When the ```client.map``` function is called, the first parcel will take the first argument from each of these 4 sublists, and the last from the last argument of each sublist. ```client.submit``` can also be used to distribute the work parcels

In [None]:
n_workers = 100
long_multiplier = 10
#seeds, #pseudo-events, batchsize, use triton (True/False)
workargstriton = [range(n_workers), [1000]*n_workers, [1000]*n_workers, [True]*n_workers]
workargslocal = [range(n_workers), [1000]*n_workers, [250]*n_workers, [False]*n_workers]
workargstritonlong =  [range(n_workers*long_multiplier), 
                      [9999]*n_workers*long_multiplier, 
                      [1000]*n_workers*long_multiplier, 
                      [True]*n_workers*long_multiplier]

## Test
Here we test the ```run_inference_pnmodel``` function, which is one of the core steps of ```process_function```.
This will take an input of (pseudo)data, a local ParticleNet model or triton client configured for the same model, as well as a worklog function which will collect some dask client metrics like time spent in the various functions, the hostname, the number of bytes sent over the network for inferences, and so on

In [None]:
with_outputs, inf_worklogs, errors = run_inference_pnmodel(
    generate_pseudodata_from_seed(983, 1000), 
    get_triton_client(), 
    batchsize=1000, 
    triton=True, 
    worklog=SimpleWorkLog
)

## Test Triton workers
For 100 work parcels of 1000 pseudoevents, this can run in a few minutes

In [None]:
# Triton, N workers trial
print("time", time.time())
pft = time.perf_counter()
futurestriton = client.map(process_function, *workargstriton, pure=False)
resulttriton = client.gather(futurestriton)
print("runtime(s)", time.perf_counter() - pft)
print("time", time.time())

### Save the worklogs which are returned

In [None]:
with open(f"wm_triton_benchmark00.pickle", "wb") as output_file:
    pickle.dump(resulttriton, output_file)

## Stress Test Triton workers
This function should take 10 times as long, roughly, if the inference calls are the majority of the processing time

In [None]:
# Triton, N workers trial long
print("time", time.time())
pft = time.perf_counter()
futurestlong = client.map(process_function, *workargstritonlong, pure=False)
resulttlong = client.gather(futurestlong)
print("runtime(s)", time.perf_counter() - pft)
print("time", time.time())

In [None]:
with open(f"wm_tritonlong_benchmark00.pickle", "wb") as output_file:
    pickle.dump(resulttlong, output_file)

## Test Local workers
This requests the analogue of the ```Test Triton workers``` cell, but uses each dask worker's local CPU to run inference. 
This can take 50x longer to run, so the number of generated events and work parcels should be carefully balanced 

In [None]:
# Local, N workers trial
print("time", time.time())
pfl = time.perf_counter()
futureslocal = client.map(process_function, *workargslocal, pure=False)
resultlocal = client.gather(futureslocal)
print("runtime(s)", time.perf_counter() - pfl)
print("time", time.time())

In [None]:
with open(f"wm_local_benchmark00.pickle", "wb") as output_file:
    pickle.dump(resultlocal, output_file)

## Close the cluster

In [None]:
cluster.close()