In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from dysts.base import DynSys
from scipy.integrate import solve_ivp
from tqdm import trange

from dystformer.chronos.pipeline import ChronosPipeline
from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import safe_standardize

device_rank = 3

## KS Equation

In [None]:
class KuramotoShivashinsky(DynSys):
    """Implements the 1+1D KS equation in fourier space"""

    def __init__(self, L: float, modes: int):
        super().__init__(metadata_path=None, dimension=2 * modes, parameters={})
        self.L = L
        self.modes = modes
        self.dimension = 2 * self.modes
        self.wave_nums = 2 * np.pi * np.arange(0, self.modes + 2) / self.L
        self.N = self.dimension + 2

        # precompute some quantities
        self.freq_domain = np.zeros(self.modes + 2, dtype=np.complex128)
        self.nonlinear_factor = -0.5 * 1j * self.wave_nums * self.N
        self.diffusion_ffts = self.wave_nums**2 - self.wave_nums**4

    def to_spatial(self, q: np.ndarray, N: int) -> np.ndarray:
        """Inverse FFT of the modes to get u(x) at a certain time

        :param q: array of flattened fourier coefficients (real and imag components), can have batch dimensions
        :param N: grid resolution in the spatial domain

        :returns: solution in the spatial domain
        """
        coeffs = np.zeros(q.shape[:-1] + (self.modes + 2,), dtype=complex)
        coeffs[..., 1:-1] = q[..., : self.modes] + 1j * q[..., self.modes :]
        return np.fft.irfft(coeffs, n=N)

    def rhs(self, t: float, X: np.ndarray) -> np.ndarray:
        self.freq_domain[1:-1] = X[: self.modes] + 1j * X[self.modes :]
        u = np.fft.irfft(self.freq_domain, n=self.N)
        pseudospectral_term = self.nonlinear_factor * np.fft.rfft(u * u)
        linear_term = self.diffusion_ffts * self.freq_domain

        # repackage components
        flow = (linear_term + pseudospectral_term)[1:-1]
        return np.concatenate([np.real(flow), np.imag(flow)])

In [None]:
ks = KuramotoShivashinsky(L=100, modes=64)

tfinal = 100
rng = np.random.default_rng(12)  # 1234
ic = 0.1 * rng.normal(size=(ks.dimension,))
teval = np.linspace(0, tfinal, 4096)
sol = solve_ivp(
    ks.rhs, (0, tfinal), ic, method="DOP853", t_eval=teval, rtol=1e-8, atol=1e-8
)
ts, freq_traj = sol.t, sol.y.T
spatial_traj = ks.to_spatial(freq_traj, N=ks.dimension)

In [None]:
grid = np.linspace(0, ks.L, ks.dimension)
plt.figure(figsize=(10, 4))
plt.pcolormesh(ts, grid, spatial_traj.T, cmap="Spectral", shading="gouraud")
plt.colorbar()
plt.ylabel("x")
plt.xlabel("t")
plt.show()

In [None]:
def plot_3d_trajectory(
    trajectory: np.ndarray,
    title: str = "3D Trajectory",
    figsize: tuple[int, int] = (12, 8),
) -> None:
    """Plot the first three dimensions of a trajectory in 3D space.

    Args:
        trajectory: Array of shape (T, D) where T is time steps and D is dimensions
        title: Plot title
        figsize: Figure size in inches (width, height)
    """
    if trajectory.shape[1] < 3:
        raise ValueError(
            f"Trajectory must have at least 3 dimensions, but has {trajectory.shape[1]}"
        )

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection="3d")

    # Plot the trajectory
    ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2])

    # Add scatter points to show direction
    n_points = min(20, len(trajectory))
    indices = np.linspace(0, len(trajectory) - 1, n_points, dtype=int)
    colors = plt.cm.viridis(np.linspace(0, 1, len(indices)))

    ax.scatter(
        trajectory[indices, 0],
        trajectory[indices, 1],
        trajectory[indices, 2],
        c=colors,
        s=30,
        alpha=0.8,
    )

    ax.set_xlabel("Dimension 1")
    ax.set_ylabel("Dimension 2")
    ax.set_zlabel("Dimension 3")
    ax.set_title(title)

    plt.tight_layout()
    plt.show()


# Example: Plot the first 1000 time steps using the first 3 spatial points
plot_3d_trajectory(freq_traj, title="KS Equation - First 3 Frequencies")

In [None]:
run_name = "pft_chattn_emb_w_poly-0"
pipeline = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map=f"cuda:{device_rank}",
    torch_dtype=torch.float32,
)
pipeline

# Forecast Visualization

