In [1]:
# This is required to run multiple processes on Unity for some reason.
from multiprocessing import set_start_method
try:
    set_start_method('spawn')
except: #Throws if already set
    pass

# Disable CUDA graphs
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_command_buffer='


In [2]:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import jax
from pathlib import Path
from importlib import reload

src = str(Path('../src').resolve())
if src not in sys.path:
    sys.path.append(src)
import config, data, models, train, evaluate

In [10]:
from config import read_config

cfg, cfg_str = read_config("../runs/hybrid/config.yml")

In [8]:
from data import TAPDataset

dataset = TAPDataset(cfg)

Dropping static attributes with 0 variance: ['glc_pc_s03', 'glc_pc_s05', 'glc_pc_s07', 'glc_pc_s08', 'glc_pc_s17', 'glc_pc_s19', 'pnv_pc_s03', 'wet_pc_s05', 'wet_pc_s06', 'wet_pc_s07']


Loading Basins:   0%|          | 0/740 [00:00<?, ?it/s]

In [None]:
from config import set_model_data_args
from data import TAPDataLoader

reload(train)
from train import Trainer

cfg = set_model_data_args(cfg, dataset)
cfg['num_layers'] = 2

dataloader = TAPDataLoader(cfg, dataset)
trainer = Trainer(cfg, dataloader)
trainer.start_training()

Dataloader using 2 parallel CPU worker(s).
Batch sharding set to 1 cpu(s)


Epoch:001:   0%|          | 0/2292 [00:18<?, ?it/s]

Anomalous batch loss (1.1480, z-score of 7.3).
Anomalous batch loss (0.7820, z-score of 5.0).
Anomalous batch loss (0.8537, z-score of 5.8).
Anomalous batch loss (0.9494, z-score of 7.7).
Anomalous batch loss (0.8990, z-score of 7.5).
Anomalous batch loss (2.1523, z-score of 20.4).
Anomalous batch loss (0.7228, z-score of 5.2).


Epoch:002:   0%|          | 0/2292 [00:00<?, ?it/s]

Anomalous batch loss (0.5728, z-score of 6.3).
Anomalous batch loss (0.9714, z-score of 11.5).
Anomalous batch loss (0.7176, z-score of 8.1).
Anomalous batch loss (0.7934, z-score of 9.0).
Anomalous batch loss (0.8520, z-score of 9.7).
Anomalous batch loss (2.3305, z-score of 23.9).
Anomalous batch loss (0.6651, z-score of 5.1).
Anomalous batch loss (0.7411, z-score of 6.2).
Anomalous batch loss (0.7380, z-score of 6.5).
Anomalous batch loss (0.8682, z-score of 8.0).
Anomalous batch loss (0.6458, z-score of 5.5).


Epoch:003:   0%|          | 0/2292 [00:00<?, ?it/s]

Anomalous batch loss (0.5903, z-score of 6.7).
Anomalous batch loss (0.7968, z-score of 7.1).
Anomalous batch loss (0.5863, z-score of 5.6).
Anomalous batch loss (2.2342, z-score of 15.7).
Anomalous batch loss (0.8083, z-score of 6.7).
Anomalous batch loss (0.7218, z-score of 5.9).
Anomalous batch loss (0.9140, z-score of 8.7).
Anomalous batch loss (1.0205, z-score of 10.5).
Anomalous batch loss (0.6955, z-score of 6.4).
Anomalous batch loss (0.5683, z-score of 5.0).


Epoch:004:   0%|          | 0/2292 [00:00<?, ?it/s]

Anomalous batch loss (0.9707, z-score of 11.7).
Anomalous batch loss (0.8885, z-score of 10.3).


In [None]:
# Resume training. Either directly from memory or loading a checkpoint.
import optax

# trainer.load_state('epoch100')
# trainer.load_last_state()

more_epochs = 0
new_schedule = optax.exponential_decay(0.01, trainer.epoch+more_epochs, 0.001, transition_begin=trainer.epoch)
trainer.lr_schedule = new_schedule
trainer.num_epochs += more_epochs

#Have to make a new dataloader when the last one is interrupted. 
trainer.dataloader = TAPDataLoader(cfg, dataset) 
trainer.start_training() 

In [None]:
reload(evaluate)
from evaluate import predict, get_all_metrics

# basin = np.random.choice(dataset.basins).tolist()
basin = 'USGS-09367540'

cfg['data_subset'] = 'predict'
cfg['basin_subset'] =  basin
cfg['num_workers'] = 0 # Faster for small runs
dataloader = TAPDataLoader(cfg, dataset)

results, metrics = predict(trainer.model, dataloader, seed=0, denormalize=True)
results['pred'] = results['pred'] * (results['pred']>0) #Clip predictions to 0

results = results.reset_index()
results = results.sort_values(by='date')
results = results.drop(columns=['basin'])
results.set_index('date', inplace=True)



In [None]:
# Plot the true values and predictions
fig, ax = plt.subplots(figsize=(12, 6))
results['pred'].plot(ax=ax)
results['obs'].plot(ax=ax,linestyle='None',marker='.')

plt.title(f"Basin: {basin}")
plt.legend()
fig.autofmt_xdate()
# plt.ylim([0,20000])
plt.show()

In [None]:
basin

In [None]:
"""
'USGS-09367540'
"""

In [None]:
results.plot.scatter('obs','pred')
plt.gca().axis('square')
# plt.xscale('log')
# plt.yscale('log')
# plt.xlim([0,20])
# plt.ylim([0,20])
plt.show()

In [None]:
import train
from data import TAPDataset, TAPDataLoader

state_dir = Path("../runs/notebook/20240603_1359/epoch18")
cfg, model, trainer_state, opt_state = train.load_state(state_dir)
dataset = TAPDataset(cfg)

In [None]:
reload(evaluate)
from evaluate import predict, get_all_metrics

cfg['data_subset'] = 'test'
cfg['num_workers'] = 4
dataloader = TAPDataLoader(cfg, dataset)

results = predict(model, dataloader, seed=0, denormalize=True)
results['pred'] = results['pred']# * (results['pred']>0) #Clip predictions to 0

# results = results.reset_index()
# results = results.sort_values(by='date')

metrics = get_all_metrics(results['obs'],results['pred'])
metrics

In [None]:
%matplotlib widget
plt.close('all')
plt.scatter(batch['y'][...,-1],pred[...,-1])
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(10,3))
xd = axes[0].imshow(batch['x_dd'][:,:,0],aspect='auto')
fig.colorbar(xd, ax=axes[0])
xs = axes[1].imshow(batch['x_s'],aspect='auto')
fig.colorbar(xs, ax=axes[1]) 

In [None]:
batch['x_dd'][:,:,0].shape

In [None]:
basins[idx_max_err]

In [None]:
positional_encoding = trainer.model.d_encoder.embedder.positional_encoding

plt.figure(figsize=(10, 8))
plt.imshow(positional_encoding, cmap='viridis')
plt.xlabel('Embedding Dimension')
plt.ylabel('Position')
plt.title('Positional Encodings')
plt.show()