### Import libraries

In [None]:
import sys
sys.path.append("..")
import os
from model import experimental2d_model, grapher
from data import loader
from helpers import helpers, metrics
import tensorflow as tf
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
def embed_position(t, d, TΔmin, Tmax):  # return.shape=(T,B,d)
    # t.shape=(T,B)   T=sequence_length, B=batch_size
    """A position-embedder, similar to the Attention paper, but tweaked to account for
    floating point positions, rather than integer.
    """
    R = Tmax / TΔmin * 100
    drange_even = TΔmin * R**(np.arange(0,d,2)/d)
    drange_odd = TΔmin * R**((np.arange(1,d,2) - 1)/d)
    x = np.concatenate([np.sin(t[:,:,None] / drange_even), np.cos(t[:,:,None] / drange_odd)], 2)
    return x

In [None]:
df = np.load('/home/omernivr/Downloads/riverflow/maurer.pickle.npy', allow_pickle= True)
df_attributes = np.load('/home/omernivr/Downloads/riverflow/attributes.pickle', allow_pickle= True)
att_csv = pd.read_csv('/home/omernivr/Downloads/riverflow/data/attibutes.csv')
basin_list = pd.read_csv('/home/omernivr/Downloads/riverflow/data/basin_list.txt', header=None)

In [None]:
attributes = ['p_mean', 'pet_mean', 'p_seasonality', 'frac_snow',
               'aridity', 'high_prec_freq', 'high_prec_dur', 
               'low_prec_freq', 'low_prec_dur', 
               'carbonate_rocks_frac', 
                'geol_permeability', 
               'soil_depth_pelletier', 'soil_depth_statsgo', 'soil_porosity','soil_conductivity', 
                'max_water_content', 'sand_frac', 'silt_frac', 'clay_frac',  
                'elev_mean', 'slope_mean', 
                'area_geospa_fabric', 
                'frac_forest', 'lai_max', 'lai_diff', 'gvf_max','gvf_diff']

In [None]:
basin_list = basin_list[0].apply(lambda x: '0' + str(x) if len(str(x)) < 8 else str(x))

In [None]:
att_csv1 = att_csv[attributes]
att_csv1 = (att_csv1 - np.mean(att_csv1, axis  = 0)) / np.std(att_csv1, axis =0 )

In [None]:
df = df[()]

In [None]:
ymd = df['01013500'].groupby(['Year', 'Mnth']).size().reset_index()

In [None]:
ymd['d_cumsum'] = 0
for y in ymd['Year']:
    temp = ymd.loc[ymd['Year'] == y, 0].cumsum()
    ymd.loc[ymd['Year'] == y,'d_cumsum'] =  np.concatenate(([0], temp[:-1]))

In [None]:
for k,v in df.items():
    df[k].reset_index(inplace=True)
    df[k] = pd.DataFrame.merge(v, ymd, right_on=['Year', 'Mnth'], left_on=['Year', 'Mnth'])
    df[k]['n_day'] = df[k]['d_cumsum_x'] + df[k]['Day']
    df[k]['t'] = df[k]['Year'] + (df[k]['n_day'] - 1) / 366

In [None]:
df_filtered = {}
for (b) in basin_list: 
    if (df[b].shape[0] < 10593):
        basin_list = basin_list[basin_list!=b]
        continue
    df_filtered[b] = df[b][['Date','t', 'Dayl(s)', 'PRCP(mm/day)', 'SRAD(W/m2)', 'Tmax(C)', 'Tmin(C)', 'Vp(Pa)', 'Q']]

In [None]:
%%time
# Choose type of transform, i.e., 'standardize' or 'normalize' 
dist = 'gaussian'

if dist == 'gaussian':
    transform = 'standardize'
    log_P = True
    log_Q = True
    
if dist == 'gamma':
    transform = 'normalize'
    log_P = True
    log_Q = True
    gamma_shift = 1e-3

divide_by_area = True
cols = ['t','Dayl(s)', 'PRCP(mm/day)', 'SRAD(W/m2)', 'Tmax(C)', 'Tmin(C)', 'Vp(Pa)', 'Q']
epsilon = 1e-3

