# Explore the LSTM/EALSTM

In [32]:
from pathlib import Path
import os
import warnings

%load_ext autoreload
%autoreload 2

# ignore warnings for now ...
warnings.filterwarnings('ignore')

if Path('.').absolute().parents[1].name == 'ml_drought':
    os.chdir(Path('.').absolute().parents[1])

!pwd

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/tommy/ml_drought


In [70]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import matplotlib as mpl
from tqdm import tqdm
import torch

mpl.rcParams['figure.dpi'] = 150

In [34]:
# data_dir = Path('data/')
data_dir = Path('/cats/datastore/data/')

assert data_dir.exists()

In [35]:
from src.utils import drop_nans_and_flatten

from src.analysis import read_train_data, read_test_data, read_pred_data
from src.analysis.evaluation import join_true_pred_da
from src.models import load_model

In [79]:
from ruamel.yaml import YAML

# Read in the data

In [36]:
EXPERIMENT =        'one_timestep_forecast'
TRUE_EXPERIMENT =   'one_timestep_forecast'
TARGET_VAR =        'discharge_spec'
STATIC_DATA_FILE =  'data.nc' # '2020_04_23:112630_data.nc_'
DYNAMIC_DATA_FILE = 'data.nc'
N_EPOCHS = 100

# assert (data_dir / f"models/{EXPERIMENT}").exists()
assert (data_dir / f"RUNOFF/features/{TRUE_EXPERIMENT}").exists()

In [37]:
# read in the training data
ds = xr.open_dataset(data_dir / f'RUNOFF/features/{TRUE_EXPERIMENT}/{DYNAMIC_DATA_FILE}')

# static_ds = xr.open_dataset(Path(f'data/features/static/data.nc'))
all_static = xr.open_dataset(data_dir / f'RUNOFF/interim/static/data.nc')
all_static['station_id'] = all_static['station_id'].astype(int)
static = all_static

ds['station_id'] = ds['station_id'].astype(int)

In [38]:
catchment_ids = [int(c) for c in ["12002", "15006", "27009", "27034", "27041", "39001", "39081", "43021", "47001", "54001", "54057", "71001", "84013",]]
catchment_names = ["Dee@Park", "Tay@Ballathie", "Ouse@Skelton", "Ure@Kilgram", "Derwent@Buttercrambe", "Thames@Kingston", "Ock@Abingdon", "Avon@Knapp", "Tamar@Gunnislake", "Severn@Bewdley", "Severn@Haw", "Ribble@Samlesbury", "Clyde@Daldowie"]
station_map = dict(zip(catchment_ids, catchment_names))

# Read AWS Trained Models

In [39]:
print([d.name for d in (data_dir/'runs/').iterdir()])
print([d.name for d in (data_dir/'runs').glob('*_less_vars*/*E015.csv')])

['ealstm_less_vars_2004_1607_1334', 'lstm_less_vars_1307_1717', 'ealstm_less_vars_2004_1707_1424', 'lstm_less_vars_2004_1507_1028']
['results_lstm_less_vars_1307_1717_E015.csv', 'results_ealstm_less_vars_2004_1707_1424_E015.csv', 'results_lstm_less_vars_2004_1507_1028_E015.csv']


In [40]:
# lstm_less_vars = pd.read_csv(data_dir / "RUNOFF/lstm_less_vars/results_lstm_less_vars_1307_1717_E015.csv")
lstm_less_vars = pd.read_csv(data_dir / "runs/lstm_less_vars_2004_1507_1028/results_lstm_less_vars_2004_1507_1028_E015.csv")
# lstm_less_vars = lstm_less_vars[["station_id", "time", "obs", "sim_E015"]].rename(columns=dict(sim_E015="sim"))
lstm_less_vars["time"] = pd.to_datetime(lstm_less_vars["time"])

In [41]:
ealstm_less_vars = pd.read_csv(data_dir / "runs/ealstm_less_vars_2004_1707_1424/results_ealstm_less_vars_2004_1707_1424_E015.csv")
ealstm_less_vars["time"] = pd.to_datetime(ealstm_less_vars["time"])

In [42]:
lstm_preds = lstm_less_vars.set_index(["station_id", "time"]).to_xarray()
ealstm_preds = ealstm_less_vars.set_index(["station_id", "time"]).to_xarray()

In [43]:
lstm_preds

# FUSE Data

In [44]:
all_paths = [d for d in (data_dir / "RUNOFF/FUSE/Timeseries_SimQ_Best/").glob("*_Best_Qsim.txt")]

