In [None]:
# This is required to run multiple processes on Unity for some reason.
from multiprocessing import set_start_method
try:
    set_start_method('spawn')
except: #Throws if already set
    pass

# Disable CUDA graphs
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_command_buffer='

In [None]:
%matplotlib widget
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import jax
import pickle
from pathlib import Path
from importlib import reload

src = str(Path('../src').resolve())
if src not in sys.path:
    sys.path.append(src)

from train import load_last_state
from data import TAPDataset, TAPDataLoader
from evaluate import *

In [None]:
import data
reload(data)
from data import TAPDataset, TAPDataLoader

run_dir = Path("../runs/hybrid_multitarget_0.5d_064_030/6_head_6_layer_all_sat_all_target_20240620_150312/flux_finetune_20240626_131838")
cfg, model, trainer_state, opt_state, _ = load_last_state(run_dir)
cfg['quiet'] = False

fig_dir = run_dir / "figures"
fig_dir.mkdir(exist_ok=True)

ts_dir = fig_dir / "timeseries"
ts_dir.mkdir(exist_ok=True)
    
dataset = TAPDataset(cfg)


NotADirectoryError: [Errno 20] Not a directory: '../runs/fusion/6_head_6_layer.yml'

In [None]:
cfg['data_subset'] = 'test'
cfg['num_workers'] = 1
cfg['basin_subset'] = None
cfg['batch_size'] = 1048
dataloader = TAPDataLoader(cfg, dataset)

results = predict(model, dataloader, seed=0, denormalize=True, return_dt=True)
bulk_metrics = get_all_metrics(results, False)
basin_metrics = get_basin_metrics(results, True)

with open(run_dir / f"{cfg['data_subset']}_data.pkl", 'wb') as f:
    pickle.dump((results, bulk_metrics, basin_metrics), f)

In [None]:
with open(run_dir / "test_data.pkl", 'rb') as f:
    results, bulk_metrics, basin_metrics = pickle.load(f)

In [None]:
model

In [None]:
basin_metrics['flux']['MAPE'].median()

In [None]:
reload(evaluate)
from evaluate import mosaic_scatter

plt.close('all')
fig = mosaic_scatter(cfg, results, bulk_metrics, str(run_dir))

plt.show()
# fig.savefig(fig_dir / f"epoch{trainer_state['epoch']:03d}_{cfg['data_subset']}_density_scatter.png",  dpi=300)


In [None]:
reload(evaluate)
from evaluate import *

metric_args = {
    'nBias':{'range':[-1,1]},
    'rRMSE':{'range':[0,500]},
    'KGE':{'range':[-2,1]},
    'NSE':{'range':[-5,1]},
    'Agreement':{'range':[0,1]}}

plt.close('all')

figs = basin_metric_histograms(basin_metrics, metric_args)

# for target, fig in figs.items():
    # fig.show()
    # fig.savefig(fig_dir / f"epoch{trainer_state['epoch']:03d}_{cfg['data_subset']}_{target}_metrics_hist.png",  dpi=300)
    

In [None]:
import matplotlib.gridspec as gridspec

def timeseries_plot(results, feature):
    pred = results['pred'][feature]
    obs = results['obs'][feature]
    joint = results[[('obs',feature),('pred',feature)]]
    joint = joint.dropna()
    
    if len(joint) < 2:
        return
           
    # Create a figure
    fig = plt.figure(figsize=(10, 4))
    gs = gridspec.GridSpec(1, 3, width_ratios=[1, 0.2, 2])
    axes = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 2])]
    
    min_val = joint.min().min()
    max_val = joint.max().max()

    axes[0].loglog(obs, pred, linestyle='None',marker='.', alpha=0.2)
    axes[0].plot([min_val, max_val], [min_val, max_val], 'r--')

    # Setting axes to be square and equal range
    axes[0].axis('square')
    axes[0].set_xlabel(f"Observed {feature}")
    axes[0].set_ylabel(f"Predicted {feature}")

    axes[0].set_title(f"Basin: {basin}")


    axes[1].plot(pred)
    axes[1].plot(obs, linestyle='None',marker='.',alpha=0.5)
    axes[1].set_title(f"{feature}")
    # axes[1].set_ylim([0,40000])

    textstr = '\n'.join([f"{key}: {basin_metrics.loc[basin][feature][key]:0.2f}" for key in ['rRMSE','KGE','NSE','Agreement']])
    props = dict(boxstyle='round', facecolor='white', alpha=0.5)
    axes[1].text(0.98, 0.97, textstr, transform=axes[1].transAxes, fontsize=10,
            va='top', ha='right', bbox=props)

    fig.suptitle(str(run_dir))
    fig.tight_layout()
    fig.autofmt_xdate()
    
    return fig

