In [None]:
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
from dysts.metrics import compute_metrics
from sklearn.decomposition import PCA
from tqdm import tqdm

from dystformer.chronos.pipeline import ChronosPipeline
from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import safe_standardize

In [None]:
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path="/stor/work/AMDG_Gilpin_Summer2024/checkpoints/pft_chattn_emb_w_poly-0/checkpoint-final",
    # pretrain_path="/stor/work/AMDG_Gilpin_Summer2024/checkpoints/pft_linattnpolyemb_from_scratch-0/checkpoint-final",
    device_map="cuda:2",
)

In [None]:
chronos_ft = ChronosPipeline.from_pretrained(
    # "/stor/work/AMDG_Gilpin_Summer2024/checkpoints/chronos_finetune_stand_updated-0/checkpoint-final",
    # "/stor/work/AMDG_Gilpin_Summer2024/checkpoints/chronos_mini_ft-0/checkpoint-final",
    "/stor/work/AMDG_Gilpin_Summer2024/checkpoints/chronos_bolt_mini-12/checkpoint-final",
    device_map="cuda:1",
    torch_dtype=torch.float32,
)

In [None]:
WORK = os.environ.get("WORK", "")
base_dir = f"{WORK}/physics-datasets"
re = 450
fpath = (
    f"{base_dir}/von_karman_street/vortex_street_velocities_Re_{re}_4800timepoints.npz"
)

In [None]:
vfield = np.load(fpath, allow_pickle=True)
vort_field = (
    np.diff(vfield, axis=1)[..., :-1, 1] + np.diff(vfield, axis=2)[:, :-1, :, 0]
)
vort_field_flattened = vort_field.reshape(vort_field.shape[0], -1)

In [None]:
def plot_model_prediction(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    title: str | None = None,
    show: bool = True,
    transpose: bool = True,
    indices: list[int] | None = None,
    **kwargs,
) -> np.ndarray:
    context = data[:, :context_length]
    groundtruth = data[:, context_length : context_length + prediction_length]
    context_tensor = torch.from_numpy(context.T if transpose else context).float()
    pred = (
        model.predict(context_tensor, prediction_length, **kwargs)
        .squeeze()
        .cpu()
        .numpy()
    )
    if not transpose:
        pred = pred.T

    total_length = context.shape[1] + prediction_length
    context_ts = np.arange(context.shape[1]) / total_length
    pred_ts = np.arange(context.shape[1], total_length) / total_length

    if show:
        if indices is None:
            indices = [0, 1, 2]
        fig = plt.figure(figsize=(15, 4))

        outer_grid = fig.add_gridspec(1, 2, width_ratios=[0.5, 0.5], wspace=0.05)
        gs = outer_grid[1].subgridspec(
            3, 1, height_ratios=[1 / 3] * 3, wspace=0, hspace=0
        )
        ax_3d = fig.add_subplot(outer_grid[0], projection="3d")
        ax_3d.plot(*context[indices], alpha=0.5, color="black", label="Context")
        ax_3d.plot(
            *groundtruth[indices], linestyle="--", color="black", label="Groundtruth"
        )
        ax_3d.plot(*pred.T[indices], color="red", label="Prediction")
        ax_3d.legend(loc="upper right", fontsize=12)
        ax_3d.set_xlabel("$x_{" + str(indices[0]) + "}$")
        ax_3d.set_ylabel("$x_{" + str(indices[1]) + "}$")
        ax_3d.set_zlabel("$x_{" + str(indices[2]) + "}$")
        if title is not None:
            ax_3d.set_title(title)

        axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
        for i, ax in zip(indices, axes_1d):
            ax.plot(context_ts, context[i], alpha=0.5, color="black")
            ax.plot(pred_ts, groundtruth[i], linestyle="--", color="black")
            ax.plot(pred_ts, pred[:, i], color="red")
            index_str = "{" + str(i) + "}"
            ax.set_ylabel(f"$x_{index_str}$")
            ax.set_aspect("auto")
        axes_1d[-1].set_xlabel("Time")

        plt.show()

    return pred

In [None]:
n_components = 512
pca = PCA(n_components=n_components)
pca.fit(vort_field_flattened)
X_ts = pca.transform(vort_field_flattened)  # (T, D)
eigenvectors = pca.components_  # (D, H*W)

## Show low-rank structure
plt.figure()
plt.plot(np.arange(n_components), pca.explained_variance_ratio_)
plt.semilogy()

## Plot trajectory
plt.figure()
plt.plot(X_ts[:, 0], X_ts[:, 1])

