In [None]:
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
from dysts.base import DynSys
from scipy.integrate import solve_ivp
from tqdm import trange

from dystformer.augmentations import StandardizeTransform
from dystformer.chronos.pipeline import ChronosPipeline
from dystformer.patchtst.pipeline import PatchTSTPipeline

In [None]:
if os.path.exists("../custom_style.mplstyle"):
    plt.style.use(["ggplot", "../custom_style.mplstyle"])

## KS Equation

In [None]:
class KuramotoShivashinsky(DynSys):
    """Implements the 1+1D KS equation in fourier space"""

    def __init__(self, L: float, modes: int):
        super().__init__(metadata_path=None, dimension=2 * modes, parameters={})
        self.L = L
        self.modes = modes
        self.dimension = 2 * self.modes
        self.wave_nums = 2 * np.pi * np.arange(0, self.modes + 2) / self.L
        self.N = self.dimension + 2

        # precompute some quantities
        self.freq_domain = np.zeros(self.modes + 2, dtype=np.complex128)
        self.nonlinear_factor = -0.5 * 1j * self.wave_nums * self.N
        self.diffusion_ffts = self.wave_nums**2 - self.wave_nums**4

    def to_spatial(self, q: np.ndarray, N: int) -> np.ndarray:
        """Inverse FFT of the modes to get u(x) at a certain time

        :param q: array of flattened fourier coefficients (real and imag components), can have batch dimensions
        :param N: grid resolution in the spatial domain

        :returns: solution in the spatial domain
        """
        coeffs = np.zeros(q.shape[:-1] + (self.modes + 2,), dtype=complex)
        coeffs[..., 1:-1] = q[..., : self.modes] + 1j * q[..., self.modes :]
        return np.fft.irfft(coeffs, n=N)

    def rhs(self, t: float, X: np.ndarray) -> np.ndarray:
        self.freq_domain[1:-1] = X[: self.modes] + 1j * X[self.modes :]
        u = np.fft.irfft(self.freq_domain, n=self.N)
        pseudospectral_term = self.nonlinear_factor * np.fft.rfft(u * u)
        linear_term = self.diffusion_ffts * self.freq_domain

        # repackage components
        flow = (linear_term + pseudospectral_term)[1:-1]
        return np.concatenate([np.real(flow), np.imag(flow)])

In [None]:
ks = KuramotoShivashinsky(L=100, modes=64)

tfinal = 100
rng = np.random.default_rng(12)  # 1234
ic = 0.1 * rng.normal(size=(ks.dimension,))
teval = np.linspace(0, tfinal, 4096)
sol = solve_ivp(
    ks.rhs, (0, tfinal), ic, method="DOP853", t_eval=teval, rtol=1e-8, atol=1e-8
)
ts, freq_traj = sol.t, sol.y.T
spatial_traj = ks.to_spatial(freq_traj, N=ks.dimension)

In [None]:
grid = np.linspace(0, ks.L, ks.dimension)
plt.figure(figsize=(10, 4))
plt.pcolormesh(ts, grid, spatial_traj.T, cmap="Spectral", shading="gouraud")
plt.colorbar()
plt.ylabel("x")
plt.xlabel("t")
plt.show()

In [None]:
run_name = "pft_chattn_emb_w_poly-0"  # "pft_chattn_noembed_pretrained_correct-0"
pipeline = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:3",
    torch_dtype=torch.float32,
)
pipeline

## Forecasting

### Our Model

In [None]:
def forecast(
    pipeline,
    trajectory: np.ndarray,
    context_length: int,
    normalize: bool = True,
    transpose: bool = False,
    prediction_length: int | None = None,
    **kwargs,
) -> np.ndarray:
    context = trajectory[:context_length]
    if normalize:
        normalizer = StandardizeTransform()
        context = normalizer(context, axis=0)

    if prediction_length is None:
        prediction_length = trajectory.shape[0] - context_length

    if transpose:
        context = context.T

    predictions = (
        pipeline.predict(
            context=torch.tensor(context).float(),
            prediction_length=prediction_length,
            limit_prediction_length=False,
            **kwargs,
        )
        .squeeze()
        .cpu()
        .numpy()
    )
    full_trajectory = np.concatenate([context, predictions], axis=1 if transpose else 0)

    if transpose:
        full_trajectory = full_trajectory.T

    if normalize:
        return normalizer(
            full_trajectory,
            axis=0,
            context=trajectory[:context_length],
            denormalize=True,
        )

    return full_trajectory

