In [None]:
import datetime
import numpy as np
import pandas as pd
import xarray as xr
import re
import ast
import pickle as pkl

from ahh import vis, ext
from sklearn import linear_model
from scipy.stats.stats import pearsonr, combine_pvalues
from sklearn.metrics import mean_squared_error as mse
from sklearn.preprocessing import MinMaxScaler
from keras.layers import Dense, LSTM
from keras.models import Sequential

In [None]:
# SETTINGS
CONCEPTUAL = False
POS_NEG = False
SAVE = False
RUN_ID = 99  # 11 is good jk5, straight, all; 12 is jk1, 13 is half validation period, 15 is jk3, 16 is +-， 99 is test
JK_INT = 3  # jack knife (split if 0)
# MONTHS = range(1, 13)
MONTHS = [99]
LEADS = range(1, 12)
TIMESCALE = 'daily'

YEARS = None  # None will be automatic
TIME_BEHIND = 0
LR = True
VAR_LIST = ['sst', 'wwv']

# CONSTANTS BELOW

SST_FP = 'data/pkl/sst_k.pkl'
WWV_FP = 'data/pkl/wwv_m3.pkl'
WND_FP = 'data/pkl/wnd_ms.pkl'
CSV_FP_FMT = 'output/{0}_{1:03d}.csv'
POS_NEG_LIST = ['positive', 'negative']

NINO34_LIM = vis.get_region_latlim('nino34', w2e=True)
SCALER = MinMaxScaler(feature_range=(-1, 1))
STATS = ['corr', 'pval', 'rmse']
KEYS = ['values', 'stat', 'lead', 'month',
        'year', 'timescale', 'var', 'lr', 'jk_int', 'time_behind']
TS_KEYS = ['valid_y', 'valid_z'] + STATS + KEYS[2:]

LSTM_H5_SAVE_FMT = 'models/{lead}_{month}_{year}_{timescale}_{var}_{lr}_{jk_int}_{time_behind}.h5'
LSTM_JSON_SAVE_FMT = 'models/{lead}_{month}_{year}_{timescale}_{var}_{lr}_{jk_int}_{time_behind}.json'
LR_SAVE_FMT = 'models/{lead}_{month}_{year}_{timescale}_{var}_{lr}_{jk_int}_{time_behind}.pkl'
OFFSET_SCALE_SAVE_FMT = 'models/{lead}_{month}_{year}_{timescale}_{var}_{lr}_{jk_int}_{time_behind}.npy'

In [None]:
def get_timescale_settings(timescale):
    if timescale == 'monthly':
        daily = False
        time_multiplier = 1
    else:
        daily = True
        time_multiplier = 31
    return daily, time_multiplier


def convert_resolution(df, daily=True):
    var_name = df.columns[0]
    if daily:
        gb_df = df.groupby([df.index.year, df.index.month])
        gb_df_list = []
        for i, group in enumerate(sorted(gb_df.groups)):
            group_df = gb_df.get_group(group)
            column = [datetime.datetime(group[0], group[1], 1)]
            group_df = group_df.reset_index().drop('time', axis=1)
            group_df.columns = column
            group_df = group_df.reindex(range(1, 32), method='ffill')
            gb_df_list.append(group_df)
        converted_df = pd.concat(gb_df_list, axis=1).T
        converted_df.columns = ['{0}_day{1:02d}'.format(var_name, col) for col in converted_df.columns]
        del gb_df_list
    else:
        converted_df = df.resample('1MS').mean()  # monthly
    return converted_df


def shift_by_lead_lag(df, lead, lag):
    col_len = len(df.columns)
    for col in df.columns:
        for alag in range(lag + 1):
            df['lag{0:02d}_lead{1:02d}_{2}'.format(alag, lead, col)] = df[col].shift(lead + alag)
    df = df[df.columns[col_len:]]
    return df.dropna()


def sort_df_cols(df, var_name, daily=True):
    if daily:
        sorted_cols = ['lag{0:02d}_lead{1:02d}_{2}_day{3:02d}'.format(alag,
                                                                      lead,
                                                                      var_name,
                                                                      day) 
                           for alag in range(TIME_BEHIND, -1, -1)
                               for day in range(1, 32)]
        df = df[sorted_cols]
    else:
        df = df[df.columns[::-1]]
    return df


def scale_df(df):
    df_indx = df.index
    df_cols = df.columns
    scaled_df = pd.DataFrame(SCALER.fit_transform(df),
                             columns=df_cols,
                             index=df_indx)
    offset = SCALER.min_
    scale = SCALER.scale_
    return scaled_df, offset, scale


