In [4]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

import os
import sys
sys.path.append('../../')
import ndrought.drought_network as dnet

import dask

from queue import Queue
import gc
import pickle
sys.setrecursionlimit(int(1e4))

from dask.distributed import Client

In [5]:
from dask.distributed import LocalCluster, SSHCluster

In [6]:
cluster = LocalCluster()
cluster

Perhaps you already have a cluster running?
Hosting the HTTP server on port 46777 instead


Tab(children=(HTML(value='<div class="jp-RenderedHTMLCommon jp-RenderedHTML jp-mod-trusted jp-OutputArea-outpu…

In [3]:
cluster = SSHCluster(["localhost", "128.95.29.62:9815"])

HostKeyNotVerifiable: Host key is not trusted

In [5]:
#client = Client(n_workers=40, threads_per_worker=1)
client = Client('127.0.0.1:40283')
client

OSError: Timed out during handshake while connecting to tcp://127.0.0.1:40283 after 30 s

Load in data

In [2]:
print('Loading Data')
dm_path = '/pool0/home/steinadi/data/drought/drought_impact/data/drought_measures'

usdm = xr.open_dataset(f'{dm_path}/usdm/USDM_CONUS_105W_20000104_20220412.nc').load()
print('... USDM loaded')

Loading Data
... USDM loaded


In [3]:
def to_y(y, y_meta):
    y_min, y_max, y_spacing = y_meta
    return ((y_min-y_max)/y_spacing)*(y)+y_max

def to_x(x, x_meta):
    x_min, x_max, x_spacing = x_meta
    return ((x_max-x_min)/x_spacing)*(x)+x_min

def to_xy(coord, coord_meta):
    y_min, y_max, y_spacing, x_min, x_max, x_spacing = coord_meta

    y_meta = (y_min, y_max, y_spacing)
    x_meta = (x_min, x_max, x_spacing)

    y, x = coord
    return (to_x(x, x_meta), to_y(y, y_meta))

def collect_drought_track(args):
    x_list = []
    y_list = []
    u_list = []
    v_list = []
    t_list = []
    alpha_list = []

    origin, net_adj_dict, net_centroids, s_thresh, cmap = args

    q = Queue()
    q.put(origin.id)
    thread_ids = [origin.id]

    while not q.empty():
        
        current_id = q.get()

        for future_id in net_adj_dict[current_id]:
            if not future_id in thread_ids:
                q.put(future_id)
                thread_ids.append(future_id)
                        
                x, y, t, s = net_centroids[current_id]
                
                if s > s_thresh:
                    u, v, __, s_f = net_centroids[future_id]

                    x_list.append(x)
                    y_list.append(y)
                    u_list.append(u-x)
                    v_list.append(v-y)
                    t_list.append(t)

                    alpha_list.append(np.min((s_f/s, s/s_f)))

    if len(t_list) > 0:
        t_min = np.min(t_list)
        t_max = np.max(t_list)
        color_list = [cmap((t-t_min)/(t_max-t_min))[:-1] for t in t_list]

    return x_list, y_list, u_list, v_list, t_list, color_list, alpha_list

def extract_drought_tracks(net, coord_meta, client, cmap=plt.cm.get_cmap('viridis'), s_thresh=0):

    net_centroids = {node.id:(*to_xy(node.coords.mean(axis=0), coord_meta), node.time, len(node.coords)) for node in net.nodes}

    x_tracks = []
    y_tracks = []
    u_tracks = []
    v_tracks = []
    color_tracks = []
    alpha_tracks = []

    args = []
    print('Collecting Valid Origins')
    for origin in net.origins:

        # the ones that are one-off events I don't want to plot
        if len(origin.future) > 0:
            args.append([origin, net.adj_dict, net_centroids, s_thresh, cmap])
            
    #results = pool.map(collect_drought_track, tqdm(args, desc=f"Extracting Tracks from {net.name}"))    
    #results = pool.map(collect_drought_track, args)    

    print(f'Extracting Tracks {len(args)} from {net.name}')
    big_futures = client.scatter(args)
    futures = client.submit(collect_drought_track, big_futures)
    results = client.gather(futures)
        
    for result in tqdm(results, desc='Reshaping and Packaging'):
        x_list, y_list, u_list, v_list, color_list, alpha_list = result

        x_tracks.append(x_list)
        y_tracks.append(y_list)
        u_tracks.append(u_list)
        v_tracks.append(v_list)
        color_tracks.append(color_list)
        alpha_tracks.append(alpha_list)


    return x_tracks, y_tracks, u_tracks, v_tracks, color_tracks, alpha_tracks

In [None]:
dm_vars_expanded = {
    'usdm':['USDM'],
    #'spi':[f'spi_{interval}' for interval in intervals],
    #'spei':[f'spei_{interval}' for interval in intervals],
    #'eddi':[f'eddi_{interval}' for interval in intervals],
    #'pdsi':['pdsi'],
    #'grace':grace_vars
}

all_dm_ds = {
    'usdm':usdm,
    #'spi':spi,
    #'spei':spei,
    #'eddi':eddi,
    #'pdsi':pdsi,
    #'grace':grace
}

In [None]:
import logging
logger = logging.getLogger("distributed.utils_perf")
logger.setLevel(logging.ERROR)

In [None]:
# compute drought networks if not already made
for var in dm_vars_expanded.keys():
    for var_exp in dm_vars_expanded[var]:
        dnet_path = f'{dm_path}/ndrought_products/CONUS_105W/individual_dnet/{var_exp}_net.pickle'

        track_path = f'{dm_path}/ndrought_products/CONUS_105W/drought_tracks/{var_exp}_tracks.pickle'        

        if os.path.exists(dnet_path):
            var_dnet = dnet.DroughtNetwork.unpickle(dnet_path)            
        else:
            var_dnet = dnet.DroughtNetwork(all_dm_ds[var][var_exp].values, name=f'{var_exp.upper()} Drought Network')
            var_dnet.pickle(dnet_path)            

        if not os.path.exists(track_path):

            x_coords = all_dm_ds[var][var_exp].x.values
            y_coords = all_dm_ds[var][var_exp].y.values

            coord_meta = (
                np.min(y_coords), np.max(y_coords), len(y_coords),
                np.min(x_coords), np.max(x_coords), len(x_coords)
            )

            var_tracks = extract_drought_tracks(
                net=var_dnet,
                coord_meta=coord_meta,
                client=client
            )

            f = open(track_path, 'wb')
            pickle.dump(var_tracks, f, pickle.HIGHEST_PROTOCOL)
            f.close()

            x_coords = None
            y_coords = None
            coord_meta = None
            var_tracks = None
            f = None

        var_dnet = None

        gc.collect()

Collecting Valid Origins
Extracting Tracks 319 from USDM Drought Network




CommClosedError: in <TCP (closed) ConnectionPool.scatter local=tcp://127.0.0.1:56536 remote=tcp://127.0.0.1:34803>: Stream is closed

Future exception was never retrieved
future: <Future finished exception=CommClosedError('in <TCP (closed) ConnectionPool.update_data local=tcp://127.0.0.1:55552 remote=tcp://127.0.0.1:42017>: BrokenPipeError: [Errno 32] Broken pipe')>
Traceback (most recent call last):
  File "/pool0/data/steinadi/.conda/envs/sedi/lib/python3.10/site-packages/tornado/iostream.py", line 971, in _handle_write
    num_bytes = self.write_to_fd(self._write_buffer.peek(size))
  File "/pool0/data/steinadi/.conda/envs/sedi/lib/python3.10/site-packages/tornado/iostream.py", line 1148, in write_to_fd
    return self.socket.send(data)  # type: ignore
BrokenPipeError: [Errno 32] Broken pipe

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/pool0/data/steinadi/.conda/envs/sedi/lib/python3.10/site-packages/tornado/gen.py", line 769, in run
    yielded = self.gen.throw(*exc_info)  # type: ignore
  File "/pool0/data/steinadi/.conda/envs/sedi/lib/python3.