In [None]:
def plot_forecast(
    ts: np.ndarray,
    grid: np.ndarray,
    trajectory: np.ndarray,
    predictions: np.ndarray,
    run_name: str = "",
    context_length: int = 0,
    save_path: str | None = None,
    v_abs: float | None = None,
    show_colorbar: bool = True,
    show_ticks: bool = True,
    display_pred_error: bool = True,
    figsize: tuple[int, int] = (5, 5),
    title_kwargs: dict = {},
):
    fig, axes = plt.subplots(3, 1, sharex=True, figsize=figsize)

    vmin = min(trajectory.min(), predictions.min())
    vmax = max(trajectory.max(), predictions.max())
    vabs = max(abs(vmin), abs(vmax))
    if v_abs is not None:
        print(f"Using v_abs: {v_abs} instead of {vabs}")
        vabs = v_abs

    for i, (ax, data, label) in enumerate(
        zip(
            axes,
            [trajectory, predictions, predictions - trajectory],
            [
                "Ground Truth",
                f"Predictions ({run_name})",
                f"Prediction Error {f'({np.mean(np.abs(predictions - trajectory)):.2e}) ' if display_pred_error else ''}({run_name})",
            ],
        )
    ):
        im = ax.pcolormesh(
            ts, grid, data.T, cmap="Spectral", shading="gouraud", vmin=-vabs, vmax=vabs
        )
        if show_ticks:
            ax.set_ylabel("x")
        else:
            ax.set_ylabel("")
            ax.set_yticks([])
            ax.set_xticks([])
        ax.set_title(label, **title_kwargs)
        if show_colorbar:
            fig.colorbar(im, ax=ax)
        # draw black vertical line at middle of plot (x axis middle)
        ax.axvline(ts[context_length], color="black", linewidth=1)
        if i == 2:
            # draw a black dotted vertical line at the end of 128 pred length window
            ax.axvline(
                ts[context_length + 128], color="gray", linestyle="--", linewidth=1
            )
    if show_ticks:
        axes[-1].set_xlabel("t")
    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches="tight")

In [None]:
start_time = 1024
end_time = 2048
context_length = 512

In [None]:
# predict in spatial domain
preds_spatial = forecast(
    pipeline,
    spatial_traj[start_time:end_time],
    context_length,
    prediction_length=None,
    normalize=True,
    sliding_context=True,
)

In [None]:
plot_forecast(
    ts[start_time:end_time],
    grid,
    spatial_traj[start_time:end_time],
    preds_spatial,
    run_name="Our Model",
    context_length=context_length,
    save_path="figs_ks/ks_our_model_spatial.pdf",
    show_colorbar=True,
    show_ticks=False,
    display_pred_error=False,
    figsize=(5, 6),
    title_kwargs={"fontweight": "bold", "fontsize": 10},
    # v_abs=0.025,
)

In [None]:
n_runs = 20
parent_rng = np.random.default_rng(12)  # 1234
rng_stream = parent_rng.spawn(n_runs)

traj_lst = []

for rng in rng_stream:
    ic = 0.1 * rng.normal(size=(ks.dimension,))
    teval = np.linspace(0, tfinal, 4096)
    sol = solve_ivp(
        ks.rhs, (0, tfinal), ic, method="DOP853", t_eval=teval, rtol=1e-8, atol=1e-8
    )
    ts, freq_traj = sol.t, sol.y.T
    spatial_traj = ks.to_spatial(freq_traj, N=ks.dimension)
    traj_lst.append(spatial_traj)

# traj_lst = np.array(traj_lst)

In [None]:
len(traj_lst)

In [None]:
# predict in spatial domain
preds_spatial_lst = []

for spatial_traj in traj_lst:
    preds_spatial_curr = forecast(
        pipeline,
        spatial_traj[start_time:end_time],
        context_length,
        prediction_length=None,
        normalize=True,
        sliding_context=True,
    )
    preds_spatial_lst.append(preds_spatial_curr)