def divide_train_val(df):
    if JK_INT != 0:
        train = df.loc[~((df.index.year > year - JK_INT) & (df.index.year <= year))]
        valid = df.loc[((df.index.year > year - JK_INT) & (df.index.year <= year))]
    else:
        mid_date = '2000-04-01'
        train = df.loc[df.index < mid_date]
        valid = df.iloc[df.index >= mid_date]
    
    if CONCEPTUAL:
        print(train['sst'])
        print(valid['sst'])
    
    valid_idx = valid.index
    train = train.values
    valid = valid.values

    train_x, train_y = train[:, 1:], train[:, 0]
    valid_x, valid_y = valid[:, 1:], valid[:, 0]  # predictors, obs
    return train_x, train_y, valid_x, valid_y, valid_idx


def reshape_x(fcst_dict):
    if LR:
        fcst_dict['train_x'] = np.rollaxis(np.stack(fcst_dict['train_x']), 1).reshape(
            np.shape(fcst_dict['train_x'])[1], time_multiplier * (TIME_BEHIND + 1) * len(var_df_list))
        fcst_dict['valid_x'] = np.rollaxis(np.stack(fcst_dict['valid_x']), 1).reshape(
            np.shape(fcst_dict['valid_x'])[1], time_multiplier * (TIME_BEHIND + 1) * len(var_df_list))
    else:
        fcst_dict['train_x'] = np.rollaxis(np.stack(fcst_dict['train_x']).T, 1).reshape(
            np.shape(fcst_dict['train_x'])[1], time_multiplier * (TIME_BEHIND + 1), len(var_df_list))
        fcst_dict['valid_x'] = np.rollaxis(np.stack(fcst_dict['valid_x']).T, 1).reshape(
            np.shape(fcst_dict['valid_x'])[1], time_multiplier * (TIME_BEHIND + 1), len(var_df_list))

    return fcst_dict


def train_regr(train_mod, train_obs):
    model = linear_model.LinearRegression()
    model.fit(train_mod, train_obs)
    return model


def train_lstm(train_mod, train_obs):
    model = Sequential()
    model.add(LSTM(input_shape=(train_mod.shape[1],
                                train_mod.shape[2]),
                   return_sequences=True, units=50)
             )
    model.add(LSTM(150, return_sequences=False))
    model.add(Dense(1))
    model.compile(loss='mse', optimizer='adam')

    model.fit(train_mod, train_obs, epochs=8,
              batch_size=8, verbose=0, shuffle=True)
    return model


def unscale_df(df, offset, scale):
    return (df - offset[0]) / scale[0]


def get_prediction(fcst_dict, time_multiplier, offset, scale,
                   lead, month, year):
    train_x = fcst_dict['train_x']
    train_y = fcst_dict['train_y']

    valid_x = fcst_dict['valid_x']
    valid_y = fcst_dict['valid_y']

    if LR:
        model = train_regr(train_x, train_y)
    else:
        model = train_lstm(train_x, train_y)

    valid_z = model.predict(valid_x)
    valid_z = unscale_df(valid_z, offset, scale)

    valid_y = unscale_df(valid_y, offset, scale)

    if not LR:
        valid_z = valid_z[:, -1]
    
    if SAVE:
        save_dict = dict(lead=lead, month=month,
                         year=year, timescale=TIMESCALE,
                         var='_'.join(VAR_LIST),
                         lr=LR, jk_int=JK_INT,
                         time_behind=TIME_BEHIND)
        if not LR:
            lstm_h5_save = LSTM_H5_SAVE_FMT.format(**save_dict)
            model.save_weights(lstm_h5_save)

            lstm_json_save = LSTM_JSON_SAVE_FMT.format(**save_dict)
            lstm_json = model.to_json()
            with open(lstm_json_save, 'w') as json_file:
                json_file.write(lstm_json)
                
        else:
            lr_save = LR_SAVE_FMT.format(**save_dict)
            with open(lr_save, 'wb') as fi:
                pkl.dump(model, fi)

    return valid_y, valid_z


def get_corr_pval_rmse(arr1, arr2):
    arr1 = np.array(arr1)
    arr2 = np.array(arr2)
    corr_pval = pearsonr(arr1, arr2)
    rmse = np.sqrt(mse(arr1, arr2))
    return corr_pval[0], corr_pval[1], rmse


def write_stat(val, stat, lead, month, year):
    fn = CSV_FP_FMT.format(stat, RUN_ID)
    write_dict = dict(val=val, stat=stat,
                      lead=lead, month=month,
                      year=year, timescale=TIMESCALE,
                      var='_'.join(VAR_LIST), lr=LR, jk_int=JK_INT,
                      time_behind=TIME_BEHIND)
    with open(fn, 'a') as file:
        file.write('{val}, {stat}, {lead}, {month}, {year}, {timescale}, {var}, {lr}, {jk_int}, {time_behind}\n'.format(**write_dict))
    return fn


