In [4]:
import os
import pandas as pd
import numpy as np
from tqdm.notebook import trange, tqdm
import jax
import optax
import matplotlib.pyplot as plt
import importlib
from pathlib import Path

# Force reload of project files
import data, train, models, evaluate, metrics
importlib.reload(data)
from data import TAPDataset

data_dir = Path("../data/wqp")
basin_list_file = data_dir / "metadata" / "site_lists" / "sites_all.txt"
# basin_list_file = data_dir / "metadata" / "site_lists" / "sites_ssc_area1000_n10.txt"

with open(basin_list_file, 'r') as file:
    basin_list = file.readlines()
    basin_list = [basin.strip() for basin in basin_list]

basin_list = basin_list[:10]

data_params = {'data_dir': data_dir,
               'basin_list': basin_list,
               'features_dict': {'daily':['grfr_q'],
                                 'irregular':['Blue','Green','Red','Nir','Swir1','Swir2']},
               'target': 'turbidity',
               'time_slice': slice('1979-01-01', '2018-12-31'),
               'split_time': np.datetime64('2010-01-01'),
               'sequence_length': 7,
               'log_norm_cols': ['turbidity', 'grfr_q'],
               'clip_target_to_zero': True,
               'discharge_col': 'grfr_q'}

dataset = TAPDataset(**data_params)

Loading Basins:   0%|          | 0/10 [00:00<?, ?it/s]

In [6]:
importlib.reload(data)
from data import TAPDataset, TAPDataLoader

loader_params = {'suffle': True,
                 'batch_size': 3,
                 'num_workers': 1,
                 'pin_memory': True} 
dataloader = TAPDataLoader(dataset, **loader_params)
# dataloader.set_mode(train=True, sequence=True, end2end=True)


Dataloader using 1 parallel CPU worker(s).
Batch sharding set to 1 cpu(s)


In [7]:
for _, _, batch in dataloader:
    break
batch['y'][:,-1]

PicklingError: Can't pickle <class 'data.TAPDataset'>: it's not the same object as data.TAPDataset

In [None]:
importlib.reload(data)
importlib.reload(models)
importlib.reload(train)
from data import TAPDataLoader
from models import TAPLSTM, EALSTM, LSTM
from train import Trainer


daily_in_size = len(data_params['features_dict']['daily'])
irregular_in_size = len(data_params['features_dict']['irregular'])
static_in_size = dataset.x_s[basin_list[0]].shape[0]
output_size = 1

hidden_size = 16
dropout = 0.4
key = jax.random.PRNGKey(0)

model = TAPLSTM(daily_in_size, irregular_in_size, static_in_size, output_size, hidden_size, key=key, dropout=dropout)
# model = LSTM(daily_in_size, output_size, hidden_size, key=key, dropout=dropout)
# model = EALSTM(daily_in_size, static_in_size, output_size, hidden_size, key=key, dropout=dropout)

loader_params = {'suffle': True,
                 'batch_size': 3,
                 'train':True,
                 'sequence':True,
                 'end2end': True,
                 'num_workers': 1,
                 'pin_memory': True} 
dataloader = TAPDataLoader(dataset, **loader_params)

num_epochs =  15
# lr_schedule = optax.polynomial_schedule(0.01, 0.0001, 2, num_epochs)
lr_schedule = optax.exponential_decay(0.005, num_epochs, 0.5)
trainer = Trainer(model, dataloader, lr_schedule, num_epochs, max_grad_norm=2)
model = trainer.start_training()


In [None]:
trainer.num_epochs=1000
model = trainer.start_training()

In [None]:
model

In [None]:
basin

In [None]:
from metrics import get_all_metrics

importlib.reload(evaluate)
from evaluate import predict

#Need a sr evaluate mode. Maybe just rethink the entire indexing scheme.
dataloader = TAPDataLoader(dataset, **loader_params)
basin = np.random.choice(basin_list).tolist()
dataloader.set_mode(train=False, sequence=True, end2end=True, basin_subset=basin)
results = predict(model, dataloader, denormalize=True)
results['pred'] = results['pred'] * (results['pred']>0) #Clip predictions to 0
metrics = get_all_metrics(results['obs'],results['pred'])
metrics

In [None]:


# Plot the true values and predictions
fig, ax = plt.subplots(figsize=(12, 6))
basin_results['pred'].plot(ax=ax)
basin_results['obs'].plot(ax=ax,linestyle='None',marker='.')

metrics = get_all_metrics(basin_results['obs'],basin_results['pred'])

plt.title(f"Basin: {basin_subset}, KGE: {metrics['kge']:0.4f}")
plt.legend()
plt.ylim([0,500])
plt.show()

In [None]:
results.plot.scatter('obs','pred')
plt.gca().axis('square')
plt.xlim([0,20])
plt.ylim([0,20])
plt.show()