if not (data_dir / "RUNOFF/ALL_fuse_ds.nc").exists():
    all_dfs = []
    for txt in tqdm(all_paths):
        df = pd.read_csv(txt, skiprows=3, header=0)
        df.columns = [c.rstrip().lstrip() for c in df.columns]
        df = df.rename(columns={"YYYY": "year", "MM": "month", "DD": "day"})
        df["time"] = pd.to_datetime(df[["year", "month", "day"]])
        station_id = int(str(txt).split("/")[-1].split("_")[0])
        df["station_id"] = [station_id for _ in range(len(df))]
        df = df.drop(["year", "month", "day", "HH"], axis=1).set_index(["station_id", "time"])
        all_dfs.append(df)
        
    fuse_ds = pd.concat(all_dfs).to_xarray()
    fuse_ds.to_netcdf(data_dir / "RUNOFF/ALL_fuse_ds.nc")
    
else:
    fuse_ds = xr.open_dataset(data_dir / "RUNOFF/ALL_fuse_ds.nc")

In [45]:
fuse_ds = fuse_ds.sel(time=slice('2004-01-01', '2009-01-01'))

In [46]:
# join with observations for stations that exist
obs = (
    ds.sel(station_id=np.isin(ds["station_id"], fuse_ds["station_id"]), time=fuse_ds["time"])["target_var_original"]
).rename("obs")
fuse_data = fuse_ds.sel(station_id=obs.station_id).merge(obs)

In [47]:
fuse_data

# Match Stations / Times

In [48]:
all_stations = np.isin(lstm_preds.station_id, fuse_data.station_id)

In [49]:
lstm_preds = lstm_preds.sel(station_id=all_stations, time=fuse_data.time)
ealstm_preds = ealstm_preds.sel(station_id=all_stations, time=fuse_data.time)

# Errors

In [50]:
from src.analysis.evaluation import spatial_rmse, spatial_r2, spatial_nse, spatial_bias
from src.analysis.evaluation import temporal_rmse, temporal_r2, temporal_nse
from src.analysis.evaluation import _nse_func, _rmse_func, _r2_func, _bias_func

In [51]:
def error_func(preds_xr: xr.Dataset, error_str: str) -> pd.DataFrame:
    lookup = {
        "nse": _nse_func,
        "rmse": _rmse_func,
        "r2": _r2_func,
        "bias": _bias_func,
    }
    error_func = lookup[error_str]
    
    df = preds_xr.to_dataframe()
    df = df.dropna(how='any')
    df = df.reset_index().set_index("time")

    station_ids = df["station_id"].unique()
    errors = []
    for station_id in station_ids:
        d = df.loc[df["station_id"] == station_id]
        if error_str == "rmse":
            _error_calc = error_func(d["obs"].values, d["sim"].values, n_instances=d.size)
        else:
            _error_calc = error_func(d["obs"].values, d["sim"].values)
        errors.append(_error_calc)

    error = pd.DataFrame({"station_id": station_ids, error_str: errors})
    
    return error

In [52]:
errors = [ 
    error_func(ealstm_preds, "nse").set_index('station_id'),
    error_func(ealstm_preds, "r2").set_index('station_id'), 
    error_func(ealstm_preds, "rmse").set_index('station_id'),
    error_func(ealstm_preds, "bias").set_index('station_id'),
]
ealstm_df = errors[0].join(errors[1].join(errors[2]).join(errors[3])).reset_index()

errors = [ 
    error_func(lstm_preds, "nse").set_index('station_id'),
    error_func(lstm_preds, "r2").set_index('station_id'), 
    error_func(lstm_preds, "rmse").set_index('station_id'),
    error_func(lstm_preds, "bias").set_index('station_id'),
]
lstm_df = errors[0].join(errors[1].join(errors[2]).join(errors[3])).reset_index()

In [53]:
ealstm_nse = (
    error_func(ealstm_preds, "nse")
    .sort_values('nse')
    .reset_index()
    .drop('index', axis=1)
)
lstm_nse = (
    error_func(lstm_preds, "nse")
    .sort_values('nse')
    .reset_index()
    .drop('index', axis=1)
)
ealstm_nse['negative'] = ealstm_nse['nse'] < 0
lstm_nse['negative'] = lstm_nse['nse'] < 0

### FUSE

In [54]:
obs = fuse_data["obs"].transpose("station_id", "time")
topmodel = fuse_data["SimQ_TOPMODEL"]
arnovic = fuse_data["SimQ_ARNOVIC"]
prms = fuse_data["SimQ_PRMS"]
sacramento = fuse_data["SimQ_SACRAMENTO"]

In [55]:
top_nse = spatial_nse(obs, topmodel).rename("TOPMODEL")
vic_nse = spatial_nse(obs, arnovic).rename("VIC")
prms_nse = spatial_nse(obs, prms).rename("PRMS")
sac_nse = spatial_nse(obs, sacramento).rename("Sacramento")