In [None]:
len(list(basin_dataset.all_basins))

In [None]:
import evaluate
reload(evaluate)
from evaluate import predict, get_all_metrics

basin = np.random.choice(dataset.test_basins).tolist()
# basin = 'USGS-09367540'
# basin = 'USGS-06109500' #no flux but nice seasonality
# basin = 'USGS-08332010' #Nice flux temporal distribution and seasonality

cfg['data_subset'] = 'predict'
cfg['basin_subset'] =  basin
cfg['num_workers'] = 1 # Faster for small runs
basin_dataset = TAPDataset(cfg)
dataloader = TAPDataLoader(cfg, basin_dataset)

results = predict(model, dataloader, seed=0, denormalize=True, return_dt=True)
results['pred'] = results['pred'] * (results['pred']>0) #Clip predictions to 0

results = results.reset_index()
results = results.sort_values(by='date')
results = results.drop('basin', axis=1, level=0)
results.set_index('date', inplace=True)

plt.close('all')
features = results.columns.get_level_values('Feature').unique()

figs = []
for feature in features:
    fig = timeseries_plot(results, feature)
    if fig is None:
        continue
    # fig.savefig(ts_dir / f"epoch{trainer_state['epoch']:03d}_{cfg['data_subset']}_{basin}_{feature}_timeseries.png",  dpi=300)
    figs.append(fig)
if len(figs)==0:
    print("No data in basin and period")

In [None]:
for basin, date, batch in dataloader:
    break
batch['x_s']

In [None]:
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, batch['x_s'].shape[0])

embedded = jax.vmap(model.static_embedder)(batch['x_s'], keys)

In [None]:
plt.close('all')
plt.imshow(embedded[0,...],aspect='auto')

In [None]:
plt.close('all')
plt.plot(model.head.weight.T, label=cfg['features']['target'])
plt.legend()


In [None]:
from train import Trainer, make_step

cfg['data_subset'] = 'test'
cfg['num_workers'] = 2
cfg['basin_subset'] = None
cfg['log'] = False
cfg['quiet'] = False
dataloader = TAPDataLoader(cfg, dataset)
trainer = Trainer(cfg, dataloader)

step_kwargs = cfg['step_kwargs']
step_kwargs['max_grad_norm'] = None

for basin, date, batch in dataloader:
    break

key = jax.random.PRNGKey(0)
batch_keys = jax.random.split(key, len(basin))

loss, grads, model, opt_state = make_step(
    model, 
    batch,
    batch_keys,
    opt_state, 
    trainer.optim, 
    trainer.filter_spec, 
    **step_kwargs
    )

In [None]:
key = jax.random.PRNGKey(0)
keys = keys = jax.random.split(key, num=batch['x_s'].shape[0])
embedded = jax.vmap(model.static_embedder)(batch['x_s'],keys)

plt.close('all')
plt.imshow(embedded[0,...])

In [None]:
plt.close('all')
plt.hist(model.static_embedder.layernorm.weight)
plt.show()

In [None]:
model

In [None]:
model.decoder.pooler

In [None]:
model.static_embedder.linear.weight.shape

In [None]:
mat_plot =  model.static_embedder.linear.weight.T
labels = dataset.attributes_scale.keys()
plot_umap(mat_plot, 16, labels)


In [None]:
import umap

def plot_umap(mat, neighbors, labels=[], components=2):
    reducer = umap.UMAP(n_neighbors=neighbors, n_components=components, metric='euclidean')
    embedding_nd = reducer.fit_transform(mat)

    # Plot using matplotlib
    plt.close('all')
    fig, ax = plt.subplots(figsize=(6,6))
    
    c = embedding_nd[:,2] if components>=2 else None
    ax.scatter(embedding_nd[:, 0], embedding_nd[:, 1], c=c)
    ax.set_title('2D Projection of Embeddings')

    for xy, label in zip(embedding_nd, labels):
        ax.text(xy[0], xy[1], label)
        
    return fig, embedding_nd


