In [None]:
import os
import numpy as np
import xarray as xr
import netCDF4 as nc
import pyvtk
import re
from toolz import first
from collections import defaultdict
import shutil
import pickle

import dask
from dask_jobqueue import SLURMCluster
from dask.distributed import Client, fire_and_forget, futures_of, wait
import dask.array as da
import dask.bag as db
from dask import config as cfg

## Set up dask cluster

In [None]:
num_cores = 25
num_mem = num_cores*8
run_time = "06:00:00"
extra_args = ["--output= WHERE DO YOU WANT SLURM OUTPUT TO GO"]
cluster = SLURMCluster(cores=num_cores,memory=str(num_mem)+"GB",processes=1,walltime= run_time, 
                       job_extra_directives=extra_args)
# disable worker heartbeat
cfg.set({'distributed.scheduler.worker-ttl': None})
# extend some time limits
cfg.set({'distributed.comm.timeouts.connect': 300})
cfg.set({'distributed.comm.timeouts.tcp': 300})

In [None]:
num_node = 20
cluster.scale(num_node)

In [None]:
client = Client(cluster)
client

In [None]:
client.wait_for_workers(num_node)

--------

## Input parameters

### File Locations

In [None]:
# Data dir (location for filtered data)
data_dir = ""

# Where to save files:
dest_dir = ""
os.makedirs(dest_dir, exist_ok=True)


## Data selection

In [None]:
wave_channel = 'T'     # Channel of the wavefield: RTZ 
phi = 45                # Azimuth of the physical slice in degrees

## Time step

In [None]:
dt = 0.1                 # sampling frequency used for output in simulation
frame_rate = 0.5         # every how many second to save a frame 

### Undulation

In [None]:
undulated = True
#### Directory containing files to compute dZ 
fileDir =''
### Undulation geometry file name:
undFile = 'dz_dict.pkl'
### Set Nr as used in simulation
Nr = 5

## Element points selection

In [None]:
# choose:
# "center"
# "ccm" (center, corner, midpoints)
element_points = "center"

---------------------------

## Load data

In [None]:
nc_fnames = [f for f in os.listdir(data_dir) if 'filtered' in f]
for k in range(len(nc_fnames)):
    nc_fnames[k] = data_dir + '/' + nc_fnames[k]

### List of distributed or serial tasks for loading

#### parallelize over number of nc data sets, i.e. MPI ranks

In [None]:
def open_datasets(nc_fname):
    # open a dataset 
    
    return xr.open_dataset(nc_fname,engine="netcdf4",chunks={}) # This converts all arrays inside to Dask arrays

def load_coords(nc_file):
    '''
    Load coordinates sz from each NC file.
    Directly acccess the 'variables' below, which returns
    dask arrays.
    '''
    # read coordinates

    return nc_file.variables['sz']

def load_elem_tag(nc_file):   
    return nc_file.variables['tags']

def load_fdata(nc_file):
    '''
    Load filtered data
    '''

    return nc_file.variables['filtered_data']




### submit tasks for loading

In [None]:
nc_file_futures = client.map(open_datasets,nc_fnames)

load_coords_futures = client.map(load_coords,nc_file_futures)
all_coords_sz = da.concatenate(client.gather(load_coords_futures),axis=0)

elem_tag_futures = client.map(load_elem_tag,nc_file_futures)
elem_tag_all = da.concatenate(client.gather(elem_tag_futures),axis=0)

load_fdata_futures = client.map(load_fdata,nc_file_futures)
fdata = da.concatenate(client.gather(load_fdata_futures),axis=0)

nelem = len(elem_tag_all)

## Generate wavefield animation on a physical slice

### Note: several parameters to tune below

In [None]:
# GLL layout in one element:
# 4 - 9 - 14 - 19 - 24
# |   |    |    |    |
# 3 - 8 - 13 - 18 - 23
# |   |    |    |    |
# 2 - 7 - 12 - 17 - 22
# |   |    |    |    | 
# 1 - 6 - 11 - 16 - 21
# |   |    |    |    |
# 0 - 5 - 10 - 15 - 20

# Spatial downsampling:
# corners+center+midpoint:
# 2--5--8
# |  |  |
# 1--4--7
# |  |  |
# 0--3--6
# This corresponds to 
# 4-14-24
# |  |  |
# 2-12-22
# |  |  |
# 0-10-20
# in the dZ array

# Connectivity array (shared by all slices and time frames)
connectivity = []

