In [None]:
## ***
## NOTE: modify common/torch/ops.py for your gpu
## ***

## versions:
## Python    : 3.11.5
## numpy     : 1.26.0
## torch     : 2.1.0
## pandas    : 2.1.1

# licensed under the Creative Commons - Attribution-NonCommercial 4.0
# International license (CC BY-NC 4.0):
# https://creativecommons.org/licenses/by-nc/4.0/. 

In [1]:
import numpy as np
import pandas as pd
import torch as t
import matplotlib.pyplot as plt
import os

from common.torch.snapshots import SnapshotManager

from covid_hub.data_utils import download_training_data
from covid_hub.forecast import Struct, default_settings, run_tests, generate_ensemble
from covid_hub.forecast import pickle_results, read_pickle, plotpred, output_figs, output_csv
from covid_hub.forecast import load_training, normalize_training, make_training_fn, ensemble_loop, generate_quantiles


In [2]:
settings = default_settings()
settings

iterations = 400
init_LR = 0.00025
batch_size = 256
lookback_opts = [3, 4, 5, 6]
random_reps = 5
horizon = 40
data_suffix = 3ma
targ_var = h
sqrt_transform = True
lfn_name = t_nll
force_positive_forecast = False
model_prefix = t_sqrt
normalize_target = False
use_windowed_norm = True
use_static_cat = False
exog_vars = ['doy', 'vacc_rate']
nbeats_stacks = 12
nbeats_hidden_dim = 512
nbeats_dropout = 0.2
encoder_k = 5
encoder_hidden_dim = 128
encoder_dropout = 0.2
qtiles = [0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99]
delete_models = False

In [None]:
## if running many tests and not enough storage:
#settings.delete_models = True

In [4]:
generate_current_forecast = False

In [None]:
if generate_current_forecast:
    ## get the latest training data
    download_training_data()

    cut = None ## use all avail data
    forecast_delay = 10 ## days from end of most recent data to expected "day 0" of forecast
    
    rstate = generate_ensemble(settings, cut)
    pickle_results(rstate)
    output_figs(rstate)
    output_csv(rstate, forecast_delay)

In [3]:
#rstate = Struct()
#load_training(rstate, settings, None)
## '2023-06-24' --> forecast day 0 = 10 day delay from data end (date on covid hub = 8 days after data end)
#np.where(rstate.data_index == '2023-06-24')[0][0] - 7
## <-- '2023-06-10' : forecast day 0 = 2 day delay from data end (date on covid hub = data end date)
#np.where(rstate.data_index == '2022-10-01')[0][0] + 1


1068

In [None]:
test_all_2023 = False

In [15]:
if test_all_2023:
    ## '2023-06-24' --> forecast day 0 = 10 day delay from data end (date on covid hub = 8 days after data end)
    ## <-- '2023-06-10' : forecast day 0 = 2 day delay from data end (date on covid hub = data end date)
    test_cut_vals1 = list(range(901,1068,7))
    test_cut_vals2 = list(range(1068,1258,7))
    test_cut_vals = test_cut_vals1 + test_cut_vals2
    forecast_delay_days = [10 if x > 1067 else 2 for x in test_cut_vals]
    run_tests(settings, test_cut_vals, forecast_delay_days)


In [None]:
test_cut_vals = [1278, 1271, 1264, 1208, 1152, 1145, 1096] + [950, 908, 901, 740, 733, 642]
test_cut_vals = [908, 901, 740, 733]

test_cut_vals = [1187,1180,1173,1131,1068] + [1062, 1055, 1041, 1020]
test_cut_vals = [929, 922, 845, 831, 817, 810]
test_cut_vals = [1187,1131,1062,1020,929,922,845]

forecast_delay_days = [10 if x > 1067 else 2 for x in test_cut_vals]
settings.random_reps = 1
run_tests(settings, test_cut_vals, forecast_delay_days)


In [None]:
test_cut_vals = [1020, 922, 1131, 1208, 1264]
forecast_delay_days = [10 if x > 1067 else 2 for x in test_cut_vals]
settings.model_prefix = "log"
settings.targ_var = "h_log"
settings.random_reps = 1
run_tests(settings, test_cut_vals, forecast_delay_days)


In [None]:
#test_cut_vals = [1278, 1271, 1152, 1145, 1096] + [950, 908, 901] + [1187, 1062, 929, 1180, 1173, 1068, 1055, 1041]
test_cut_vals = [1264, 1208, 1020, 1131]
forecast_delay_days = [10 if x > 1067 else 2 for x in test_cut_vals]

settings.model_prefix = "sqrt"
settings.random_reps = 1
run_tests(settings, test_cut_vals, forecast_delay_days)


In [None]:
print(list(rstate.fc_med.keys()), '\n', settings.lookback_opts)

In [None]:
horizon = rstate.settings.horizon
vals_train = rstate.vals_train ## read only
test_targets = rstate.test_targets ## read only
us_train = vals_train["nat_scale"].sum(axis=0,keepdims=True)
us_test = test_targets.sum(axis=0,keepdims=True) if test_targets is not None else None
x0 = rstate.cut - 400 if rstate.cut is not None else vals_train["nat_scale"].shape[1] - 400

#k = "median"
k = [*rstate.fc_med.keys()][3]
loc_idx = 20
plotpred(rstate.fc_med, k, loc_idx, vals_train["nat_scale"], test_targets, horizon, rstate.fc_lower, rstate.fc_upper, x0)
plotpred(rstate.us_med, k, 0, us_train, us_test, horizon, rstate.us_lower, rstate.us_upper, x0)


In [None]:
#f = "https://media.githubusercontent.com/media/reichlab/covid19-forecast-hub/master/data-truth/truth-Incident%20Hospitalizations.csv"
#df = pd.read_csv(f,dtype={"location":str})
#d_str = pd.to_datetime(df.date).max().strftime("%m-%d")
#df.to_csv("storage/truth-inc-hosp-" + d_str + ".csv", index=False)

In [None]:
print([k for k in rstate.mu_fc])

In [None]:
k = [k for k in rstate.mu_fc][0]
total_iter = rstate.settings.iterations
snapshot_manager = SnapshotManager(snapshot_dir=os.path.join('hub_model_snapshots', k), total_iterations=total_iter)

ldf = snapshot_manager.load_training_losses()
_, ax = plt.subplots(figsize=[4,3])
ax.plot(ldf)
plt.show()