def write_timeseries(df_ts_out):
    df_ts_out = df_ts_out[TS_KEYS]
    fn = CSV_FP_FMT.format('time', RUN_ID)
    with open(fn, 'a') as f:
        df_ts_out.to_csv(f, header=False)
    return fn


def get_label(month):
    label = 'jk{jk_int}_{month}_{time_behind}_{timescale}_{lr}_with_{var}'.format(
        lr='lr' if LR else 'lstm',
        month='segmented' if month != 99 else 'linked',
        time_behind=TIME_BEHIND,
        timescale=TIMESCALE,
        var='_'.join(VAR_LIST),
        jk_int=JK_INT,
    ).replace(' ', '')
    return label


def get_cleansed_df(fn, time=False):
    if not time:
        df = pd.read_csv(fn, header=None, names=KEYS)
    else:
        df = pd.read_csv(fn, header=None, names=TS_KEYS)

    df = df.dropna()
    
    if not time:
        df.index = range(0, len(df))
        df['stat'] = df['stat'].str.strip()
        df['timescale'] = df['timescale'].str.strip()
        df['var'] = df['var'].str.strip()
        try:
            df['lr'] = df['lr'].str.strip().apply(ast.literal_eval)
        except AttributeError:
            pass
    else:
        fn = CSV_FP_FMT.format('time', RUN_ID)
        df.index = pd.DatetimeIndex(df.index)
        df.to_csv(fn, header=False)
    return df


def filter_and_pivot(df, month, filter_only=False):
    df = df.loc[(df['lr'] == LR) &
                (df['timescale'] == TIMESCALE) &
                (df['var'] == '_'.join(VAR_LIST)) &
                (df['jk_int'] == JK_INT) &
                (df['time_behind'] == TIME_BEHIND)]

    if month != 99:
        df = df.loc[df['month'] != 99].groupby(['lead', 'year']).mean().reset_index()
    else:
        df = df.loc[df['month'] == 99]

    if not filter_only:
        df = df.drop_duplicates(df.columns[1:])
        df = df.pivot(index='lead', columns='year', values='values')  # pivoted
    return df


def get_pivoted_csvs(corr_df, pval_df, rmse_df, month):
    corr_df = filter_and_pivot(corr_df, month)
    pval_df = filter_and_pivot(pval_df, month)
    rmse_df = filter_and_pivot(rmse_df, month)
    return corr_df, pval_df, rmse_df


def neaten(input_str):
    return input_str.replace('_', ' ').title()


def get_avg_corr_pval(corr_df, pval_df, axis=None):
    if axis is None:
        avg_corr = np.tanh(np.arctanh(corr_df)).mean()
        avg_pval = combine_pvalues(pval_df)[1]
    else:
        avg_corr = np.tanh(np.arctanh(corr_df).mean(axis=1))
        avg_pval = list(zip(*pval_df.apply(combine_pvalues, axis=1)))[1]
    return avg_corr, avg_pval


def plot(corr_df, pval_df, rmse_df, label):
    vis_dict = dict(figsize='na', xlabel='Lead', rows=2, linewidth=0.75, length_scale=False, xlim=(1, len(corr_df)))
    _ = vis.set_figsize(12, 8)
    clist = vis.get_color_list(vis.get_cmap('GMT_haxby', n=int(len(corr_df.columns)) + 1))
    for i, col in enumerate(corr_df):
        ax = vis.plot_line(corr_df.index, corr_df[col],
                           label='{0}'.format(col),
                           color=clist[i],
                           ylabel='Correlation',
                           **vis_dict)
        ax2 = vis.plot_line(rmse_df.index, rmse_df[col],
                            label='{0}'.format(col),
                            color=clist[i],
                            ylabel='RMSE',
                            pos=2,
                            **vis_dict)
        i += 1

    avg_corr_line, avg_pval_line = get_avg_corr_pval(corr_df, pval_df, axis=1)
    avg_rmse_line = rmse_df.mean(axis=1)

    ax = vis.plot_line(pval_df.index, avg_pval_line, label='PVal',
                       color='gray',
                       linestyle='--',
                       marker='x',
                       **vis_dict)

    ax = vis.plot_line(corr_df.index, avg_corr_line, label='Mean',
                       color='k',
                       marker='.',
                       ylim=(-0, 1),
                       ylabel='Correlation',
                       title=neaten(label),
                       **vis_dict)

    ax2 = vis.plot_line(rmse_df.index, avg_rmse_line, label='Mean'.format(i=year),
                        color='k',
                        marker='.',
                        ylim=(0, 2),
                        ylabel='RMSE',
                        pos=2,
                        **vis_dict)

    avg_corr, avg_pval = get_avg_corr_pval(avg_corr_line, avg_pval_line)
    vis.set_axtext(ax, 'Avg Corr All Leads: {0:.1%}'.format(avg_corr), loc='bottom right')
    vis.set_axtext(ax2, 'Avg RMSE All Leads: {0:.3}'.format(avg_rmse_line.mean()), loc='bottom right')

    ax = vis.set_legend(ax, ncol=3, size=5, loc='top right')
    ax2 = vis.set_legend(ax2, ncol=3, size=5, loc='top right')

    vis.savefig('output/{0}.png'.format(label))


