In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.decomposition import PCA

from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import safe_standardize

In [None]:
WORK = os.environ.get("WORK", "")
base_dir = f"{WORK}/physics-datasets"
fpath = f"{base_dir}/von_karman_street/vortex_street_velocities_Re_{1200}_largefile.npz"

In [None]:
vfield = np.load(fpath, allow_pickle=True)
vort_field = (
    np.diff(vfield, axis=1)[..., :-1, 1] + np.diff(vfield, axis=2)[:, :-1, :, 0]
)
vort_field_flattened = vort_field.reshape(vort_field.shape[0], -1)

In [None]:
def plot_model_prediction(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    title: str | None = None,
    show: bool = True,
    transpose: bool = True,
    **kwargs,
) -> np.ndarray:
    context = data[:, :context_length]
    groundtruth = data[:, context_length : context_length + prediction_length]
    context_tensor = torch.from_numpy(context.T if transpose else context).float()
    pred = (
        model.predict(context_tensor, prediction_length, **kwargs)
        .squeeze()
        .cpu()
        .numpy()
    )
    if not transpose:
        pred = pred.T

    total_length = context.shape[1] + prediction_length
    context_ts = np.arange(context.shape[1]) / total_length
    pred_ts = np.arange(context.shape[1], total_length) / total_length

    if show:
        fig = plt.figure(figsize=(15, 4))

        outer_grid = fig.add_gridspec(1, 2, width_ratios=[0.5, 0.5], wspace=0.05)
        gs = outer_grid[1].subgridspec(
            3, 1, height_ratios=[1 / 3] * 3, wspace=0, hspace=0
        )
        ax_3d = fig.add_subplot(outer_grid[0], projection="3d")
        ax_3d.plot(*context[:3], alpha=0.5, color="black", label="Context")
        ax_3d.plot(*groundtruth[:3], linestyle="--", color="black", label="Groundtruth")
        ax_3d.plot(*pred.T[:3], color="red", label="Prediction")
        ax_3d.legend(loc="upper right", fontsize=12)
        ax_3d.set_xlabel("$x_1$")
        ax_3d.set_ylabel("$x_2$")
        ax_3d.set_zlabel("$x_3$")
        if title is not None:
            ax_3d.set_title(title)

        axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
        for i, ax in enumerate(axes_1d):
            ax.plot(context_ts, context[i], alpha=0.5, color="black")
            ax.plot(pred_ts, groundtruth[i], linestyle="--", color="black")
            ax.plot(pred_ts, pred[:, i], color="red")
            ax.set_ylabel(f"$x_{i + 1}$")
            ax.set_aspect("auto")
        axes_1d[-1].set_xlabel("Time")

        plt.show()

    return pred

In [None]:
pca = PCA(n_components=20)
pca.fit(vort_field_flattened)
X_ts = pca.transform(vort_field_flattened)  # (T, D)
eigenvectors = pca.components_  # (D, H*W)

## Show low-rank structure
plt.figure()
plt.plot(pca.explained_variance_ratio_)

## Plot trajectory
plt.figure()
plt.plot(X_ts[:, 0], X_ts[:, 1])

In [None]:
num_approx = 20
vort_recon = X_ts[:, :num_approx] @ eigenvectors[:num_approx, :]
vort_recon = vort_recon.reshape(
    vort_field.shape[0], vort_field.shape[1], vort_field.shape[2]
)
plt.figure()
plt.imshow(vort_recon[0, :, :], cmap="seismic")
plt.colorbar();

In [None]:
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path="/stor/work/AMDG_Gilpin_Summer2024/checkpoints/pft_chattn_emb_w_poly-0/checkpoint-final",
    device_map="cuda:2",
)

In [None]:
start = 1000
stride = 2
subsampled_pca_coeffs = X_ts[start::stride, :]
stand_subsampled_pca_coeffs = safe_standardize(subsampled_pca_coeffs, axis=0)
print(stand_subsampled_pca_coeffs.shape)
predictions = plot_model_prediction(
    pft_model,
    stand_subsampled_pca_coeffs.T,
    512,
    128,
    limit_prediction_length=False,
    title="Von Karman Vortex Sheet PCA modes",
)