# Undulating index:
if element_points == "ccm":
    # corners+midpoints+center
    undIdx = [0,2,4,10,12,14,20,22,24]
    
    for ielem in np.arange(nelem):
        start = ielem * 9
        connectivity.append([start + 0, start + 1, start + 4, start + 3])
        connectivity.append([start + 1, start + 2, start + 5, start + 4])
        connectivity.append([start + 3, start + 4, start + 7, start + 6])
        connectivity.append([start + 4, start + 5, start + 8, start + 7])
    
elif element_points == "center":
    # center only
    undIdx = [12]
    ngll=len(undIdx)
    for ielem in np.arange(nelem):
        start = ielem * ngll
        connectivity.append([start + 0])
    

### Function to convert coordinate

In [None]:
def sz2rtheta(s,z):
    '''
    Convert cylindrical s,z to spherical r, theta
    '''
    r = np.sqrt(s**2+z**2)
    theta = np.arccos(z/r)
    return r,theta


### Get wave on the particular physical slice and channel
### Get coordinates for the physical slice

In [None]:
all_coords_sz = all_coords_sz.compute()
elem_tag_all = elem_tag_all.compute()

In [None]:
import time
start_time = time.time()


### Undulated geometry 
if undulated:
    with open(fileDir+undFile,"rb") as p:
        dict_data_dz = pickle.load(p)
    if element_points == "center":
        pts_per_elem = ngll
    elif element_points == "ccm":
        pts_per_elem = 9    
    for ielem in np.arange(nelem):
        elemTag = int(elem_tag_all[ielem])
        start = ielem * pts_per_elem
        if elemTag in dict_data_dz.keys():
            dz = dict_data_dz[elemTag]
            r,theta = sz2rtheta(all_coords_sz[start:(start+ngll), 0],all_coords_sz[start:(start+ngll), 1])
            # S
            all_coords_sz[start:(start+ngll), 0] += dz[undIdx]*np.sin(theta)
            # Z
            all_coords_sz[start:(start+ngll), 1] += dz[undIdx]*np.cos(theta)
            
print("--- %s seconds ---" % (time.time() - start_time))

## Final preperation

In [None]:
lscratch_dir = os.getenv('L_SCRATCH_JOB')
def make_and_save_vtk(tlist,**kwargs):
    ### MPI rank
    rank = str(kwargs["rank"])
    save_dir = lscratch_dir + '/task' +rank
    os.makedirs(save_dir, exist_ok=True)
    ### vtk mesh
    x = kwargs["coords"][:, 0] * np.cos(phi)
    y = kwargs["coords"][:, 0] * np.sin(phi)
    z = kwargs["coords"][:, 1]
    vtk_mesh = pyvtk.UnstructuredGrid(list(zip(x, y, z)),
                                      vertex=kwargs["connect"])

    for itime in tlist:
        vtk = pyvtk.VtkData(vtk_mesh, pyvtk.PointData(pyvtk.Scalars(fdata[:,itime].compute(), name='U' + wave_channel)))
        saveFile = save_dir +"/fwave%d.vtk" % itime
        destFile = dest_dir + "/fwave%d.vtk" % itime
        vtk.tofile(saveFile, 'binary')
        shutil.copyfile(saveFile, destFile)
    return 




In [None]:
coord_future = client.scatter(all_coords_sz,broadcast=True)
con = np.asarray(connectivity)
connect_future = client.scatter(con,broadcast=True)

phi = np.radians(phi)

ntime = fdata.shape[1]
time_gap = int(frame_rate/dt)
time_list = np.arange(0,ntime,time_gap)
time_bag = db.from_sequence(time_list,npartitions=100)
time_bag = time_bag.persist()
wait(time_bag)

# Write out a text file for the number of outputs to expect
with open(dest_dir+'/NOF.txt','w') as f:
    f.write("The number of expected filtered VTK files is: %d" % len(time_list)+'\n')

# Find out which worker has which piece of data, hence MPI rank assigned by Dask
key_to_part_dict = {str(part.key): part for part in futures_of(time_bag)}
who_has = client.who_has(time_bag)
worker_map = defaultdict(list)
for key, workers in who_has.items():
    worker_map[first(workers)].append(key_to_part_dict[key])

# Submit tasks to write vtk files!

In [None]:
f = []
for worker, list_of_parts in worker_map.items():
    for i in range(len(list_of_parts)):
        f.append(client.submit(make_and_save_vtk,list_of_parts[i],rank = list_of_parts[i].key[1],
                               coords=coord_future,connect=connect_future,
                               workers=worker))
        


In [None]:
f