In [None]:
def forecast(
    pipeline,
    trajectory: np.ndarray,
    context_length: int,
    normalize: bool = True,
    transpose: bool = False,
    prediction_length: int | None = None,
    **kwargs,
) -> np.ndarray:
    context = trajectory[:context_length]
    if normalize:
        context = safe_standardize(context, axis=0)

    if prediction_length is None:
        prediction_length = trajectory.shape[0] - context_length

    if transpose:
        context = context.T

    predictions = (
        pipeline.predict(
            context=torch.tensor(context).float(),
            prediction_length=prediction_length,
            limit_prediction_length=False,
            **kwargs,
        )
        .squeeze()
        .cpu()
        .numpy()
    )
    full_trajectory = np.concatenate([context, predictions], axis=1 if transpose else 0)

    if transpose:
        full_trajectory = full_trajectory.T

    if normalize:
        return safe_standardize(
            full_trajectory,
            axis=0,
            context=trajectory[:context_length],
            denormalize=True,
        )

    return full_trajectory

In [None]:
def plot_forecast(
    ts: np.ndarray,
    grid: np.ndarray,
    trajectory: np.ndarray,
    predictions: np.ndarray,
    run_name: str = "",
    context_length: int = 0,
    save_path: str | None = None,
    v_abs: float | None = None,
    prediction_horizon: int = 128,
):
    fig, axes = plt.subplots(3, 1, sharex=True, figsize=(9, 9))

    vmin = min(trajectory.min(), predictions.min())
    vmax = max(trajectory.max(), predictions.max())
    vabs = v_abs or max(abs(vmin), abs(vmax))

    for i, (ax, data, label) in enumerate(
        zip(
            axes,
            [trajectory, predictions, predictions - trajectory],
            [
                "Ground Truth",
                f"Predictions ({run_name})",
                f"Prediction Error ({np.mean(np.abs(predictions - trajectory)):.2e}) ({run_name})",
            ],
        )
    ):
        im = ax.pcolormesh(
            ts, grid, data.T, cmap="Spectral", shading="gouraud", vmin=-vabs, vmax=vabs
        )
        ax.set_ylabel("x")
        ax.set_title(label, fontweight="bold")
        fig.colorbar(im, ax=ax)
        # draw black vertical line at middle of plot (x axis middle)
        ax.axvline(ts[context_length], color="black", linewidth=1)
        if i == 2:
            # draw a black dotted vertical line at the end of 128 pred length window
            ax.axvline(
                ts[context_length + prediction_horizon],
                color="gray",
                linestyle="--",
                linewidth=1,
            )
    axes[-1].set_xlabel("t")
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches="tight")

    return vabs


In [None]:
start_time = 1024
end_time = 2048
context_length = 512

### Our Model

In [None]:
# predict in frequency domain
preds_freq = forecast(
    pipeline,
    freq_traj[start_time:end_time],
    context_length,
    prediction_length=512,
    normalize=True,
    sliding_context=True,
)

# convert to spatial domain
preds_freq_to_spatial = ks.to_spatial(preds_freq, N=ks.dimension)

In [None]:
our_freq_vabs = plot_forecast(
    ts[start_time:end_time],
    grid,
    spatial_traj[start_time:end_time],
    preds_freq_to_spatial,
    run_name="Our Model",
    context_length=context_length,
    save_path="../figures/ks_our_model_freq_to_spatial.pdf",
)

In [None]:
# predict in spatial domain
preds_spatial = forecast(
    pipeline,
    spatial_traj[start_time:end_time],
    context_length,
    prediction_length=None,
    normalize=True,
    sliding_context=True,
)

In [None]:
our_spatial_vabs = plot_forecast(
    ts[start_time:end_time],
    grid,
    spatial_traj[start_time:end_time],
    preds_spatial,
    run_name="Our Model",
    context_length=context_length,
    save_path="../figures/ks_our_model_spatial.pdf",
)

### Chronos Finetune

In [None]:
chronos_ft = ChronosPipeline.from_pretrained(
    "/stor/work/AMDG_Gilpin_Summer2024/checkpoints/chronos_finetune_stand_updated-0/checkpoint-final",
    # "/stor/work/AMDG_Gilpin_Summer2024/checkpoints/chronos_mini_ft-0/checkpoint-final",
    device_map=f"cuda:{device_rank}",
    torch_dtype=torch.float32,
)
chronos_ft

