In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
NUM_ITERATIONS = np.array([0, 20000, 40000, 60000, 80000, 100000])
result_csv_files = [
    '../eval_results/train_base_metrics.csv',
    '../eval_results/train_20000_metrics.csv',
    '../eval_results/train_40000_metrics.csv',
    '../eval_results/train_60000_metrics.csv',
    '../eval_results/train_80000_metrics.csv',
    '../eval_results/train_100000_metrics.csv',
]

In [None]:
from collections import defaultdict

smape_dict = defaultdict(list)
mase_dict = defaultdict(list)
rmse_dict = defaultdict(list)
wql_dict = defaultdict(list)
for file in result_csv_files:
    print('Processing file:', file)
    # read the CSV file
    df = pd.read_csv(file)
    # get the smape column
    dyst = df['dataset']
    for i in range(len(df)):
        smape_dict[dyst[i]].append(df['sMAPE'][i])
        mase_dict[dyst[i]].append(df['MASE'][i])
        rmse_dict[dyst[i]].append(df['RMSE'][i])
        wql_dict[dyst[i]].append(df['WQL'][i])


In [None]:
print(smape_dict)

In [None]:
list(smape_dict.keys())

In [None]:
def plot_metrics_dict(metrics_dict, title, top_n=None, log_scale=True):
    fig, ax = plt.subplots(figsize=(5, 5))
    for key in list(metrics_dict.keys())[:top_n]:
        values = metrics_dict[key]
        if NUM_ITERATIONS is not None:
            ax.plot(NUM_ITERATIONS, values, '.-', alpha=0.5, label=key)
        else:
            ax.plot(values, '.-', alpha=0.5, label=key)
        if log_scale:
            ax.set_yscale('log')
        ax.set_xlabel('Iterations')
    ax.set_ylabel(title)
    ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    ax.set_title(title)
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.show()

In [None]:
plot_metrics_dict(smape_dict, 'sMAPE', top_n=None)

In [None]:
plot_metrics_dict(smape_dict, 'sMAPE', top_n=10)

In [None]:
plot_metrics_dict(mase_dict, 'MASE', top_n=10)

In [None]:
plot_metrics_dict(rmse_dict, 'RMSE', top_n=10)

In [None]:
plot_metrics_dict(wql_dict, 'WQL', top_n=10)

In [None]:
def plot_agg_metrics_dict(all_metrics_dict, title='Averaged Metrics', top_n=None, log_scale=True):
    fig, ax = plt.subplots(figsize=(5, 5))
    for metric_name, metrics_dict in all_metrics_dict.items():
        print("metric: ", metric_name)
        avg_values = np.mean(list(metrics_dict.values())[:top_n], axis=0)
        print(avg_values.shape)
        if NUM_ITERATIONS is not None:
            ax.plot(NUM_ITERATIONS, avg_values, '.-', alpha=0.5, label=metric_name)
        else:
            ax.plot(avg_values, '.-', alpha=0.5, label='Average')
    if log_scale:
        ax.set_yscale('log')
    ax.set_xlabel('Iterations')
    ax.set_ylabel(title)
    ax.set_title(title)
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.show()

In [None]:
all_metrics_dict = {
    'sMAPE': smape_dict,
    'MASE': mase_dict,
    'RMSE': rmse_dict,
    'WQL': wql_dict,
}

In [None]:
plot_agg_metrics_dict(all_metrics_dict, 'Averaged Metrics', top_n=None, log_scale=True)

In [None]:
from dysts.flows import Lorenz

model = Lorenz()
model.gamma = 1
model.ic = np.array([0.1, 0.0, 5])
sol = model.make_trajectory(1024)

In [None]:
sol_x = sol[:, 0]

In [None]:
import torch
from chronos_dysts.pipeline import ChronosPipeline

In [None]:
# Load model from checkpoint
model_id = "/stor/work/AMDG_Gilpin_Summer2024/checkpoints/checkpoint-final"
device = "cpu"
torch_dtype = getattr(torch, "bfloat16")
print(f"Loading Chronos checkpoint: {model_id} onto device: {device}")
pipeline = ChronosPipeline.from_pretrained(
    model_id,
    device_map=device,
    torch_dtype=torch_dtype,
)

In [None]:
context_length = 512
prediction_length = 64

In [None]:
# TODO: this takes a long time, how about we multiprocess it?
# NOTE: about 2.5 min for 1 sample on CPU, seemingly scales linearly with num_samples
forecast = pipeline.predict(
    context=torch.tensor(sol_x[:context_length]),
    prediction_length=prediction_length,
    num_samples=20,
)

In [None]:
forecast.shape

In [None]:
plt.plot(sol_x[context_length:context_length + prediction_length], label="True")
plt.plot(forecast.squeeze(), label="Forecast")

In [None]:
all_forecasts = forecast.squeeze().T
print(all_forecasts.shape)

In [None]:
plt.plot(sol_x[context_length:context_length + prediction_length], color='k', label="True");
plt.plot(all_forecasts, alpha=0.5, linewidth=1, label="Forecast");

In [None]:
plt.plot(sol_x, color='k', label="True");