In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
from datetime import datetime, timedelta
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import os
import dask
import dask.delayed as delayed
from dask.distributed import Client
from glob import glob
import track_analysis
import multiprocessing

In [None]:
# DIRECTORIES
dir_input = '../STORM'
dir_output = './output'

# List of models to analyse. Model data read as {dir_input}/{model_name}/*{basin}*
model_list = ['IBTRACS','CMCC','HADGEM','ECEARTH','CNRM']

# STORM basin from which to read input files
basin = 'NA' # NA, SP, WP, EI or WI

# ANALYSIS GRID SETTINGS
resolution = 1/4 # Grid resolution in degrees
lonmin = -107  # Grid minimum longitude
lonmax = 2  # Grid maximum longitude
latmin = 0  # Grid minimum latitude
latmax = 63  # Grid maxmimum latitude
margin = 3  # Margin around grid within which to generate storms (degrees)

# Example (lonmin, lonmax, latmin, latmax) for basins
# NA: (-107, 2, 63, 3)
# SP: ()
# WP: ()
# EI: ()
# WI: ()

# ANALYSIS SETTINGS
radius = 200

In [None]:
def multiple_analysis( df_idx_list, margin, lonmin, 
                       lonmax, latmin, latmax, resolution, 
                       radius):

    n_idx = len(df_idx_list)
    output_data = []

    for kk in range(n_idx):
        out_single = single_analysis( df_idx_list[kk], margin, 
                                      lonmin, lonmax, latmin, latmax, resolution,
                                      radius = radius)
        if out_single is None:
            continue
        else:
            ds = out_single
            output_data.append(ds)
    
    ds_concat = xr.concat( output_data, dim='storm' )
    ds_out = xr.Dataset()
    for cat in [0, 1, 2, 3, 4, 5]:
        ds_tmp = ds_concat.data == cat
        ds_tmp = ds_tmp.sum( dim='storm' )
        ds_out[f'category_{cat}'] = ds_tmp

    return ds_out

def single_analysis(df_ii, margin, 
                    lonmin, lonmax, latmin, latmax, 
                    resolution, dir_tmp = './tmp', 
                    radius=200):
    '''

    '''

    grid_lon = np.arange(lonmin, lonmax, resolution)
    grid_lat = np.arange(latmin, latmax, resolution)
    lon2, lat2 = np.meshgrid(grid_lon, grid_lat)
    lonF = lon2.flatten()
    latF = lat2.flatten()
    
    prdist = track_analysis.track_distance_to_box( df_ii, lonmin, lonmax, latmin, latmax )
    n_r, n_c = lon2.shape
    df_ii = df_ii[['longitude','latitude','category']]

    if prdist > margin:
        t = np.zeros((6, n_r, n_c))
        return
    else:
        df_interp = track_analysis.interpolate_track( df_ii, delta=1/3 )
        df_interp['category'] = df_interp['category'].round(0).astype(int)
        
        t = track_analysis.track_distance_to_grid( df_interp,
                             lonF, latF,
                             radius=radius )
        t = t.reshape((6, n_r, n_c))
        
    for ii in range(6):
        t[ii] = t[ii]*(ii+1)
    t = np.max(t, axis = 0) - 1
    
    ds_out = xr.Dataset()
    ds_out['x'] = (['x'], grid_lon)
    ds_out['y'] = (['y'], grid_lat)
    ds_out['data'] = (['y','x'], t)
    return ds_out

In [None]:
# Loop over models in sepcified list
for model_name in model_list:

    # Get list of inputs and read into dataframe
    fp_inputs = glob(f'{dir_input}/{model_name}/*{basin}*')
    tracks_list = [ track_analysis.read_STORM( fp ) for fp in fp_inputs ]
    tracks = pd.concat(tracks_list).reset_index()

    # Separate tracks into individual events
    tracks_events = track_analysis.separate_events( tracks )
    n_events = len(tracks_events)

    client = Client()
    
    pdel = delayed(multiple_analysis)
    compute_list = []
    batch_size = 50
    for idx0 in np.arange(0, 250, batch_size):#np.arange(0,n_events,batch_size):
        idx1 = idx0 + batch_size
        tracks_ii_list = tracks_events[idx0:idx1]
        compute_list.append( pdel( tracks_ii_list, margin, 
                                   lonmin, lonmax, latmin, latmax, 
                                   resolution, radius ) )
    
    out = dask.compute(compute_list, scheduler = 'processes')

    ds_concat = xr.concat(out[0], dim='storm')
    ds_total_count = ds_concat.sum(dim='storm').compute()
    n_years = 10000
    ds_count_per_year = ds_total_count / n_years

    attrs = {'title': f'Average number of tropical cyclones of varying categories passing within {radius}km per year.',
             'radius': f'{radius}km',
             'years_analysed': f'{n_years}',
             'track_dataset':'STORM',
             'model':f'{model_name}'}
    ds_count_per_year.attrs = attrs

    fp_out_template = os.path.join( dir_output, 'annual_probabilities_STORM_{0}_{1}_200km.nc' )
    ds_count_per_year.to_netcdf(fp_out_template.format(model_name, basin))

    for fp in fp_to_concat:
        os.remove(fp)

    # Clean up
    client.close()
    del track_events
    del tracks
    del ds_tmp
    del ds_out