x_maxs, x_mins, x_means, x_stds = [], [], [], []
for k,v in df_filtered.items():
#     # Scale streamflow values by catchment area
    if divide_by_area:
        v['Q'] = v['Q']/df_attributes[k]['area_geospa_fabric'].values
    
    # Calculate mean (after scaling by area)
#     v['Q_mu'] = v['Q'].mean()
    
#     Log-transform precipitation
    if log_P: 
        v['PRCP(mm/day)'] = np.log(v['PRCP(mm/day)'] + epsilon)
    
    # Log-transform streamflow
    if log_Q: 
        v['Q'] = np.log(v['Q'] + epsilon)
    
#     x_maxs.append(v[cols].max().values)
#     x_mins.append(v[cols].min().values)
    x_means.append(v[cols].values)
    x_stds.append(v[cols].values)

# x_max = np.concatenate(x_maxs).reshape(-1,len(cols)).max(axis=0)
# x_min = np.concatenate(x_mins).reshape(-1,len(cols)).min(axis=0)
x_mean = np.concatenate(x_means, axis=0).mean(axis=0)
x_std = np.concatenate(x_stds, axis=0).std(axis = 0)

for k,v in df_filtered.items():
    for i, col in enumerate(cols):
        if transform == 'normalize':
            
            v[col] = (v[col] - x_min[i]) / (x_max[i] - x_min[i])
            
            if dist=="gamma":
                v['Q'] = v['Q'] + gamma_shift
            
            def rev_transform(x):
                x = x * (x_max[0] - x_min[0]) + x_min[0]
                if log_Q:
                    x = np.exp(x) - epsilon
                if dist == "gamma":
                    x = x - gamma_shift
                return x
            
            def rev_transform_tensor(x):
                x = x * (x_max[0] - x_min[0]) + x_min[0]
                if log_Q:
                    x = torch.exp(x) - epsilon 
                if dist == "gamma":
                    x = x - gamma_shift
                return x
        
        elif transform == 'standardize':
            
            v[col] = (v[col] - x_mean[i]) / x_std[i]
            
            #WARNING -- NO GAMMA SHIFT
            
            def rev_transform(x):
                x = x * x_std[0] + x_mean[0]
                if log_Q:
                    x = np.exp(x) - epsilon
                if dist == "gamma":
                    x = x - gamma_shift
                return x

            def rev_transform_tensor(x):
                x = x * x_std[0] + x_mean[0]
                if log_Q:
                    x = torch.exp(x) - epsilon 
                if dist == "gamma":
                    x = x - gamma_shift
                return x
        
        else:
            print("No transform has been applied.")

In [None]:
df_filt_sliced = {}
for b in basin_list:
    df_filt_sliced[b] = {}
    temp = df_filtered[b].reset_index()[:10500]
    
    # Use "499" not "500" -- since it is a pandas dataframe indexing
    t_y = np.array([embed_position((temp.loc[j:j+499, 't'])[:, None], d=122, TΔmin= 0.2, Tmax=180) for j in range(0, 10499, 500)])
    df_filt_sliced[b]['x'] = temp[['PRCP(mm/day)','Dayl(s)', 'SRAD(W/m2)', 'Tmax(C)', 'Tmin(C)', 'Vp(Pa)']]
#     df_filt_sliced[b]['x']['basin'] = b
    df_filt_sliced[b]['t'] = t_y
    df_filt_sliced[b]['y'] = temp['Q']
    df_filt_sliced[b]['date'] = temp['Date']

In [None]:
train_step, test_step, train_loss, test_loss, m_tr, m_te = grapher.build_graph()

### Training

In [None]:
basin_list = basin_list.reset_index()
all_idx = np.arange(len(basin_list))
train_b_idx = np.random.choice(all_idx, 400, replace = False)
test_b_idx = all_idx[~np.isin(all_idx, train_b_idx).reshape(-1)]
valid_b_idx = test_b_idx[50:]

In [None]:
d = False
save_dir = os.path.expanduser('~/Downloads/riverflow/river_flow')