In [None]:
def forecast_chronos(
    pipeline,
    trajectory: np.ndarray,
    context_length: int,
    chunk_size: int,
) -> np.ndarray:
    subchannel_predictions = []
    for i in trange(0, trajectory.shape[1] // chunk_size):
        subpreds = forecast(
            pipeline,
            trajectory[:, i * chunk_size : (i + 1) * chunk_size],
            context_length,
            prediction_length=None,
            transpose=True,
            normalize=False,
            num_samples=1,
        )
        subchannel_predictions.append(subpreds)

    return np.concatenate(subchannel_predictions, axis=1)

In [None]:
# predict in frequency domain
chronos_preds_freq = forecast_chronos(
    chronos_ft, freq_traj[start_time:end_time], context_length, chunk_size=ks.dimension
)

# convert to spatial domain
chronos_preds_freq_to_spatial = ks.to_spatial(chronos_preds_freq, N=ks.dimension)

In [None]:
plot_forecast(
    ts[start_time:end_time],
    grid,
    spatial_traj[start_time:end_time],
    chronos_preds_freq_to_spatial,
    run_name="Chronos 20M Finetune",
    context_length=context_length,
    save_path="../figures/ks_chronos_ft_freq_to_spatial.pdf",
    v_abs=our_freq_vabs,
)

In [None]:
# spatial domain chronos prediction
chronos_preds_spatial = forecast_chronos(
    chronos_ft,
    spatial_traj[start_time:end_time],
    context_length,
    chunk_size=ks.dimension,
)

In [None]:
plot_forecast(
    ts[start_time:end_time],
    grid,
    spatial_traj[start_time:end_time],
    chronos_preds_spatial,
    run_name="Chronos 20M Finetune",
    context_length=context_length,
    save_path="../figures/ks_chronos_ft_spatial.pdf",
    v_abs=our_spatial_vabs,
)

### Chronos Zeroshot

In [None]:
chronos_zs = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-mini",
    device_map=f"cuda:{device_rank}",
    torch_dtype=torch.float32,
)
chronos_zs

In [None]:
chronos_zs_preds_freq = forecast_chronos(
    chronos_zs, freq_traj[start_time:end_time], context_length, chunk_size=ks.dimension
)

# convert to spatial domain
chronos_zs_preds_freq_to_spatial = ks.to_spatial(chronos_zs_preds_freq, N=ks.dimension)

In [None]:
plot_forecast(
    ts[start_time:end_time],
    grid,
    spatial_traj[start_time:end_time],
    chronos_zs_preds_freq_to_spatial,
    run_name="Chronos 20M",
    context_length=context_length,
    save_path="../figures/ks_chronos_zs_freq_to_spatial.pdf",
    v_abs=our_freq_vabs,
)

In [None]:
# spatial domain chronos prediction
chronos_zs_preds_spatial = forecast_chronos(
    chronos_zs,
    spatial_traj[start_time:end_time],
    context_length,
    chunk_size=ks.dimension,
)

In [None]:
plot_forecast(
    ts[start_time:end_time],
    grid,
    spatial_traj[start_time:end_time],
    chronos_zs_preds_spatial,
    run_name="Chronos 20M",
    context_length=context_length,
    save_path="../figures/ks_chronos_zs_spatial.pdf",
    v_abs=our_spatial_vabs,
)

# Rollout Evaluation

In [None]:
from collections import defaultdict

from dysts.metrics import smape

In [None]:
# repeated for convenience
start_time = 1024
end_time = 2048
context_length = 512

In [None]:
n_runs = 20
parent_rng = np.random.default_rng(12)
rng_stream = parent_rng.spawn(n_runs)

predict_spatial = True  # predict in spatial domain instead of frequency domain
convert_to_spatial = False  # if prediction in freq domain, convert to spatial domain

trajectories = []

for rng in rng_stream:
    ic = 0.1 * rng.normal(size=(ks.dimension,))
    teval = np.linspace(0, tfinal, 4096)
    sol = solve_ivp(
        ks.rhs, (0, tfinal), ic, method="DOP853", t_eval=teval, rtol=1e-8, atol=1e-8
    )
    ts, freq_traj = sol.t, sol.y.T
    if predict_spatial:
        trajectories.append(ks.to_spatial(freq_traj, N=ks.dimension))
    else:
        trajectories.append(freq_traj)

time_intervals = [(0, end) for end in np.arange(64, 512 + 64, 64)]

In [None]:
def compute_pred_error(prediction, ground_truth, time_intervals: list[tuple[int, int]]):
    pred_error_dict = {}
    for start, end in time_intervals:
        curr_mae = np.mean(np.abs(prediction[start:end] - ground_truth[start:end]))
        curr_rmse = np.sqrt(
            np.mean((prediction[start:end] - ground_truth[start:end]) ** 2)
        )
        curr_mse = np.mean((prediction[start:end] - ground_truth[start:end]) ** 2)
        curr_smape = smape(prediction[start:end], ground_truth[start:end])
        error_dict = {
            "mae": curr_mae,
            "rmse": curr_rmse,
            "mse": curr_mse,
            "smape": curr_smape,
        }
        pred_error_dict[start, end] = error_dict
    return pred_error_dict


