# Analyse the FUSE Experiments vs. LSTM work

In [128]:
from pathlib import Path
import os
import warnings
from typing import Optional, List, Tuple, Dict

%load_ext autoreload
%autoreload 2

# ignore warnings for now ...
warnings.filterwarnings('ignore')

if Path('.').absolute().parents[1].name == 'ml_drought':
    os.chdir(Path('.').absolute().parents[1])

!pwd

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/tommy/ml_drought


In [129]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import matplotlib as mpl
from tqdm import tqdm

mpl.rcParams['figure.dpi'] = 150

In [130]:
data_dir = Path('/cats/datastore/data/')

assert data_dir.exists()

In [131]:
from src.utils import drop_nans_and_flatten

from src.analysis import read_train_data, read_test_data, read_pred_data
from src.analysis.evaluation import join_true_pred_da
from src.models import load_model

# Read in the data

In [133]:
# read in the training data
ds = xr.open_dataset(data_dir / "RUNOFF/ALL_dynamic_ds.nc")
ds['station_id'] = ds['station_id'].astype(int)

all_static = xr.open_dataset(data_dir / f'RUNOFF/interim/static/data.nc')
all_static['station_id'] = all_static['station_id'].astype(int)
static = all_static

In [134]:
# 13 test stations
catchment_ids = [int(c) for c in ["12002", "15006", "27009", "27034", "27041", "39001", "39081", "43021", "47001", "54001", "54057", "71001", "84013",]]
catchment_names = ["Dee@Park", "Tay@Ballathie", "Ouse@Skelton", "Ure@Kilgram", "Derwent@Buttercrambe", "Thames@Kingston", "Ock@Abingdon", "Avon@Knapp", "Tamar@Gunnislake", "Severn@Bewdley", "Severn@Haw", "Ribble@Samlesbury", "Clyde@Daldowie"]
station_map = dict(zip(catchment_ids, catchment_names))

# Open the Less Vars experiments (2004-2015 test)

In [135]:
print([d.name for d in (data_dir/'runs/').iterdir()])
print([d.name for d in (data_dir/'runs').glob('*_less_vars*/*E015.csv')])

['ealstm_less_vars_2004_1707_1424', 'lstm_less_vars_2004_1507_1028', 'train_data.h5', 'train_data_scaler.p', 'lstm_ALL_vars_2004_2210_1035', 'lstm_all_vars_1998_2008_2210_110347', 'lstm_all_vars_1998_2008_2210_110727', 'lstm_all_vars_1998_2008_nh_2310_101443', 'lstm_all_vars_1998_2008_nh_2310_142625', 'ensemble__', 'ensemble_EALSTM', 'ensemble_lstm10', 'ensemble', '0_ensemble_results']
['results_ealstm_less_vars_2004_1707_1424_E015.csv', 'results_lstm_less_vars_2004_1507_1028_E015.csv']


In [136]:
lstm_df = pd.read_csv(data_dir / "runs/lstm_less_vars_2004_1507_1028/results_lstm_less_vars_2004_1507_1028_E015.csv")
lstm_df["time"] = pd.to_datetime(lstm_df["time"])

In [137]:
ealstm_df = pd.read_csv(data_dir / "runs/ealstm_less_vars_2004_1707_1424/results_ealstm_less_vars_2004_1707_1424_E015.csv")
ealstm_df["time"] = pd.to_datetime(ealstm_df["time"])

In [138]:
lstm_ensemble_df = pd.read_csv("/cats/datastore/data/runs/ensemble/data_ENS.csv").drop("Unnamed: 0", axis=1)
lstm_ensemble_df["time"] = pd.to_datetime(lstm_ensemble_df["time"])

lstm_ensemble = lstm_ensemble_df.set_index(["station_id", "time"]).to_xarray()
lstm_df = lstm_ensemble_df
lstm_ensemble

In [139]:
# lstm_df.set_index(["station_id", "time"]).to_xarray()

lstm_preds = lstm_ensemble
ealstm_preds = ealstm_df.set_index(["station_id", "time"]).to_xarray()

# Calculate Errors

In [140]:
from src.analysis.evaluation import spatial_rmse, spatial_r2, spatial_nse
from src.analysis.evaluation import temporal_rmse, temporal_r2, temporal_nse
from src.analysis.evaluation import _nse_func, _rmse_func, _r2_func