In [None]:
def reconstruct(
    pca_coeffs: np.ndarray, eigenvectors: np.ndarray, modes: int = -1
) -> np.ndarray:
    if modes == -1:
        modes = pca_coeffs.shape[1]
    return pca_coeffs[:, :modes] @ eigenvectors[:modes, :]

In [None]:
vort_recon = reconstruct(X_ts, eigenvectors)
vort_recon = vort_recon.reshape(
    vort_field.shape[0], vort_field.shape[1], vort_field.shape[2]
)
plt.figure()
plt.imshow(vort_recon[100 + 512, :, :].T, cmap="seismic")
plt.colorbar(shrink=0.5);

In [None]:
start = 2048  # ignore transient
stride = 1
subsampled_pca_coeffs = X_ts[start::stride, :]
stand_subsampled_pca_coeffs = safe_standardize(subsampled_pca_coeffs, axis=0)
predictions = plot_model_prediction(
    pft_model,
    stand_subsampled_pca_coeffs.T,
    context_length=512,
    prediction_length=128,
    limit_prediction_length=False,
    sliding_context=True,
    title="Von Karman Vortex Sheet PCA modes",
    indices=[0, 1, 2],
)


In [None]:
def forecast(
    model,
    data,
    context_length: int,
    prediction_length: int,
    transpose: bool = False,
    standardize: bool = True,
    batch_size: int | None = None,
    start: int | None = None,
    **kwargs,
) -> np.ndarray:
    if start is None:
        start = 0
    context = data[start : start + context_length]

    if standardize:
        context = safe_standardize(context, axis=0)

    torch.cuda.empty_cache()
    context_tensor = torch.from_numpy(context).float()
    if transpose:
        context_tensor = context_tensor.T

    if batch_size is None:
        pred = model.predict(context_tensor, prediction_length, **kwargs)
    else:
        pred = []
        for i in range(0, context_tensor.shape[0], batch_size):
            pred.append(
                model.predict(
                    context_tensor[i : i + batch_size], prediction_length, **kwargs
                )
            )
        pred = torch.cat(pred, dim=0)

    pred = pred.squeeze().detach().cpu().numpy()
    if transpose:
        pred = pred.T

    if standardize:
        pred = safe_standardize(
            pred, axis=0, context=data[start : start + context_length], denormalize=True
        )

    return pred


def plot_predicted_flow(
    prediction: np.ndarray,
    data: np.ndarray,
    eigenvectors: np.ndarray,
    context_length: int,
    prediction_length: int,
    num_modes: int,
    shape: tuple[int, int] = (vort_field.shape[1], vort_field.shape[2]),
    time_indices: list[int] | None = None,
    save_path: str | None = None,
    camera_ready: bool = False,
):
    groundtruth = data[context_length : context_length + prediction_length]

    recon = reconstruct(prediction, eigenvectors, modes=num_modes)
    recon = recon.reshape(prediction_length, shape[0], shape[1])
    groundtruth = reconstruct(groundtruth, eigenvectors, modes=num_modes)
    groundtruth = groundtruth.reshape(prediction_length, shape[0], shape[1])
    vabs = max(groundtruth.min(), groundtruth.max())

    if time_indices is None:
        time_indices = list(range(0, prediction_length, stride))

    aspect_ratio = shape[0] / shape[1]
    fig = plt.figure(
        figsize=(5 * (len(time_indices) + 1) / aspect_ratio, 5 * aspect_ratio)
    )
    gs = fig.add_gridspec(
        2,
        len(time_indices) + 1,
        width_ratios=[1] * (len(time_indices) + 1),
        height_ratios=[1, 1],
        wspace=0,
        hspace=0,
    )
    axes = np.array(
        [
            [fig.add_subplot(gs[i, j]) for j in range(len(time_indices) + 1)]
            for i in range(2)
        ]
    )
    for i, index in enumerate([0] + time_indices):
        groundtruth_slice = groundtruth[index, :, :]
        recon_slice = recon[index, :, :]
        rax = axes[0, i].imshow(
            recon_slice,
            vmin=-vabs,
            vmax=vabs,
            cmap="seismic",
        )
        gax = axes[1, i].imshow(
            groundtruth_slice,
            vmin=-vabs,
            vmax=vabs,
            cmap="seismic",
        )
        axes[0, i].set_title(
            f"t={context_length}" + (f" + {index}" if index > 0 else ""), fontsize=8
        )
        axes[0, i].set_xticks([])
        axes[0, i].set_yticks([])
        axes[1, i].set_xticks([])
        axes[1, i].set_yticks([])

        circle = plt.Circle(
            (0.5 * shape[1] + 1, 0.145 * shape[0]), 5, fill=True, color="black"
        )
        axes[0, i].add_patch(circle)
        circle = plt.Circle(
            (0.5 * shape[1] + 1, 0.145 * shape[0]), 5, fill=True, color="black"
        )
        axes[1, i].add_patch(circle)

    if not camera_ready:
        axes[0, 0].set_ylabel("Prediction")
        axes[1, 0].set_ylabel("Groundtruth (Low-rank)")

    if save_path is not None:
        plt.savefig(save_path)
    else:
        plt.show()


