In [None]:
client.restart()

In [None]:
%load_ext memory_profiler

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
from dask.dot import dot_graph
import itertools
import logging
import netCDF4
import numpy as np
import dask.array as da
from dask import delayed
import time
from dask.distributed import Client
from urllib import request
from multiprocessing import Pool


client = Client('scheduler:8786')
#client = Client(processes=False)

download_location = '/temp'
data_url = 'http://172.22.0.1:8080'
data_url = 'http://nasanex.s3.amazonaws.com'
max_download_attempts = 5

all_models = ['ACCESS1-0',  'BNU-ESM', 'CCSM4', 'CESM1-BGC', 'CNRM-CM5', 'CSIRO-Mk3-6-0', 'CanESM2', 'GFDL-CM3', 'GFDL-ESM2G', 'GFDL-ESM2M', 'IPSL-CM5A-LR', 'IPSL-CM5A-MR', 'MIROC-ESM-CHEM', 'MIROC-ESM', 'MIROC5', 'MPI-ESM-LR', 'MPI-ESM-MR', 'MRI-CGCM3', 'NorESM1-M', 'bcc-csm1-1', 'inmcm4']
# all_models = ['ACCESS1-0', 'BNU-ESM'] 
all_vars = ['tasmax', 'pr']
all_years = {
     # 'historical': list(range(1971, 1976))
    'historical': list(range(1971, 2001))
}

def get_dataset_url(variable, scenario, model, year, prefix = data_url):
    prefix_filename = '/'.join([prefix, 'NEX-GDDP', 'BCSD', scenario, 'day', 'atmos', variable, 'r1i1p1', 'v1.0'])
    # prefix_filename = data_url + '/data'
    filename = '_'.join([variable, 'day', 'BCSD', scenario, 'r1i1p1', model, str(year) + '.nc'])
    url = '/'.join([prefix_filename, filename])
    return url

def get_context(year, **kwargs):
    variables = [kwargs.get('variable')] if kwargs.get('variable') else all_vars
    scenarios = ['historical']
    models = [kwargs.get('model')] if kwargs.get('model') else all_models
    outlist = []
    combinations = list(itertools.product(variables, scenarios, models))
    result = list(map(lambda comb: [ *comb, year ], combinations))
    return result

def get_year_ensemble(year, variable = 'tasmax'):
    context = get_context(year, variable = variable)
    datasets = list(map(lambda x: str(get_dataset_url(*x)), context))
    return datasets

def download_file(url):
    print("url: " + url)
    attempts = 0
    success = False
    filename = ""
    while attempts < max_download_attempts and not success:
        time.sleep(2 ** attempts)
        filename = '/'.join([download_location, str(url.split('/')[-1])])
        print("Downloading file at " + filename)
        u = request.urlopen(url)
        f = open(filename, 'wb')
        f.write(u.read())
        f.close()
        success = True
        break
    return filename

def download_file_list(url_list):
    print("Starting download pool")
    pool = Pool()
    res = pool.map(download_file, url_list)
    print("Jobs sent")
    pool.close()
    pool.join()
    print("Downloads finished")
    print(res)
    return res

"OK"

In [None]:
def download_and_stack(year, variable):
    dsets_urls = list(map(lambda x: get_year_ensemble(x, variable = variable), [year]))[0]
    filenames = download_file_list(dsets_urls)
    datasets = [ netCDF4.Dataset(filename) for filename in filenames ]
    dask_arrays = []
    for dset in datasets:
        dask_arrays.append(da.from_array(dset[str(variable)], chunks= (366, 144, 144)))
    final_stack = da.stack(dask_arrays, axis = 0)
    return final_stack

def avg_over_first_axis(darray):
    return np.average(darray, axis=0)

# %timeit stack_1971 = download_and_stack(1971, variable='tasmax')
stack_1971 = download_and_stack(1971, variable='pr')
#avg_stack_1971 = avg_over_first_axis(stack_1971)
#avg_stack_1971

In [None]:
stack_1971

In [None]:
def get_stacks_mod_avg(a, chunksize):
    nmodels, time, lat, lon = a.shape
    nstacks_lat = int(np.ceil(lat / chunksize))
    nstacks_lon = int(np.ceil(lon / chunksize))
    
    stacks = []
    
    for i in range(nstacks_lat):
        for j in range(nstacks_lon):
            latmin, latmax = i * chunksize, (i+1) * chunksize
            lonmin, lonmax = j * chunksize, (j+1) * chunksize
            print(i, j, '~>', latmin, latmax, lonmin, lonmax)
            stacked = a[:, :, latmin:latmax, lonmin:lonmax]
            print(stacked)
            stacks.append(stacked)
    return stacks

chunked_stacks = get_stacks_mod_avg(stack_1971, 360)
chunked_stacks

In [None]:
chunked_stacks = get_stacks_mod_avg(stack_1971, 360)
pr_mean_1971 = list(map(lambda x: delayed(np.mean)(x, axis=0).compute(), chunked_stacks))
pr_mean_1971


In [None]:
def restack(chunk_list, chunksize):
    shapes = list(map(np.shape, chunk_list))
    ndays = shapes[0][0]
    nlons = int(1440 / chunksize)
    nlats = int(720 / chunksize)

    out_array = np.empty((ndays, 720, 1440))
 
    combs = list(itertools.product(
        list(range(nlats)),
        list(range(nlons))
    ))
    
    res_list = zip(combs, chunk_list)
    
    for position, arr in res_list:
        minlon, maxlon = position[0] * chunksize, position[0] * chunksize + chunksize
        minlat, maxlat = position[1] * chunksize, position[1] * chunksize + chunksize
        out_array[:, minlon:maxlon, minlat:maxlat] = arr
    return out_array

pr_arr1971 = restack(pr_mean_1971, 360)
pr_arr1971.shape

In [None]:
%%memit
chunked_stacks = get_stacks_mod_avg(stack_1971, 360)
p50 = list(map(lambda x: np.percentile(x, 50, axis=0), chunked_stacks))
p50

In [None]:
%%memit
chunked_stacks = get_stacks_mod_avg(stack_1971, 360)
avg = list(map(lambda x: np.mean(x, axis=0), chunked_stacks))
avg

In [None]:
pr_arr1971[:, 40, 30]

In [None]:
np.histogram(pr_arr1971[100, :, :] < 1e-1, bins=100)

In [None]:
pr_arr1971[pr_arr1971 > 1] = 0

In [None]:
plt.imshow(pr_arr1971[0,:,:])