def get_row_label(df, i):
    label = 'lead{lead}_{month}_{timescale}_tb{time_behind}_{lr}_jk{jk_int}_with_{var}'.format(
        lr='lr' if df['lr'][i] else 'lstm',
        month='segmented' if df['month'][i] != 99 else 'linked',
        timescale=df['timescale'][i],
        var=df['var'][i],
        jk_int=df['jk_int'][i],
        lead=df['lead'][i],
        time_behind=df['time_behind'][i]
    ).replace(' ', '')
    return label


def get_summary(df):
    summary_list = []
    for i in range(len(df)):
        summary_list.append(get_row_label(df, i))
    summary_unique = np.unique(summary_list)
    del summary_list
    summary_len = len(summary_unique)
    max_leads = int(summary_unique[-1][4])
    return summary_unique[int(-summary_len / max_leads):]

In [None]:
s = datetime.datetime.utcnow()

if CONCEPTUAL:
    dt_range = pd.date_range('1982-01-01', '2017-12-31')
    sst_df = pd.DataFrame({'sst': dt_range.month}, index=dt_range)
    sst_df.index.name = 'time'

    wwv_df = pd.DataFrame({'wwv': dt_range.month * 10}, index=dt_range)
    wwv_df.index.name = 'time'

    wnd_df = pd.DataFrame({'wnd': dt_range.month * 10}, index=dt_range)
    wnd_df.index.name = 'time'
else:
    sst_df = pd.read_pickle(SST_FP)
    wwv_df = pd.read_pickle(WWV_FP)
    wnd_df = pd.read_pickle(WND_FP)

if YEARS is None and JK_INT != 0:
    years = np.unique(sst_df.index.year)[JK_INT::JK_INT]
elif JK_INT == 0:
    years = np.unique(sst_df.index.year)
else:
    years = YEARS

obs_col = sst_df['sst'].resample('1MS').mean()

var_df_list = []
for var in VAR_LIST:
    if var == 'sst':
        var_df_list.append(('sst', sst_df))
    elif var == 'wwv':
        var_df_list.append(('wwv', wwv_df))
    elif var == 'wnd':
        var_df_list.append(('wnd', wnd_df))