In [None]:
import equinox as eqx

model = eqx.nn.inference_mode(model)
key = jax.random.PRNGKey(0)
basins = []
embeddings = []
for basin in dataset.train_basins:
    basins.append(basin)
    static_data = dataset.x_s[basin]
    # embeddings.append(model.tealstm_i.cell.input_linear(static_data).flatten())
    embeddings.append(model.static_embedder(static_data, key).flatten())
embeddings = np.stack(embeddings)

In [None]:
from sklearn.preprocessing import MinMaxScaler

fig, embedding_nd = plot_umap(embeddings, 64, [], 3)

# Normalize the UMAP embeddings to the range [0, 1]
scaler = MinMaxScaler()
embedding_nd_norm = scaler.fit_transform(embedding_nd)

In [None]:
embeddings_mag = np.sqrt(np.mean(embedding_nd**2,axis=1))
# embeddings_mag

In [None]:
import geopandas as gpd
wqp_locs = gpd.read_file("/work/pi_kandread_umass_edu/tss-ml/data/NA_WQP/metadata/wqp_sites.shp")
wqp_locs = wqp_locs[[(loc in basins) for loc in wqp_locs.LocationID]]
wqp_locs = wqp_locs.set_index('LocationID')

In [None]:
columns = [f"UMAP{i:d}" for i in range(embedding_nd.shape[1])]
df = pd.DataFrame(embedding_nd, columns=columns)
df.index = basins

gdf_embeddings = wqp_locs.merge(df,left_index=True, right_index=True)

In [None]:
col = 'UMAP2'
gdf_embeddings.plot(col)
plt.title(col)

In [None]:
embed_dir = fig_dir / "embeddings"
embed_dir.mkdir(exist_ok=True)
gdf_embeddings.to_file(embed_dir / "sites_umap.shp")

In [None]:
plt.close('all')
# plt.imshow(model.tealstm_i.cell.input_linear.weight,aspect='auto')
plt.plot(model.tealstm_i.cell.input_linear.bias)

In [None]:

ms = 'USGS-07289000'
ms_t = 'USGS-07288955'

oh = 'USGS-03612600'
oh_t = 'USGS-03438500'

embedder = lambda b: model.tealstm_i.cell.input_linear(dataset.x_s[b])
# embedder = lambda b: model.static_embedder(dataset.x_s[b], key)
plt.close('all')
# plt.plot(embedder(ms))
# plt.plot(embedder(oh))

plt.plot(embedder(ms) - embedder(ms_t))
plt.plot(embedder(oh)- embedder(oh_t))



In [None]:
mat_plot =  model.ealstm_d.cell.weight_ih[0:64,:].T
labels = cfg['features']['daily']
plot_umap(mat_plot, 8, labels)


In [None]:
static_sums = np.sum(np.abs(embedded),axis=0)
sorted_ids = np.argsort(-static_sums)

static_features = list(dataset.attributes_scale.keys())
for i in range(25):
    idx = sorted_ids[i]
    print(f"{static_features[idx]}: {static_sums[idx]}")

In [None]:
plt.figure()
plt.hist(static_sums)

In [None]:
plt.close('all')
plt.imshow(static_embeddings, aspect='auto')
plt.show()

In [None]:
model.decoder.layers[0].attention_block.attention

In [None]:
plt.close('all')

fig, ax = plt.subplots(figsize=(6,6))

x = results['obs']
y = results['pred']
z = []

dt_max = 10
for dt in range(dt_max+1):
    if dt < dt_max:
        mask = np.abs(results['dt']) == dt
    else:
        mask = np.abs(results['dt']) >= dt
    metrics = get_all_metrics(x[mask], y[mask])
    z.append(metrics['lNSE'])

ax.plot(z)

xticks = range(0,dt_max+1,2)
xtick_labels = ax.get_xticklabels()
new_labels = ["≥" + str(int(tick)) if tick == dt_max else str(int(tick)) for tick in xticks]
ax.set_xticks(xticks)
ax.set_xticklabels(new_labels)

plt.show()