In [56]:
nse = xr.merge([
    top_nse,
    vic_nse,
    prms_nse,
    sac_nse,
])
nse_df = nse.to_dataframe()
nse_df = static['gauge_name'].to_dataframe().join(nse_df).rename(columns=dict(gauge_name="Name"))
nse_df.to_csv(data_dir / 'RUNOFF/FUSE_nse_table.csv')
nse_df.columns = [["nse" for _ in range(len(nse_df.columns))], nse_df.columns]


In [57]:
top_rmse = spatial_rmse(obs, topmodel).rename("TOPMODEL")
vic_rmse = spatial_rmse(obs, arnovic).rename("VIC")
prms_rmse = spatial_rmse(obs, prms).rename("PRMS")
sac_rmse = spatial_rmse(obs, sacramento).rename("Sacramento")

rmse = xr.merge([
    top_rmse,
    vic_rmse,
    prms_rmse,
    sac_rmse,
])
rmse_df = rmse.to_dataframe().drop(columns='time')
rmse_df.to_csv(data_dir / 'RUNOFF/FUSE_rmse_table.csv')
rmse_df.columns = [["rmse" for _ in range(len(rmse_df.columns))], rmse_df.columns]


In [58]:
top_r2 = spatial_r2(obs, topmodel).rename("TOPMODEL")
vic_r2 = spatial_r2(obs, arnovic).rename("VIC")
prms_r2 = spatial_r2(obs, prms).rename("PRMS")
sac_r2 = spatial_r2(obs, sacramento).rename("Sacramento")

r2 = xr.merge([
    top_r2,
    vic_r2,
    prms_r2,
    sac_r2,
])
r2_df = r2.to_dataframe().drop(columns='time')
r2_df.to_csv(data_dir / 'RUNOFF/FUSE_r2_table.csv')
r2_df.columns = [["r2" for _ in range(len(r2_df.columns))], r2_df.columns]

In [59]:
top_bias = spatial_bias(obs, topmodel).rename("TOPMODEL")
vic_bias = spatial_bias(obs, arnovic).rename("VIC")
prms_bias = spatial_bias(obs, prms).rename("PRMS")
sac_bias = spatial_bias(obs, sacramento).rename("Sacramento")

bias = xr.merge([
    top_bias,
    vic_bias,
    prms_bias,
    sac_bias,
])
bias_df = bias.to_dataframe()
bias_df.to_csv(data_dir / 'RUNOFF/FUSE_bias_table.csv')
bias_df.columns = [["bias" for _ in range(len(bias_df.columns))], bias_df.columns]

In [60]:
fuse_errors = pd.concat([nse_df, rmse_df, r2_df, bias_df], axis=1)
try:
    fuse_errors = fuse_errors.drop('time', axis=1, level=1).swaplevel(axis=1).sort_index(axis=1)
except KeyError:
    pass
fuse_errors.to_csv(data_dir / 'RUNOFF/FUSE_errors.csv')
fuse_errors.to_pickle(data_dir / 'RUNOFF/FUSE_errors.pkl')

In [61]:
fuse_nse_df = fuse_errors.drop(['bias', 'r2', 'rmse'], axis=1, level=0).droplevel(axis=1, level=0)
fuse_bias = fuse_errors.drop(['nse', 'r2', 'rmse'], axis=1, level=0).droplevel(axis=1, level=0)

In [62]:
fuse_nse_df.head()

Unnamed: 0_level_0,Name,TOPMODEL,VIC,PRMS,Sacramento
station_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1001,Wick at Tarroul,0.806001,0.782846,0.81074,0.833511
2001,Helmsdale at Kilphedir,0.702887,0.72446,0.733039,0.745913
2002,Brora at Bruachrobie,0.759962,0.73194,0.754227,0.753752
3003,Oykel at Easter Turnaig,0.826229,0.818363,0.820548,0.830304
4001,Conon at Moy Bridge,0.730351,0.814088,0.761567,0.794666


# Interpret the Static Embedding

In [77]:
print([d.name for d in (data_dir / "runs/").iterdir()])
print([d.name for d in (data_dir / "runs/ealstm_less_vars_2004_1607_1334").glob("*")])

(data_dir / "runs/ealstm_less_vars_2004_1607_1334")
(data_dir / "runs/lstm_less_vars_2004_1507_1028")