def error_func(preds_xr: xr.Dataset, error_str: str) -> pd.DataFrame:
    lookup = {
        "nse": _nse_func,
        "rmse": _rmse_func,
        "r2": _r2_func,
    }
    error_func = lookup[error_str]
    
    df = preds_xr.to_dataframe()
    df = df.dropna(how='any')
    df = df.reset_index().set_index("time")

    station_ids = df["station_id"].unique()
    errors = []
    for station_id in station_ids:
        d = df.loc[df["station_id"] == station_id]
        if error_str == "rmse":
            _error_calc = error_func(d["obs"].values, d["sim"].values, n_instances=d.size)
        else:
            _error_calc = error_func(d["obs"].values, d["sim"].values)
        errors.append(_error_calc)

    error = pd.DataFrame({"station_id": station_ids, error_str: errors})
    
    return error

In [141]:
errors = [ 
    error_func(ealstm_preds, "nse").set_index('station_id'),
    error_func(ealstm_preds, "r2").set_index('station_id'), 
    error_func(ealstm_preds, "rmse").set_index('station_id'),
]
ealstm_metric_df = errors[0].join(errors[1].join(errors[2])).reset_index()

errors = [ 
    error_func(lstm_preds, "nse").set_index('station_id'),
    error_func(lstm_preds, "r2").set_index('station_id'), 
    error_func(lstm_preds, "rmse").set_index('station_id'),
]
lstm_metric_df = errors[0].join(errors[1].join(errors[2])).reset_index()

errors = [ 
    error_func(lstm_ensemble, "nse").set_index('station_id'),
    error_func(lstm_ensemble, "r2").set_index('station_id'), 
    error_func(lstm_ensemble, "rmse").set_index('station_id'),
]
ensemble_metric_df = errors[0].join(errors[1].join(errors[2])).reset_index()

In [142]:
ensemble_metric_df.head()

Unnamed: 0,station_id,nse,r2,rmse
0,1001,0.87609,0.876249,0.412211
1,2001,0.796159,0.798533,0.657134
2,2002,0.799168,0.79982,0.865827
3,3003,0.878744,0.878744,1.190749
4,4001,0.872644,0.872956,0.772388


In [143]:
metric_df = pd.read_csv(data_dir / "runs/ensemble/metric_df.csv", index_col=0)
metric_df.columns = [c.lower() for c in metric_df.columns]

metric_df.head()

Unnamed: 0,station_id,nse,kge,mse,fhv,fms,flv
0,10002,0.898328,0.857653,0.229347,-10.042169,20.175607,4.281218
1,10003,0.926555,0.88111,0.11198,-6.074578,1.342742,28.237229
2,1001,0.87609,0.920412,0.509755,3.647934,-23.746845,70.234474
3,101002,0.757246,0.647206,0.300506,-25.176707,-4.734227,73.409471
4,101005,0.824926,0.79643,0.208704,-18.346511,1.971726,45.197245


# Open FUSE Models

In [144]:
all_paths = [d for d in (data_dir / "RUNOFF/FUSE/Timeseries_SimQ_Best/").glob("*_Best_Qsim.txt")]

if not (data_dir / "RUNOFF/ALL_fuse_ds.nc").exists():
    all_dfs = []
    for txt in tqdm(all_paths):
        df = pd.read_csv(txt, skiprows=3, header=0)
        df.columns = [c.rstrip().lstrip() for c in df.columns]
        df = df.rename(columns={"YYYY": "year", "MM": "month", "DD": "day"})
        df["time"] = pd.to_datetime(df[["year", "month", "day"]])
        station_id = int(str(txt).split("/")[-1].split("_")[0])
        df["station_id"] = [station_id for _ in range(len(df))]
        df = df.drop(["year", "month", "day", "HH"], axis=1).set_index(["station_id", "time"])
        all_dfs.append(df)
        
    fuse_ds = pd.concat(all_dfs).to_xarray()
    fuse_ds.to_netcdf(data_dir / "RUNOFF/fuse_ds.nc")
    
else:
    fuse_ds = xr.open_dataset(data_dir / "RUNOFF/ALL_fuse_ds.nc")

### NOTE: only test performance on 2004-2008

In [145]:
fuse_ds = fuse_ds.sel(time=slice('2004-01-01', '2009-01-01'))

# join with observations for stations that exist
obs = (
    ds.sel(station_id=np.isin(ds["station_id"], fuse_ds["station_id"]), time=fuse_ds["time"])["target_var_original"]
).rename("discharge_spec")
fuse_data = fuse_ds.sel(station_id=obs.station_id).merge(obs)
fuse_data = fuse_data.rename({'discharge_spec': 'obs'})

### Calculate FUSE Errors

In [146]:
fuse_errors = pickle.load((data_dir / 'RUNOFF/FUSE_errors.pkl').open("rb"))