In [None]:
stride = 1
start = 100  # ignore transient
num_modes = 64
context_length = 512
prediction_length = 128
time_indices = [3, 6, 9, 12]

In [None]:
pft_predictions = forecast(
    pft_model,
    X_ts[start::stride],
    context_length=context_length,
    prediction_length=prediction_length,
    sliding_context=True,
    limit_prediction_length=False,
)

plot_predicted_flow(
    pft_predictions,
    X_ts[start::stride],
    eigenvectors,
    context_length=context_length,
    prediction_length=prediction_length,
    num_modes=num_modes,
    time_indices=time_indices,
    camera_ready=True,
)


In [None]:
chronos_predictions = forecast(
    chronos_ft,
    X_ts[start::stride],
    context_length=context_length,
    prediction_length=prediction_length,
    transpose=True,
    deterministic=True,
    num_samples=1,
    limit_prediction_length=False,
    batch_size=100,
)

plot_predicted_flow(
    chronos_predictions,
    X_ts[start::stride],
    eigenvectors,
    context_length=context_length,
    prediction_length=prediction_length,
    num_modes=num_modes,
    time_indices=time_indices,
    camera_ready=True,
)


In [None]:
def compute_rollout_metrics(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    step: int = 64,
    num_windows: int = 5,
    metrics: list[str] = ["mse", "mae", "smape"],
    **kwargs,
):
    full_metrics = defaultdict(
        lambda: np.zeros((num_windows, prediction_length // step))
    )
    for s in tqdm(range(num_windows), desc="Sampling contexts", total=num_windows):
        start = np.random.randint(0, len(data) - context_length - prediction_length)
        prediction = forecast(
            model, data, context_length, prediction_length, start=start, **kwargs
        )
        for i in range(0, prediction_length, step):
            pred = prediction[i : i + step]
            gt = data[start + context_length + i : start + context_length + i + step]
            submetrics = compute_metrics(pred, gt, include=metrics)
            for k, v in submetrics.items():
                full_metrics[k][s, i // step] += v

    mean_metrics = {k: v.mean(axis=0) for k, v in full_metrics.items()}
    std_metrics = {
        k: v.std(axis=0) / np.sqrt(num_windows) for k, v in full_metrics.items()
    }
    return mean_metrics, std_metrics


In [None]:
pft_mean_metrics, pft_std_metrics = compute_rollout_metrics(
    pft_model,
    X_ts[start::stride],
    context_length=context_length,
    prediction_length=prediction_length,
    step=64,
    num_windows=10,
    sliding_context=True,
    limit_prediction_length=False,
)


In [None]:
chronos_mean_metrics, chronos_std_metrics = compute_rollout_metrics(
    chronos_ft,
    X_ts[start::stride],
    context_length=context_length,
    prediction_length=prediction_length,
    step=64,
    num_windows=10,
    num_samples=1,
    limit_prediction_length=False,
    deterministic=True,
    transpose=True,
    batch_size=100,
)


In [None]:
metrics = ["mse", "mae", "smape"]
fig, axes = plt.subplots(1, len(metrics), figsize=(20, 5))
for i, m in enumerate(metrics):
    axes[i].plot(pft_mean_metrics[m], color="red")
    axes[i].fill_between(
        range(prediction_length // 64),
        pft_mean_metrics[m] - pft_std_metrics[m],
        pft_mean_metrics[m] + pft_std_metrics[m],
        color="red",
        alpha=0.2,
    )
    axes[i].plot(chronos_mean_metrics[m], color="blue")
    axes[i].fill_between(
        range(prediction_length // 64),
        chronos_mean_metrics[m] - chronos_std_metrics[m],
        chronos_mean_metrics[m] + chronos_std_metrics[m],
        color="blue",
        alpha=0.2,
    )
    axes[i].set_title(m)
plt.show()
