In [None]:
# This block is required to run multiple processes on Unity for some reason.
from multiprocessing import set_start_method
set_start_method('spawn')

In [10]:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import jax
import optax
from pathlib import Path

src = Path('../src').resolve()
if src not in sys.path:
    sys.path.append(src)
    
# Force reload of project files
import importlib
import data, train, models, evaluate
importlib.reload(data)
from data import TAPDataset, TAPDataLoader

data_dir = Path("../data/wqp")
basin_file = data_dir / "metadata" / "site_lists" / "sites_test.txt"
# basin_list_file = data_dir / "metadata" / "site_lists" / "sites_all.txt"
# basin_file = data_dir / "metadata" / "site_lists" / "sites_turb_area1000_n10.txt"

data_args = {'data_dir': data_dir,
             'basin_file': basin_file,
             'features':{
                 'daily':['grfr_q'],
                 'irregular':['Blue','Green','Red','Nir','Swir1','Swir2'],
                 'static':['wet_pc_s02', 'glc_pc_s11', 'dor_pc_pva', 'soc_th_sav', 'snw_pc_s08',
                           'wet_pc_s09', 'fmh_cl_smj', 'glc_pc_s14', 'pnv_pc_s11', 'pac_pc_sse'],
                 'target': 'turbidity'},
             'time_slice': slice('1979-01-01', '2018-12-31'),
             'split_time': np.datetime64('2010-01-01'),
             'sequence_length': 30,
             'log_norm_cols': ['turbidity', 'grfr_q'],
             'clip_target_to_zero': True}

dataset = TAPDataset(**data_args)

Loading Basins:   0%|          | 0/10 [00:00<?, ?it/s]

In [5]:
importlib.reload(models)
importlib.reload(train)
from models import TAPLSTM, EALSTM, LSTM
from train import Trainer

loader_args = {'shuffle': True,
               'batch_size': 16,
               'data_subset': 'pre-train',
               'num_workers': 1,
               'pin_memory': True} 
dataloader = TAPDataLoader(dataset, **loader_args)

model_args = {'daily_in_size': len(dataset.daily_features),
              'irregular_in_size': len(dataset.irregular_features),
              'static_in_size': len(dataset.static_features),
              'out_size': 1,
              'hidden_size': 16,
              'dropout': 0.4,
              'seed': 0}
# model_args = {'dynamic_in_size': len(data_args['features']['daily']),
#               'static_in_size': len(dataset.static_features),
#               'out_size': 1,
#               'hidden_size': 16,
#               'dropout': 0.4,
#               'seed': 0}

num_epochs = 5
lr_schedule = optax.exponential_decay(0.01, num_epochs, 0.01)
trainer_args = {'model_func': TAPLSTM,
                'model_args': model_args,
                'dataloader': dataloader,
                'lr_schedule': lr_schedule,
                'num_epochs': num_epochs,
                'max_grad_norm': 2}

trainer = Trainer(**trainer_args)
trainer.start_training()

Dataloader using 1 parallel CPU worker(s).
Batch sharding set to 1 cpu(s)


  self.pid = os.fork()


Epoch:001:   0%|          | 0/43 [00:00<?, ?it/s]

ERROR: Unexpected segmentation fault encountered in worker.
 

RuntimeError: DataLoader worker (pid(s) 1928876) exited unexpectedly

In [None]:
more_epochs = 150

trainer.load_state('epoch100')
loader_args['data_subset'] = 'train'
trainer.dataloader = TAPDataLoader(dataset, **loader_args)
trainer.lr_schedule = optax.exponential_decay(0.01, trainer.epoch+more_epochs, 0.001, transition_begin=trainer.epoch)
trainer.num_epochs += more_epochs
trainer.freeze_components('tealstm_i',True)
trainer.start_training() 

In [None]:
lr_schedule = optax.exponential_decay(0.01, trainer.epoch+num_epochs, 0.001, transition_begin=trainer.epoch)
x = np.linspace(0,num_epochs*2)
y = lr_schedule(x)
plt.plot(x,y)

In [None]:
importlib.reload(evaluate)
from evaluate import predict, get_all_metrics

basin = np.random.choice(basin_list).tolist()

loader_args['data_subset'] = 'test'
loader_args['basin_subset'] =  basin
loader_args['num_workers'] = 0 # Faster for small runs
dataloader = TAPDataLoader(dataset, **loader_args)

results = predict(trainer.model, dataloader, denormalize=True)
results['pred'] = results['pred'] * (results['pred']>0) #Clip predictions to 0
metrics = get_all_metrics(results['obs'],results['pred'])
metrics

In [None]:
results

In [None]:
# Plot the true values and predictions
fig, ax = plt.subplots(figsize=(12, 6))
results['pred'].plot(ax=ax)
results['obs'].plot(ax=ax,linestyle='None',marker='.')

metrics = get_all_metrics(results['obs'],results['pred'])

plt.title(f"Basin: {basin}, KGE: {metrics['kge']:0.4f}")
plt.legend()
plt.ylim([0,500])
plt.show()

In [None]:
results.plot.scatter('obs','pred')
plt.gca().axis('square')
plt.xlim([0,20])
plt.ylim([0,20])
plt.show()

In [None]:
from train import make_step
from tqdm.notebook import trange

# See if we can recreate the error... 
data = trainer.load_state("exceptions/epoch130_exception0")
for i in trange(1000):
    make_step(trainer.model, data['batch'], trainer.opt_state, trainer.optim,
              trainer.filter_spec, loss_name="mse", max_grad_norm=None, l2_weight=None)