In [147]:
def get_error_df(model: str, fuse_errors: pd.DataFrame) -> pd.DataFrame:
    all_models = ["TOPMODEL", "VIC", "PRMS", "Sacramento"]
    assert model in all_models
    remove_models = [m for m in all_models if m != model]
    error_df = fuse_errors.drop(remove_models, axis=1, level=1).swaplevel(axis=1).sort_index(axis=1).droplevel(axis=1, level=0)
    rename_cols = pd.io.parsers.ParserBase({'names': error_df.columns})._maybe_dedup_names(error_df.columns)
    error_df.columns = [n if n != "nse" else "Name" for n in rename_cols]
    return error_df.rename({"nse.1": "nse"}, axis=1)


topmodel = fuse_data[["obs", "SimQ_TOPMODEL"]].to_dataframe().reset_index()
vic = fuse_data[["obs", "SimQ_ARNOVIC"]].to_dataframe().reset_index()
prms = fuse_data[["obs", "SimQ_PRMS"]].to_dataframe().reset_index()
sacramento = fuse_data[["obs", "SimQ_SACRAMENTO"]].to_dataframe().reset_index()

top_error = get_error_df("TOPMODEL", fuse_errors)
vic_error = get_error_df("VIC", fuse_errors)
prms_error = get_error_df("PRMS", fuse_errors)
sac_error = get_error_df("Sacramento", fuse_errors)

In [1]:
from pathlib import Path
import os
import warnings

%load_ext autoreload
%autoreload 2

# ignore warnings for now ...
warnings.filterwarnings('ignore')

if Path('.').absolute().parents[1].name == 'ml_drought':
    os.chdir(Path('.').absolute().parents[1])

!pwd

/home/tommy/ml_drought


In [2]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch

from typing import List, Union, Optional, Tuple, Dict

data_dir = Path('/cats/datastore/data/')
# data_dir = Path('/Volumes/Lees_Extend/data/zip_data')
# data_dir = Path('/Volumes/Lees_Extend/data/ecmwf_sowc/data/')
# plot_dir = Path('/Users/tommylees/Downloads')

assert data_dir.exists()

In [21]:
[d.name for d in (data_dir / "runs/lstm_less_vars_2004_1507_1028/").iterdir()]

['config.yml',
 'events.out.tfevents.1594810205.GPU_MachineLearning.10968.0',
 'model_epoch001.pt',
 'valid_ds.nc',
 'results_lstm_less_vars_2004_1507_1028_E001.csv',
 'all_lstm_less_vars_2004_1507_1028_results.csv',
 'model_epoch002.pt',
 'model_epoch003.pt',
 'model_epoch004.pt',
 'model_epoch005.pt',
 'model_epoch006.pt',
 'model_epoch007.pt',
 'model_epoch008.pt',
 'model_epoch009.pt',
 'model_epoch010.pt',
 'model_epoch011.pt',
 'model_epoch012.pt',
 'model_epoch013.pt',
 'model_epoch014.pt',
 'model_epoch015.pt',
 'results_lstm_less_vars_2004_1507_1028_E002.csv',
 'results_lstm_less_vars_2004_1507_1028_E006.csv',
 'results_lstm_less_vars_2004_1507_1028_E007.csv',
 'results_lstm_less_vars_2004_1507_1028_E008.csv',
 'results_lstm_less_vars_2004_1507_1028_E012.csv',
 'results_lstm_less_vars_2004_1507_1028_E014.csv',
 'results_lstm_less_vars_2004_1507_1028_E009.csv',
 'results_lstm_less_vars_2004_1507_1028_E010.csv',
 'results_lstm_less_vars_2004_1507_1028_E005.csv',
 'results_lstm_l

# Load csv

In [39]:
def get_all_data_csv(run_dir: Path):
    csv_path = sorted(list(run_dir.glob("*.csv")))[-1]
    number = csv_path.as_posix().split("_")[-1]
    print(f"Reading {number}")
    return pd.read_csv(csv_path)
    
    
run_dir = data_dir / "runs/ealstm_less_vars_2004_1607_1334"
ealstm = get_all_data_csv(run_dir)
run_dir = data_dir / "runs/lstm_less_vars_2004_1507_1028"
lstm = get_all_data_csv(run_dir)

lstm.head()

Reading E011.csv
Reading E015.csv


Unnamed: 0,station_id,time,obs,sim
0,1001,2004-01-01,9.16,10.026257
1,1001,2004-01-02,6.23,5.819096
2,1001,2004-01-03,5.6,5.267475
3,1001,2004-01-04,4.45,5.064303
4,1001,2004-01-05,4.46,5.131446


# Calculate Error Metrics

In [48]:
import hydroeval

In [51]:
hydroeval.nse(lstm.sim.values, lstm.obs.values)

nan