In [1]:
import sys
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import jax
from pathlib import Path
from importlib import reload

src = str(Path('../src').resolve())
if src not in sys.path:
    sys.path.append(src)

from config import read_config
from train import Trainer

In [3]:
import train
reload(train)
from train import Trainer

log_dir = Path('../runs/notebook/20240603_1359/')
with open(log_dir / 'config.pkl', 'rb') as file:
    cfg = pickle.load(file)
cfg['num_workers'] = 0
cfg['log'] = False

trainer = Trainer(cfg, None)
trainer.log_dir = log_dir

anomalies = [p.stem for p in list((log_dir / "exceptions").glob('*'))]


In [6]:
from collections import Counter

all_pairs = []
all_sites = []
all_dates = []
for a in anomalies:
    basins, dates, _ = trainer.load_state("exceptions/"+a)
    
    all_sites.extend(basins)
    all_dates.extend(dates)
    batch_pairs = list(zip(basins,dates))
    all_pairs.extend(batch_pairs)

site_counts = Counter(all_sites)
date_counts = Counter(all_dates)
pair_counts = Counter(all_pairs)

# Identify common pairs, dates, and sites
common_sites = {site: count for site, count in site_counts.items() if count > 1}
common_dates = {date: count for date, count in date_counts.items() if count > 1}
common_pairs = {pair: count for pair, count in pair_counts.items() if count > 1}


# Function to print sorted results
def print_sorted_counts(title, counts):
    sorted_counts = sorted(counts.items(), key=lambda item: item[1], reverse=True)
    print(title)
    for item, count in sorted_counts:
        if count > (len(anomalies)*0.5):
            print(f"{item}: {count}")

# Print results
print_sorted_counts("Common Sites:", common_sites)
print_sorted_counts("Common Dates:", common_dates)
print_sorted_counts("Common Location/Date Pairs:", common_pairs)

Common Sites:
USGS-01335770: 668
USGS-05325000: 277
USGS-01331095: 262
USGS-04193500: 250
USGS-01357500: 240
USGS-05474000: 238
USGS-06486000: 184
USGS-08354900: 177
USGS-04198000: 174
USGS-11303500: 155
USGS-05559600: 153
USGS-08358400: 151
USGS-04084445: 149
USGS-05586100: 147
USGS-04192500: 138
USGS-431510077363501: 131
USGS-05465500: 123
USGS-08330000: 121
USGS-08384500: 112
USGS-01578310: 109
USGS-06610000: 102
USGS-04108660: 93
USGS-12340500: 87
USGS-08332010: 84
USGS-04085059: 81
USGS-12334550: 78
USGS-09152500: 74
USGS-04085139: 74
USGS-12340000: 72
USGS-09095500: 72
USGS-05543500: 71
USGS-14243000: 68
USGS-07241550: 68
USGS-04069530: 68
USGS-02338000: 66
USGS-09261000: 63
USGS-06329500: 61
USGS-03374100: 61
USGS-08355490: 61
USGS-04102533: 61
USGS-04067651: 60
USGS-01327755: 59
USGS-04087170: 59
USGS-040851385: 58
USGS-06818000: 55
USGS-09251000: 55
USGS-04120250: 52
USGS-06452000: 52
USGS-09260000: 51
USGS-05532500: 51
USGS-02387000: 50
USGS-07288955: 50
USGS-02226010: 50
USG

In [None]:
import equinox as eqx
@eqx.filter_jit
def _predict_map(model, batch, keys):
    return jax.vmap(model)(batch,keys)

basins, dates, batch = trainer.load_state("exceptions/"+anomalies[0])
key = jax.random.PRNGKey(0)
batch_keys = jax.random.split(key, trainer.cfg['batch_size'])
pred = _predict_map(trainer.model, batch, batch_keys)

err = batch['y'][...,-1] - pred[...,-1]
idx_max_err = np.argmax(err)

x = batch['x_dd'][idx_max_err,...]

plt.close('all')
plt.plot(x)
plt.show()

In [None]:
pred

In [None]:
basins[idx_max_err]

In [None]:
%matplotlib widget
plt.close('all')
plt.scatter(batch['y'][...,-1],pred[...,-1])
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(10,3))
xd = axes[0].imshow(batch['x_dd'][:,:,0],aspect='auto')
fig.colorbar(xd, ax=axes[0])
xs = axes[1].imshow(batch['x_s'],aspect='auto')
fig.colorbar(xs, ax=axes[1]) 

In [None]:
batch['x_dd'][:,:,0].shape