df_list = []
for year in years:
    for i, lead in enumerate(LEADS):
        print('\n{0} | {1} Lead: {2}'.format(datetime.datetime.utcnow() - s, year, lead), end=" - ")
        for month in MONTHS:
            print(month, end= " ")
            daily, time_multiplier = get_timescale_settings(TIMESCALE)

            fcst_pos_neg_dict = {}  # used if POS_NEG == True
            pos_neg_loc_dict = {}
            for j, pos_neg in enumerate(POS_NEG_LIST):  # if POS_NEG
                fcst_dict = {'train_x': [], 'valid_x': []}
                for i, (var_name, var_df) in enumerate(var_df_list):
                    shifted_df = shift_by_lead_lag(convert_resolution(var_df, daily), lead, TIME_BEHIND)
                    sorted_df = sort_df_cols(shifted_df, var_name, daily=daily)
                    entire_fcst_df = pd.concat([obs_col, sorted_df], axis=1).dropna()

                    if POS_NEG:
                        if 'sst' in VAR_LIST:
                            pos_neg_var = 'sst'
                        elif 'wwv' in VAR_LIST:
                            pos_neg_var = 'wwv'
                        else:
                            pos_neg_var = 'wnd'
                    
                        if i == 0:
                            if TIMESCALE == 'daily':
                                col_fmt = 'lag00_lead{0:02d}_{1}_day01'
                            elif TIMESCALE == 'monthly':
                                col_fmt = 'lag00_lead{0:02d}_{1}'

                            pos_neg_loc_dict['positive'] = entire_fcst_df[col_fmt.format(lead, pos_neg_var)] >= 0
                            pos_neg_loc_dict['negative'] = entire_fcst_df[col_fmt.format(lead, pos_neg_var)] < 0

                        fcst_df = entire_fcst_df.loc[pos_neg_loc_dict[pos_neg]]
                    else:
                        fcst_df = entire_fcst_df
                        
                    if month != 99:
                        fcst_df = fcst_df.loc[fcst_df.index.month == month]

                    scaled_df, offset, scale = scale_df(fcst_df)

                    if SAVE:
                        save_dict = dict(lead=lead, month=month,
                             year=year, timescale=TIMESCALE,
                             var=var_name, lr=LR, jk_int=JK_INT,
                             time_behind=TIME_BEHIND)

                        offset_scale = np.array([offset, scale])
                        offset_scale_save = OFFSET_SCALE_SAVE_FMT.format(**save_dict)
                        pkl_save = np.save(offset_scale_save, offset_scale)

                    if CONCEPTUAL:
                        train_x, train_y, valid_x, valid_y, valid_idx = divide_train_val(fcst_df)
                    else:
                        train_x, train_y, valid_x, valid_y, valid_idx = divide_train_val(scaled_df)

                    fcst_dict['train_x'].append(train_x.reshape(len(train_x), time_multiplier * (TIME_BEHIND + 1)))
                    fcst_dict['valid_x'].append(valid_x.reshape(len(valid_x), time_multiplier * (TIME_BEHIND + 1)))
                    if i == 0:
                        fcst_dict['train_y'] = train_y
                        fcst_dict['valid_y'] = valid_y
                        fcst_dict['valid_idx'] = valid_idx

                if POS_NEG:
                    fcst_pos_neg_dict[pos_neg] = fcst_dict

                elif not POS_NEG:
                    break

            pos_neg_df_list = []
            for j, pos_neg in enumerate(POS_NEG_LIST):
                if POS_NEG:
                    fcst_dict = fcst_pos_neg_dict[pos_neg]

                fcst_dict = reshape_x(fcst_dict)
                valid_y, valid_z = get_prediction(fcst_dict, time_multiplier, offset, scale, lead, month, year)
                if POS_NEG:
                    pos_neg_df_list.append(pd.DataFrame(data={'valid_y': valid_y, 'valid_z': valid_z},
                                                        index=fcst_dict['valid_idx']))
                    if j == 1:
                        pos_neg_df = pd.concat(pos_neg_df_list)
                        valid_y = pos_neg_df['valid_y']
                        valid_z = pos_neg_df['valid_z']
                elif not POS_NEG:
                    break

            corr, pval, rmse = get_corr_pval_rmse(valid_y, valid_z)

            if not CONCEPTUAL:
                corr_fn = write_stat(corr, 'corr', lead, month, year)
                pval_fn = write_stat(pval, 'pval', lead, month, year)
                rmse_fn = write_stat(rmse, 'rmse', lead, month, year)

                df_ts_dict = dict(valid_y=valid_y, valid_z=valid_z,
                                  corr=corr, pval=pval, rmse=rmse,
                                  lead=lead, month=month,
                                  year=year, timescale=TIMESCALE,
                                  var='_'.join(VAR_LIST),
                                  lr=LR, jk_int=JK_INT,
                                  time_behind=TIME_BEHIND)
                df_ts_out = pd.DataFrame(df_ts_dict, index=valid_idx)

                time_fn = write_timeseries(df_ts_out)
            
            if month == 99:
                break
        if CONCEPTUAL:
            break

    if CONCEPTUAL or JK_INT == 0:
        break

del var_df_list
print('\nDONE!')

if not CONCEPTUAL:
    corr_fn = CSV_FP_FMT.format('corr', RUN_ID)
    pval_fn = CSV_FP_FMT.format('pval', RUN_ID)
    rmse_fn = CSV_FP_FMT.format('rmse', RUN_ID)
    time_fn = CSV_FP_FMT.format('time', RUN_ID)

    label = get_label(month)
    corr_df = get_cleansed_df(corr_fn)
    pval_df = get_cleansed_df(pval_fn)
    rmse_df = get_cleansed_df(rmse_fn)
    time_df = get_cleansed_df(time_fn, time=True)
    time_df = filter_and_pivot(time_df, month, filter_only=True)
    get_summary(corr_df)

    corr_df, pval_df, rmse_df = get_pivoted_csvs(corr_df, pval_df, rmse_df, month)
    plot(corr_df, pval_df, rmse_df, label)  # add print correlation