In [None]:
import os
import numpy as np
import xarray as xr
import netCDF4 as nc
from toolz import first
from collections import defaultdict
import shutil
from scipy import signal

import dask
from dask_jobqueue import SLURMCluster
from dask.distributed import Client, fire_and_forget, futures_of, wait
import dask.array as da
import dask.bag as db
from dask import config as cfg

--------

## Input parameters

### File Locations

In [None]:
# Data dir
data_dir = ""

# Where to save files:
dest_dir = ""
os.makedirs(dest_dir, exist_ok=True)


### Data selection

In [None]:
wave_channel = 'T'     # Channel of the wavefield: RTZ 

# wave channel to select
wave_dim = 'RTZ'.index(wave_channel)   # finds the index for wave_channel in string "RTZ"
# NaG
nag = 1
# Slice number
islice = 0 

---------------------------

## Look for data

In [None]:
# All original file names
nc_fnames = [f for f in os.listdir(data_dir) if 'axisem3d_synthetics.nc' in f]

# Find rank numbers that have already been done
present_ranks = []
for f in os.listdir(dest_dir):
    present_ranks.append(f.split('.')[-1])

done_fnames = ['axisem3d_synthetics.nc'+'.'+x for x in present_ranks]
todo_fnames = list(set(nc_fnames)-set(done_fnames))

# First run, only aim to do 300 of these
#todo_fnames = todo_fnames[0:300]

for k in range(len(todo_fnames)):
    todo_fnames[k] = data_dir + '/' + todo_fnames[k]

print("%d NC files to consider" % len(todo_fnames))

## Set up dask cluster


In [None]:
num_cores = 25
num_mem = num_cores*8
run_time = "9:00:00"
extra_args = ["--output=SLURM OUTPUT DIRECTORY"]
cluster = SLURMCluster(cores=num_cores,memory=str(num_mem)+"GB",processes=1,walltime= run_time, 
                       job_extra_directives=extra_args)
# disable worker heartbeat
cfg.set({'distributed.scheduler.worker-ttl': None})
# extend some time limits
cfg.set({'distributed.comm.timeouts.connect': 300})
cfg.set({'distributed.comm.timeouts.tcp': 300})

In [None]:
num_nodes = int(np.ceil(len(todo_fnames)/num_cores))
#num_nodes = 5
print("%d nodes requested" % num_nodes)
cluster.scale(num_nodes)

In [None]:
client = Client(cluster)
client

In [None]:
client.wait_for_workers(num_nodes)

## Make and distribute the filter

In [None]:
# Make filter with Scipy
low_corner = 1
high_corner = 4.99
filter_order = 8
sampling_frequency = 1/0.1  # check this in the element-wise section of inparam.output.yaml!!!
butter_filter = signal.butter(filter_order,[low_corner, high_corner],'bandpass',output='sos',fs=sampling_frequency)
filter_futures = client.scatter(butter_filter,broadcast=True)

## Distribute data

In [None]:
def open_datasets(nc_fname):
    # open a dataset
    nc_file = xr.open_dataset(nc_fname,engine="netcdf4",chunks={})  # This converts all arrays inside to Dask arrays
    
    return [nc_file, nc_fname]

In [None]:
data_bag = db.from_sequence(todo_fnames, npartitions=len(todo_fnames)).map(open_datasets)
data_bag = data_bag.persist()
wait(data_bag)

# Find out which worker has which piece of data, hence MPI rank assigned by Dask
key_to_part_dict = {str(part.key): part for part in futures_of(data_bag)}
who_has = client.who_has(data_bag)
worker_map = defaultdict(list)
for key, workers in who_has.items():
    worker_map[first(workers)].append(key_to_part_dict[key])

## Save filtered data

In [None]:
lscratch_dir = os.getenv('L_SCRATCH_JOB')
def save_filtered_wavefield(butter_filter,data,**kwargs):
    ### MPI rank for this job
    rank = kwargs["rank"]
    save_dir = lscratch_dir + '/task' +str(rank)
    os.makedirs(save_dir, exist_ok=True)
       
    ### Load and select data
    # Select the GLL coordinate in each element to filter
    elemGLL = 4
    wave_component = data[0][0].variables['data_wave__NaG=1'][:,kwargs["nslice"],elemGLL,kwargs["channel"],:].data
    #wave_space_time = wave_component.data
    
    szInFile = data[0][0].variables['list_element_coords'][:,elemGLL,:].data
    
    elemTag = data[0][0].variables['list_element_na'][:,0].data
    
    ### MPI rank for the original file
    original_rank = data[0][1].split('/')[-1].split('.')[-1]
    
    ### Prepare to save filtered data
    # Name the files with original rank
    saveFile = save_dir +"/filtered.nc" +'.'+original_rank
    destFile = dest_dir + "/filtered.nc" +'.'+original_rank
    try: ncfile.close()  # just to be safe, make sure dataset is not already open. 
    except: pass         # Be careful, opening a file with 'w' will clobber any existing data
    ncfile = nc.Dataset(saveFile,mode='w',format='NETCDF4')
    # Dimensions
    time_dim = ncfile.createDimension('time', None)       # time axis, unlimited
    coord_dim = ncfile.createDimension('coord', 2)        # coord axis, dim 2 for s, z
    nelem_dim = ncfile.createDimension('nelem', szInFile.shape[0])       # element axis, dim numer of elements in this file
    # Variables
    filtered_data = ncfile.createVariable('filtered_data', np.float32, ('nelem','time')) 
    coordInFile = ncfile.createVariable('sz',np.float32,('nelem','coord'))
    elemTagInFile = ncfile.createVariable('tags',np.uint,('nelem'))
    
    # Put in coordinates
    coordInFile[:,:] = szInFile
    # Put in global element tag numbers
    elemTagInFile[:] = elemTag
    
    ### Filter and save
    for ielem in range(wave_component.shape[0]):
        filtered_data[ielem,:] = signal.sosfilt(butter_filter,wave_component[ielem,:].astype("float32"))

    ### Close and copy back the file
    ncfile.close()
    shutil.copyfile(saveFile, destFile)
    
    return


In [None]:
f = []
for worker, list_of_parts in worker_map.items():
    for i in range(len(list_of_parts)):
        f.append(client.submit(save_filtered_wavefield,filter_futures,list_of_parts[i],
                               rank=list_of_parts[i].key[1], nslice=islice, channel=wave_dim,
                               workers=worker))

In [None]:
f