In [None]:
len(preds_spatial_lst)

In [None]:
def smape(x, y):
    """Symmetric mean absolute percentage error"""
    return 100 * np.mean(np.abs(x - y) / (np.abs(x) + np.abs(y))) * 2

In [None]:
preds_spatial.shape

In [None]:
def compute_pred_error(prediction, ground_truth, time_intervals_lst):
    pred_error_dict = {}
    for time_interval in time_intervals_lst:
        curr_mae = np.mean(
            np.abs(prediction[:time_interval] - ground_truth[:time_interval])
        )
        curr_rmse = np.sqrt(
            np.mean((prediction[:time_interval] - ground_truth[:time_interval]) ** 2)
        )
        curr_mse = np.mean(
            (prediction[:time_interval] - ground_truth[:time_interval]) ** 2
        )
        curr_smape = smape(prediction[:time_interval], ground_truth[:time_interval])
        error_dict = {
            "mae": curr_mae,
            "rmse": curr_rmse,
            "mse": curr_mse,
            "smape": curr_smape,
        }
        pred_error_dict[time_interval] = error_dict
    return pred_error_dict

In [None]:
time_intervals_lst = np.arange(64, 512 + 64, 64)
print(time_intervals_lst)

In [None]:
traj_lst[0][start_time:end_time][context_length:].shape

In [None]:
def get_mean_median_std_metrics_dicts_rollout(pred_lst):
    pred_error_dict_lst = []
    for preds, traj in zip(pred_lst, traj_lst):
        actual_preds = preds[context_length:]
        actual_gt = traj[start_time:end_time][context_length:]
        print(actual_preds.shape, actual_gt.shape)
        pred_error_dict_lst.append(
            compute_pred_error(actual_preds, actual_gt, time_intervals_lst)
        )

    metrics_lst = ["mse", "mae", "rmse", "smape"]
    metric_dict = defaultdict(dict)
    for time_interval in pred_error_dict_lst[0].keys():
        for metric in metrics_lst:
            values = []
            for pred_error_dict in pred_error_dict_lst:
                values.append(pred_error_dict[time_interval][metric])
            values = np.array(values)
            mean_metric = np.mean(values, axis=0)
            median_metric = np.median(values, axis=0)
            std_metric = np.std(values, axis=0)
            metric_dict[time_interval][metric] = {
                "mean": mean_metric,
                "median": median_metric,
                "std": std_metric,
            }

    mean_metrics_dict = defaultdict(dict)
    for time_interval in time_intervals_lst:
        for metric in metrics_lst:
            mean_metrics_dict[metric][time_interval] = metric_dict[time_interval][
                metric
            ]["mean"]

    median_metrics_dict = defaultdict(dict)
    for time_interval in time_intervals_lst:
        for metric in metrics_lst:
            median_metrics_dict[metric][time_interval] = metric_dict[time_interval][
                metric
            ]["median"]

    std_metrics_dict = defaultdict(dict)
    for time_interval in time_intervals_lst:
        for metric in metrics_lst:
            std_metrics_dict[metric][time_interval] = metric_dict[time_interval][
                metric
            ]["std"]

    return mean_metrics_dict, median_metrics_dict, std_metrics_dict

In [None]:
mean_metrics_dict, median_metrics_dict, std_metrics_dict = (
    get_mean_median_std_metrics_dicts_rollout(preds_spatial_lst)
)

In [None]:
list(mean_metrics_dict["mse"].values())

In [None]:
list(median_metrics_dict["mse"].values())

### Chronos Finetune

In [None]:
chronos_ft = ChronosPipeline.from_pretrained(
    "/stor/work/AMDG_Gilpin_Summer2024/checkpoints/chronos_bolt_mini-12/checkpoint-final",
    device_map="cuda:1",
    torch_dtype=torch.float32,
)
chronos_ft