In [None]:
if __name__ == '__main__':
    step = 0
    # change to run 9 if you want to overfit
    EPOCHS = 75; batch_s  = 16; run = 52; tr_regime ='shuffle'
    l= [128, 64, 32]; heads = 16; e = 128; context = 400; c= 400
    name_comp = 'run_' + str(run)
    logdir = save_dir + '/logs/' + name_comp
    writer = tf.summary.create_file_writer(logdir)
    folder = save_dir + '/ckpt/check_' + name_comp
    optimizer_c = tf.keras.optimizers.Adam(3e-4)
    helpers.mkdir(folder)
    decoder = experimental2d_model.Decoder(e, l[0], l[1], l[2], num_heads=heads)
    num_batches = 500
    ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=optimizer_c, net=decoder)
    manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=3)
    ckpt.restore(manager.latest_checkpoint)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")
    with writer.as_default():
        for epoch in range(EPOCHS):
            start = time.time()
            for batch_n in range(num_batches):
                m_tr.reset_states(); train_loss.reset_states()
                
                # choose the basins in each batch     
                batch_idx = np.random.choice(train_b_idx, 16)
                
                # choose the block of time 
                block_idx = np.random.choice(np.arange(21), 1)
                
                                
                y_tr = (np.array([df_filt_sliced[basin_list.iloc[batch_idx[i]][0]]['y'] for i in range(16)])).reshape(16, 21, 500)[:, block_idx, :]
            
                x_tr = np.array([np.concatenate(
                                (np.array(df_filt_sliced[basin_list.iloc[batch_idx[i]][0]]['x']).reshape(21, 500, -1), 
                                                 df_filt_sliced[basin_list.iloc[batch_idx[i]][0]]['t'].squeeze()), axis = 2) 
                                 for i in range(16)])[:, block_idx, :, :]
                
                
                to_gather = helpers.gather_idx(c, l=500, b=16)
                
                temp = np.zeros((16, 500))
                
                temp[to_gather[:, 0], to_gather[:, 1]] = 1
                
                
                pred, pred_log, weights, names, shapes, y_real, g = train_step(decoder, optimizer_c, train_loss, m_tr, x_tr.squeeze(), y_tr.squeeze(), d = True, to_gather=temp)
                if (epoch == 0) & (batch_n == 0): helpers.write_speci(folder, names, shapes, context, heads)
                if batch_n % 3 == 0:
                    m_te.reset_states(); test_loss.reset_states()


                    y_te = np.array([df_filt_sliced[basin_list.iloc[valid_b_idx[i]][0]]['y'] for i in range(len(valid_b_idx))]).reshape(-1, 21, 500)[:, block_idx, :]
                    
                    
                    x_te = np.array([np.concatenate(
                                (np.array(df_filt_sliced[basin_list.iloc[valid_b_idx[i]][0]]['x']).reshape(21, 500, -1), 
                                                 df_filt_sliced[basin_list.iloc[valid_b_idx[i]][0]]['t'].squeeze()), axis = 2) 
                                 for i in range(len(valid_b_idx))])[:, block_idx, :, :]

                    to_gather_te = helpers.gather_idx(c, l=500, b= len(valid_b_idx))
                    temp_te = np.zeros((len(valid_b_idx), 500))
                    temp_te[to_gather_te[:, 0], to_gather_te[:, 1]] = 1
                    pred_te, pred_log_te = test_step(decoder, test_loss, m_te, x_te =x_te.squeeze(), y_te = y_te.squeeze(), to_gather=temp_te, d=True)

                    fig,ax = plt.subplots(figsize=(12, 8))
                    idx_p = np.random.choice(np.arange(len(valid_b_idx)), 1)
                    
                    ax.scatter(df_filt_sliced[basin_list.iloc[int(valid_b_idx[idx_p])][0]]['date'][500 * int(block_idx) : 500 * (int(block_idx)+1)], y_te[1], c='blue')
                    ax.scatter(df_filt_sliced[basin_list.iloc[int(valid_b_idx[idx_p])][0]]['date'][500 * int(block_idx) : (500 * int(block_idx) + 400)], y_te[1][:, :400], c='red')
                    ax.scatter(df_filt_sliced[basin_list.iloc[int(valid_b_idx[idx_p])][0]]['date'][500 * int(block_idx) + 400: 500 * (int(block_idx)+1)], pred_te[1, 399:], c='goldenrod') 
                    plt.show()
                        
                        
                    helpers.print_progress(epoch, batch_n, train_loss.result(), test_loss.result(), m_tr.result(), m_te.result())
                    manager.save()
                step += 1
                ckpt.step.assign_add(1)
            print ('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))