def get_mean_median_std_metrics_dicts_rollout(
    predictions: list[np.ndarray],
    trajectories: list[np.ndarray],
    time_intervals: list[tuple[int, int]],
):
    pred_error_dict_lst = []
    for preds, traj in zip(predictions, trajectories):
        actual_preds = preds[context_length:]
        actual_gt = traj[start_time:end_time][context_length:]
        pred_error_dict_lst.append(
            compute_pred_error(actual_preds, actual_gt, time_intervals)
        )

    metrics_lst = ["mse", "mae", "rmse", "smape"]
    metric_dict = defaultdict(dict)
    for time_interval in pred_error_dict_lst[0].keys():
        for metric in metrics_lst:
            values = []
            for pred_error_dict in pred_error_dict_lst:
                values.append(pred_error_dict[time_interval][metric])
            values = np.array(values)
            mean_metric = np.mean(values, axis=0)
            median_metric = np.median(values, axis=0)
            std_metric = np.std(values, axis=0)
            metric_dict[time_interval][metric] = {
                "mean": mean_metric,
                "median": median_metric,
                "std": std_metric,
            }

    mean_metrics_dict = defaultdict(dict)
    for time_interval in time_intervals:
        for metric in metrics_lst:
            mean_metrics_dict[metric][time_interval] = metric_dict[time_interval][
                metric
            ]["mean"]

    median_metrics_dict = defaultdict(dict)
    for time_interval in time_intervals:
        for metric in metrics_lst:
            median_metrics_dict[metric][time_interval] = metric_dict[time_interval][
                metric
            ]["median"]

    std_metrics_dict = defaultdict(dict)
    for time_interval in time_intervals:
        for metric in metrics_lst:
            std_metrics_dict[metric][time_interval] = metric_dict[time_interval][
                metric
            ]["std"]

    return mean_metrics_dict, median_metrics_dict, std_metrics_dict


### Our Model

In [None]:
preds = []

for traj in trajectories:
    sample_pred = forecast(
        pipeline,
        traj[start_time:end_time],
        context_length,
        prediction_length=None,
        normalize=True,
        sliding_context=True,
    )
    if convert_to_spatial and not predict_spatial:
        sample_pred = ks.to_spatial(sample_pred, N=ks.dimension)
    preds.append(sample_pred)


### Chronos Finetune

In [None]:
chronos_ft_preds = []

for traj in trajectories:
    chronos_ft_sample_pred = forecast_chronos(
        chronos_ft,
        traj[start_time:end_time],
        context_length,
        chunk_size=ks.dimension,
    )
    if convert_to_spatial and not predict_spatial:
        chronos_ft_sample_pred = ks.to_spatial(chronos_ft_sample_pred, N=ks.dimension)
    chronos_ft_preds.append(chronos_ft_sample_pred)

### Chronos Zeroshot

In [None]:
chronos_zs_preds = []

for traj in trajectories:
    chronos_zs_sample_pred = forecast_chronos(
        chronos_zs,
        traj[start_time:end_time],
        context_length,
        chunk_size=ks.dimension,
    )
    if convert_to_spatial and not predict_spatial:
        chronos_zs_sample_pred = ks.to_spatial(chronos_zs_sample_pred, N=ks.dimension)
    chronos_zs_preds.append(chronos_zs_sample_pred)

### Plot Results

In [None]:
end_times = [end_time for _, end_time in time_intervals]
for metric_to_plot, title_metric_name in [
    ("smape", "sMAPE"),
    ("mse", "MSE"),
    ("mae", "MAE"),
    ("rmse", "RMSE"),
]:
    plt.figure(figsize=(5, 4))
    for run_name, plist in zip(
        ["Our Model", "Chronos 20M Finetune", "Chronos 20M"],
        [preds, chronos_ft_preds, chronos_zs_preds],
    ):
        mean_metrics_dict, median_metrics_dict, std_metrics_dict = (
            get_mean_median_std_metrics_dicts_rollout(
                plist, trajectories, time_intervals
            )
        )
        plt.plot(
            end_times,
            list(mean_metrics_dict[metric_to_plot].values()),
            label=run_name,
        )
        plt.fill_between(
            end_times,
            np.array(list(mean_metrics_dict[metric_to_plot].values()))
            - np.array(list(std_metrics_dict[metric_to_plot].values()))
            / np.sqrt(len(time_intervals)),
            np.array(list(mean_metrics_dict[metric_to_plot].values()))
            + np.array(list(std_metrics_dict[metric_to_plot].values()))
            / np.sqrt(len(time_intervals)),
            alpha=0.2,
        )
    plt.xticks(end_times)
    plt.legend(loc="lower right")
    plt.title(f"{title_metric_name}", fontweight="bold")
    plt.xlabel("Prediction Length")
    plt.tight_layout()
    plt.savefig(f"../figures/ks_all_models_{metric_to_plot}.pdf", bbox_inches="tight")
    plt.show()
    plt.close()