In [None]:
def forecast_chronos(
    pipeline,
    trajectory: np.ndarray,
    context_length: int,
    chunk_size: int,
) -> np.ndarray:
    subchannel_predictions = []
    for i in trange(0, trajectory.shape[1] // chunk_size):
        subpreds = forecast(
            pipeline,
            trajectory[:, i * chunk_size : (i + 1) * chunk_size],
            context_length,
            prediction_length=None,
            transpose=True,
            normalize=False,
            num_samples=1,
        )
        subchannel_predictions.append(subpreds)

    return np.concatenate(subchannel_predictions, axis=1)

In [None]:
# # predict in frequency domain
# chronos_preds_freq = forecast_chronos(
#     chronos, freq_traj[start_time:end_time], context_length, chunk_size=ks.dimension
# )

# # convert to spatial domain
# chronos_preds_freq_to_spatial = ks.to_spatial(chronos_preds_freq, N=ks.dimension)

In [None]:
# plot_forecast(
#     ts[start_time:end_time],
#     grid,
#     spatial_traj[start_time:end_time],
#     chronos_preds_freq_to_spatial,
#     run_name="Chronos 20M Finetune",
#     context_length=context_length,
#     save_path="ks_freq_to_spatial.pdf",
# )

In [None]:
# spatial domain chronos prediction
chronos_ft_preds_spatial = forecast_chronos(
    chronos_ft,
    spatial_traj[start_time:end_time],
    context_length,
    chunk_size=ks.dimension,
)

In [None]:
plot_forecast(
    ts[start_time:end_time],
    grid,
    spatial_traj[start_time:end_time],
    chronos_ft_preds_spatial,
    run_name="Chronos 20M Finetune",
    context_length=context_length,
    save_path="figs_ks/ks_chronos_ft_spatial.pdf",
    show_colorbar=True,
    show_ticks=False,
    display_pred_error=False,
    # v_abs=0.025,
)

In [None]:
# predict in spatial domain
chronos_ft_preds_spatial_lst = []

for spatial_traj in traj_lst:
    chronos_ft_preds_spatial_curr = forecast_chronos(
        chronos_ft,
        spatial_traj[start_time:end_time],
        context_length,
        chunk_size=ks.dimension,
    )
    chronos_ft_preds_spatial_lst.append(chronos_ft_preds_spatial_curr)

In [None]:
mean_metrics_dict, median_metrics_dict, std_metrics_dict = (
    get_mean_median_std_metrics_dicts_rollout(chronos_ft_preds_spatial_lst)
)

### Chronos Zeroshot

In [None]:
chronos_zs = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-mini",
    device_map="cuda:1",
    torch_dtype=torch.float32,
)
chronos_zs

In [None]:
# chronos_zs_preds_freq = forecast_chronos(
#     chronos_zs, freq_traj[start_time:end_time], context_length, chunk_size=ks.dimension
# )

# # convert to spatial domain
# chronos_zs_preds_freq_to_spatial = ks.to_spatial(chronos_zs_preds_freq, N=ks.dimension)

In [None]:
# plot_forecast(
#     ts[start_time:end_time],
#     grid,
#     spatial_traj[start_time:end_time],
#     chronos_zs_preds_freq_to_spatial,
#     run_name="Chronos 20M",
#     context_length=context_length,
#     save_path="ks_chronos_zs_freq_to_spatial.pdf",
# )

In [None]:
# spatial domain chronos prediction
chronos_zs_preds_spatial = forecast_chronos(
    chronos_zs,
    spatial_traj[start_time:end_time],
    context_length,
    chunk_size=ks.dimension,
)

In [None]:
plot_forecast(
    ts[start_time:end_time],
    grid,
    spatial_traj[start_time:end_time],
    chronos_zs_preds_spatial,
    run_name="Chronos 20M",
    context_length=context_length,
    save_path="figs_ks/ks_chronos_zs_spatial.pdf",
    show_colorbar=True,
    show_ticks=False,
    display_pred_error=False,
    # v_abs=0.025,
)

In [None]:
# predict in spatial domain
chronos_zs_preds_spatial_lst = []

for spatial_traj in traj_lst:
    chronos_zs_preds_spatial_curr = forecast_chronos(
        chronos_zs,
        spatial_traj[start_time:end_time],
        context_length,
        chunk_size=ks.dimension,
    )
    chronos_zs_preds_spatial_lst.append(chronos_zs_preds_spatial_curr)

In [None]:
mean_metrics_dict, median_metrics_dict, std_metrics_dict = (
    get_mean_median_std_metrics_dicts_rollout(chronos_zs_preds_spatial_lst)
)