['ealstm_less_vars_2004_1607_1334', 'lstm_less_vars_1307_1717', 'ealstm_less_vars_2004_1707_1424', 'lstm_less_vars_2004_1507_1028']
['config.yml', 'events.out.tfevents.1594906656.GPU_MachineLearning.22441.0', 'model_epoch001.pt', 'model_epoch002.pt', 'model_epoch003.pt', 'model_epoch004.pt', 'model_epoch005.pt', 'model_epoch006.pt', 'model_epoch007.pt', 'model_epoch008.pt', 'model_epoch009.pt', 'model_epoch010.pt', 'model_epoch011.pt', 'valid_ds.nc', 'results_ealstm_less_vars_2004_1607_1334_E002.csv', 'results_ealstm_less_vars_2004_1607_1334_E006.csv', 'model_epoch012.pt', 'results_ealstm_less_vars_2004_1607_1334_E007.csv', 'results_ealstm_less_vars_2004_1607_1334_E001.csv', 'results_ealstm_less_vars_2004_1607_1334_E008.csv', 'results_ealstm_less_vars_2004_1607_1334_E009.csv', 'results_ealstm_less_vars_2004_1607_1334_E010.csv', 'model_epoch013.pt', 'results_ealstm_less_vars_2004_1607_1334_E005.csv', 'results_ealstm_less_vars_2004_1607_1334_E004.csv', 'results_ealstm_less_vars_2004_1607

PosixPath('/cats/datastore/data/runs/lstm_less_vars_2004_1507_1028')

In [85]:
model_path = (data_dir / "runs/ealstm_less_vars_2004_1607_1334/model_epoch014.pt") 
config_path = (data_dir / "runs/ealstm_less_vars_2004_1607_1334/config.yml") 
yaml = YAML(typ="safe")
cfg = yaml.load(config_path)
model_dict = torch.load(model_path)

In [88]:
assert (data_dir / "CAMELS_GB_DATASET").exists()
cfg["data_dir"] = data_dir / "CAMELS_GB_DATASET"

# input_size_dyn = len(cfg["dynamic_inputs"])
# input_size_stat = len(cfg["static_inputs"] + cfg["camels_attributes"])

In [102]:
import sys
sys.path.insert(1, '/home/tommy/tommy_multiple_forcing')

from codebase.modelzoo.ealstm import EALSTM
from codebase.modelzoo.cudalstm import CudaLSTM
from codebase.data.camelstxt import CamelsGBCSV
from codebase.data.utils import load_basin_file

In [103]:
input_size_dyn = len(cfg["dynamic_inputs"])
input_size_stat = len(cfg["static_inputs"] + cfg["camels_attributes"])

# Initialize model
model = EALSTM(cfg)
model.load_state_dict(torch.load(model_path, map_location="cpu"))

# # extract weight and bias of input gate
# weight = model.lstm.weight_sh
# bias = model.lstm.bias_s

# load in the attributes
attributes = load_attributes(db_path=str(BASE_RUN_DIR / "attributes.db"), 
                             basins=basins,
                             drop_lat_lon=True)
means = attributes.mean()
stds = attributes.std()


<All keys matched successfully>

In [105]:
cfg.keys()

dict_keys(['additional_feature_files', 'batch_size', 'cache_data', 'cache_validation_data', 'camels_attributes', 'clip_gradient_norm', 'clip_gradient_value', 'commit_hash', 'data_dir', 'dataset', 'device', 'dynamic_inputs', 'embedding_activation', 'embedding_dropout', 'embedding_hiddens', 'epochs', 'experiment_name', 'h5_file', 'head', 'hidden_size', 'initial_forget_bias', 'learning_rate', 'log_dir', 'log_figure', 'log_figure_n_basins', 'log_interval', 'log_tensorboard', 'loss', 'metrics', 'model', 'num_workers', 'number_of_basins', 'optimizer', 'output_activation', 'output_dropout', 'predict_last_n', 'run_dir', 'save_validation_results', 'save_weights_every', 'scaler_file', 'seed', 'seq_length', 'static_inputs', 'target_variable', 'test_basin_file', 'test_end_date', 'test_start_date', 'train_basin_file', 'train_dir', 'train_end_date', 'train_start_date', 'use_basin_id_encoding', 'validate_every', 'validate_n_random_basins', 'validation_basin_file', 'validation_end_date', 'validation_s

In [104]:
basins = load_basin_file(cfg["train_basin_file"])
basin = basins[0]
ds_test = CamelsGBCSV(camels_root=cfg["data_dir"],
                        basin=basin,
                        dates=[cfg["test_start_date"], cfg["test_end_date"]],
                        is_train=False,
                        with_attributes=True,
                        attribute_means=means,
                        attribute_stds=stds,
                        db_path=str(BASE_RUN_DIR / "attributes.db"))

SyntaxError: invalid syntax (<ipython-input-104-6875ee063512>, line 2)

In [101]:
for basin in tqdm(basins):
    ds_test = CamelsTXT(camels_root=CAMELS_DIR,
                        basin=basin,
                        dates=[VAL_START, VAL_END],
                        is_train=False,
                        with_attributes=True,
                        attribute_means=means,
                        attribute_stds=stds,
                        db_path=str(BASE_RUN_DIR / "attributes.db"))
    input_gate = torch.sigmoid(torch.addmm(bias, ds_test.attributes, weight))
    lstm_embedding[basin] = input_gate.detach().numpy()


NameError: name 'ds_test' is not defined