In [None]:
preds_spatial_lst

In [None]:
for metric_to_plot, title_metric_name in [
    ("smape", "sMAPE"),
    ("mse", "MSE"),
    ("mae", "MAE"),
    ("rmse", "RMSE"),
]:
    plt.figure(figsize=(5, 4))
    for run_name, plist in zip(
        ["Our Model", "Chronos 20M Finetune", "Chronos 20M"],
        [preds_spatial_lst, chronos_ft_preds_spatial_lst, chronos_zs_preds_spatial_lst],
    ):
        mean_metrics_dict, median_metrics_dict, std_metrics_dict = (
            get_mean_median_std_metrics_dicts_rollout(plist)
        )
        plt.plot(
            time_intervals_lst,
            list(median_metrics_dict[metric_to_plot].values()),
            label=run_name,
        )
        plt.fill_between(
            time_intervals_lst,
            np.array(list(mean_metrics_dict[metric_to_plot].values()))
            - np.array(list(std_metrics_dict[metric_to_plot].values()))
            / np.sqrt(len(time_intervals_lst)),
            np.array(list(mean_metrics_dict[metric_to_plot].values()))
            + np.array(list(std_metrics_dict[metric_to_plot].values()))
            / np.sqrt(len(time_intervals_lst)),
            alpha=0.2,
        )
    plt.legend(loc="lower right")
    plt.title(f"{title_metric_name}", fontweight="bold")
    plt.xlabel("Prediction Length")
    plt.tight_layout()
    plt.savefig(f"figs_ks/ks_all_models_{metric_to_plot}.pdf", bbox_inches="tight")
    plt.show()
    plt.close()

In [None]:
# make a plot with two subplots: one for smape and one for mae, stacked vertically
plt.figure(figsize=(4, 6))

# sMAPE subplot
plt.subplot(2, 1, 1)
for run_name, plist in zip(
    ["Our Model", "Chronos 20M Finetune", "Chronos 20M"],
    [preds_spatial_lst, chronos_ft_preds_spatial_lst, chronos_zs_preds_spatial_lst],
):
    mean_metrics_dict, median_metrics_dict, std_metrics_dict = (
        get_mean_median_std_metrics_dicts_rollout(plist)
    )
    plt.plot(
        time_intervals_lst,
        list(median_metrics_dict["smape"].values()),
        label=run_name,
    )
    plt.fill_between(
        time_intervals_lst,
        np.array(list(mean_metrics_dict["smape"].values()))
        - np.array(list(std_metrics_dict["smape"].values()))
        / np.sqrt(len(time_intervals_lst)),
        np.array(list(mean_metrics_dict["smape"].values()))
        + np.array(list(std_metrics_dict["smape"].values()))
        / np.sqrt(len(time_intervals_lst)),
        alpha=0.2,
    )
plt.legend(loc="lower right")
plt.title("sMAPE", fontweight="bold")

# MAE subplot
plt.subplot(2, 1, 2)
for run_name, plist in zip(
    ["Our Model", "Chronos 20M Finetune", "Chronos 20M"],
    [preds_spatial_lst, chronos_ft_preds_spatial_lst, chronos_zs_preds_spatial_lst],
):
    mean_metrics_dict, median_metrics_dict, std_metrics_dict = (
        get_mean_median_std_metrics_dicts_rollout(plist)
    )
    plt.plot(
        time_intervals_lst,
        list(median_metrics_dict["mae"].values()),
        label=run_name,
    )
    plt.fill_between(
        time_intervals_lst,
        np.array(list(mean_metrics_dict["mae"].values()))
        - np.array(list(std_metrics_dict["mae"].values()))
        / np.sqrt(len(time_intervals_lst)),
        np.array(list(mean_metrics_dict["mae"].values()))
        + np.array(list(std_metrics_dict["mae"].values()))
        / np.sqrt(len(time_intervals_lst)),
        alpha=0.2,
    )
plt.legend(loc="lower right")
plt.title("MAE", fontweight="bold")
plt.xlabel("Prediction Length")

plt.tight_layout()
plt.savefig("figs_ks/ks_all_models_smape_mae.pdf", bbox_inches="tight")
plt.show()