In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy import stats
from sklearn.linear_model import LinearRegression
from tqdm import tqdm

from panda.patchtst.pipeline import PatchTSTPipeline
from panda.utils import (
    apply_custom_style,
    get_system_filepaths,
    load_trajectory_from_arrow,
)

In [None]:
# Apply matplotlib style from config
apply_custom_style("../config/plotting.yaml")

In [None]:
# run_name = "mlm_stand_chattn_noembed-0"
run_name = "panda_mlm_nh12_dmodel768_mixedp-2"

In [None]:
model_pipeline = PatchTSTPipeline.from_pretrained(
    mode="pretrain",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:6",
)

In [None]:
model_pipeline.model.model.encoder.layers[0]

In [None]:
def get_model_completion(
    pipeline,
    context: np.ndarray,
    return_normalized_completions: bool = False,
    verbose: bool = True,
    **kwargs,
):
    # Prepare input tensor
    context_tensor = torch.from_numpy(context.T).float().to(pipeline.device)[None, ...]
    # Generate completions
    completions_output = pipeline.model.generate_completions(
        context_tensor,
        past_observed_mask=None,
        **kwargs,
    )

    if verbose:
        print(f"context_tensor shape: {context_tensor.shape}")
        print(f"completions output shape: {completions_output.completions.shape}")

    # Extract shapes and data
    patch_size = completions_output.completions.shape[-1]

    # Check for required outputs
    if any(
        x is None
        for x in [completions_output.mask, completions_output.patched_past_values]
    ):
        raise ValueError("Required completion outputs are None")

    # Process tensors to numpy arrays
    def process_tensor(tensor, reshape=True):
        if reshape:
            return (
                tensor.reshape(context_tensor.shape[0], context_tensor.shape[-1], -1)
                .detach()
                .cpu()
                .numpy()
                .transpose(0, 2, 1)
            )
        return tensor.detach().cpu().numpy()

    completions = process_tensor(completions_output.completions)
    processed_context = process_tensor(completions_output.patched_past_values)
    patch_mask = process_tensor(completions_output.mask, reshape=False)
    timestep_mask = np.repeat(patch_mask, repeats=patch_size, axis=2)

    # Denormalize if needed
    if not return_normalized_completions:
        if completions_output.loc is None or completions_output.scale is None:
            raise ValueError("Loc or scale is None")
        loc = completions_output.loc.detach().cpu().numpy()
        scale = completions_output.scale.detach().cpu().numpy()
        completions = completions * scale + loc
        processed_context = processed_context * scale + loc

    # Reshape for plotting
    processed_context = processed_context.squeeze(0).transpose(1, 0)
    completions = completions.squeeze(0).transpose(1, 0)
    timestep_mask = timestep_mask.squeeze(0)

    if verbose:
        print(f"processed context shape: {processed_context.shape}")
        print(f"completions shape: {completions.shape}")
        print(f"timestep mask shape: {timestep_mask.shape}")

    return completions, processed_context, timestep_mask

In [None]:
def plot_model_completion(
    completions,
    processed_context,
    timestep_mask,
    figsize: tuple[int, int] = (6, 8),
    save_path: str | None = None,
):
    n_timesteps = processed_context.shape[1]
    assert n_timesteps == completions.shape[1] == processed_context.shape[1]

    # Create figure with grid layout
    fig = plt.figure(figsize=figsize)
    gs = gridspec.GridSpec(4, 1, height_ratios=[3, 1, 1, 1])

    # Create axes
    ax_3d = fig.add_subplot(gs[0], projection="3d")
    axes_2d = [fig.add_subplot(gs[i]) for i in range(1, 4)]

    # Plot completions in 3D
    ax_3d.plot(
        processed_context[0, :],
        processed_context[1, :],
        processed_context[2, :],
        alpha=0.5,
        color="black",
        linewidth=2,
    )
    # ax_3d.set_title("Completions", y=0.94, fontweight="bold")
    ax_3d.axis("off")
    ax_3d.grid(False)

    # Plot masked segments in 3D
    mask_bool = timestep_mask.astype(bool)
    for dim in range(3):
        # Find contiguous blocks in mask
        change_indices = np.where(
            np.diff(np.concatenate(([False], mask_bool[dim], [False])))
        )[0]

        # Plot each contiguous block
        for i in range(0, len(change_indices), 2):
            if i + 1 < len(change_indices):
                start_idx, end_idx = change_indices[i], change_indices[i + 1]
                # Plot masked parts in red
                ax_3d.plot(
                    completions[0, start_idx:end_idx],
                    completions[1, start_idx:end_idx],
                    completions[2, start_idx:end_idx],
                    alpha=1,
                    color="red",
                    linewidth=2,
                    zorder=10,
                )
                # Plot masked parts in red
                ax_3d.plot(
                    processed_context[0, start_idx:end_idx],
                    processed_context[1, start_idx:end_idx],
                    processed_context[2, start_idx:end_idx],
                    alpha=1,
                    color="black",
                    linewidth=2,
                )

    # Plot univariate series for each dimension
    for dim, ax in enumerate(axes_2d):
        mask_bool_dim = timestep_mask[dim, :].astype(bool)

        # Plot context
        ax.plot(processed_context[dim, :], alpha=0.5, color="black", linewidth=2)

        # Find segments where mask changes
        diffs = np.diff(mask_bool_dim.astype(int))
        change_indices = np.where(diffs)[0]
        if not mask_bool_dim[0]:
            change_indices = np.concatenate(([0], change_indices))
        segment_indices = np.concatenate((change_indices, [n_timesteps]))

        # Plot completions for masked segments
        segments = zip(segment_indices[:-1], segment_indices[1:])
        masked_segments = [idx for i, idx in enumerate(segments) if (i + 1) % 2 == 1]
        for start, end in masked_segments:
            if end < n_timesteps - 1:
                end += 1
            ax.plot(
                range(start, end),
                completions[dim, start:end],
                alpha=1,
                color="red",
                linewidth=2,
                zorder=10,
            )
            ax.plot(
                range(start, end),
                processed_context[dim, start:end],
                alpha=1,
                color="black",
                linewidth=2,
            )

        # Fill between completions and context
        ax.fill_between(
            range(n_timesteps),
            processed_context[dim, :],
            completions[dim, :],
            where=~mask_bool_dim,
            alpha=0.2,
        )
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()

In [None]:
split = "final_base40"
test_data_dir = f"/stor/work/AMDG_Gilpin_Summer2024/data/improved/{split}"
subsplit = "test_zeroshot"

In [None]:
test_system_subdirs = os.listdir(os.path.join(test_data_dir, subsplit))
test_system_subdirs = [
    d
    for d in test_system_subdirs
    if os.path.isdir(os.path.join(test_data_dir, subsplit, d))
]
print(len(test_system_subdirs))

In [None]:
n_systems_to_plot = 4

# select random n systems from test_system_subdirs
selected_pair_names = np.random.choice(test_system_subdirs, 4, replace=False)
print(selected_pair_names)

In [None]:
num_sample_idxs = {}
for dyst_name in test_system_subdirs:
    num_sample_idxs[dyst_name] = (
        len(os.listdir(os.path.join(test_data_dir, subsplit, dyst_name))) - 1
    )
print(num_sample_idxs)

In [None]:
chosen_start_time = 512

# sample_idx, start_time, subsample_interval
chosen_completions_settings = {
    pair_name: (0, chosen_start_time, 1) for pair_name in selected_pair_names
}

# chosen_completions_settings = {
#     "LorenzStenflo_VallisElNino": (0, chosen_start_time, 1),
# }

# chosen_completions_settings = {
#     pair_name: (num_sample_idxs[pair_name], chosen_start_time, 1)
#     for pair_name in chosen_completions_settings
# }

In [None]:
chosen_completions_settings

In [None]:
print(len(chosen_completions_settings.keys()))

In [None]:
context_length = 4096  # actually min(4096, context_length + start_time)

completions_dict = {}

show_plot = True
save_plot = False

for dyst_name, settings in tqdm(chosen_completions_settings.items()):
    print(dyst_name)
    sample_idx, start_time, subsample_interval = settings

    syspaths = get_system_filepaths(dyst_name, test_data_dir, subsplit)
    trajectory, _ = load_trajectory_from_arrow(syspaths[sample_idx])
    trajectory = trajectory[:, ::subsample_interval]

    end_time = start_time + context_length

    save_path = os.path.join(
        "../figures",
        run_name,
        split,
        subsplit,
        f"{dyst_name}_sample{sample_idx}_context{start_time}-{end_time}.pdf",
    )
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    completions, processed_context, timestep_mask = get_model_completion(
        model_pipeline,
        trajectory[:, start_time:end_time],  # context
        return_normalized_completions=False,
        verbose=False,
    )
    completions_dict[dyst_name] = {
        "completions": completions,
        "processed_context": processed_context,
        "timestep_mask": timestep_mask,
    }
    if show_plot:
        plot_model_completion(
            completions,
            processed_context,
            timestep_mask,
            figsize=(24, 8),
            save_path=save_path if save_plot else None,
        )

In [None]:
completions.shape

## Plot computed GP Dims

See our script in `scripts/compute_gpdims.py`

In [None]:
import pickle

WORK_DIR = os.environ["WORK"]

metrics_save_dir = f"{WORK_DIR}/eval_results/patchtst/{run_name}/test_zeroshot"
gpdims_fnames = [
    f for f in os.listdir(metrics_save_dir) if f.endswith(".pkl") and "gpdim" in f
]

gpdims_completions_all_runs = {}
gpdims_groundtruth_all_runs = {}
for gpdims_fname in gpdims_fnames:
    with open(os.path.join(metrics_save_dir, gpdims_fname), "rb") as f:
        gp_dims = pickle.load(f)
    print(f"number of systems in {gpdims_fname}: {len(gp_dims)}")
    print(
        f"gpdim of completions of first system in {gpdims_fname}: {gp_dims['LorenzStenflo_pp0']['completions']}"
    )
    print(
        f"gpdim of groundtruth of first system in {gpdims_fname}: {gp_dims['LorenzStenflo_pp0']['groundtruth']}"
    )
    for sys_name, gp_dim_val in gp_dims.items():
        if sys_name not in gpdims_completions_all_runs:
            gpdims_completions_all_runs[sys_name] = []
        gpdims_completions_all_runs[sys_name].append(gp_dim_val["completions"])
        if sys_name not in gpdims_groundtruth_all_runs:
            gpdims_groundtruth_all_runs[sys_name] = []
        gpdims_groundtruth_all_runs[sys_name].append(gp_dim_val["groundtruth"])

In [None]:
len(gpdims_completions_all_runs.keys())

In [None]:
test_system_name = next(iter(gpdims_completions_all_runs))
print(test_system_name)
test_gpdim_vals = gpdims_completions_all_runs[test_system_name]
print(len(test_gpdim_vals))

In [None]:
for sys_name in gpdims_completions_all_runs.keys():
    gpdims_completions_all_runs[sys_name] = np.mean(
        gpdims_completions_all_runs[sys_name]
    )

for sys_name in gpdims_groundtruth_all_runs.keys():
    gpdims_groundtruth_all_runs[sys_name] = np.mean(
        gpdims_groundtruth_all_runs[sys_name]
    )

groundtruth_gp_dims = list(gpdims_groundtruth_all_runs.values())
completions_gp_dims = list(gpdims_completions_all_runs.values())

print(len(groundtruth_gp_dims))

In [None]:
print(len(groundtruth_gp_dims))
print(len(completions_gp_dims))

# Convert to numpy arrays for easier manipulation
x = np.array(groundtruth_gp_dims)
y = np.array(completions_gp_dims)

z_scores = np.abs(stats.zscore(np.vstack([x, y]).T))
outliers = np.any(z_scores > 3, axis=1)  # Points with z-score > 2.5 are outliers
# get number of outliers
num_outliers = np.sum(outliers)
print(f"Number of outliers: {num_outliers}")

# Filter out outliers
x_clean = x[~outliers]
y_clean = y[~outliers]

fit_intercept = True

model = LinearRegression(fit_intercept=fit_intercept)
model.fit(x_clean.reshape(-1, 1), y_clean)
slope = model.coef_[0]
intercept = model.intercept_ if fit_intercept else 0
r_value = np.sqrt(model.score(x_clean.reshape(-1, 1), y_clean))
line_x = np.linspace(min(x_clean), max(x_clean), 100)
line_y = slope * line_x + intercept

plt.figure(figsize=(4, 4))
# # Plot all points
# plt.scatter(x, y, alpha=0.5, label="All data")

# Highlight non-outlier points
plt.scatter(x_clean, y_clean, color="black", s=5, alpha=0.1)

# Plot the regression line
if fit_intercept:
    regression_eq = f"y = {slope:.2f}x{' + ' if intercept > 0 else ' - '}{abs(intercept):.2f} (R² = {r_value**2:.2f})"
else:
    regression_eq = f"y = {slope:.2f}x (R² = {r_value**2:.2f})"
plt.plot(
    line_x,
    line_y,
    "r-",
    alpha=0.9,
    zorder=10,
    # color="tab:red",
    label=regression_eq,
)
# also plot the slope = 1 line
lower_bound = min(min(x_clean), min(y_clean))
upper_bound = max(max(x_clean), max(y_clean))
plt.plot(
    np.linspace(lower_bound, upper_bound, 100),
    np.linspace(lower_bound, upper_bound, 100),
    "r--",
    alpha=0.9,
    zorder=9,
    # color="tab:red",
    label="y = x",
)
# Set the same range for both axes to make the plot perfectly square
min_val = min(min(x_clean), min(y_clean))
max_val = max(max(x_clean), max(y_clean))
plt.xlim(min_val, max_val)
plt.ylim(min_val, max_val)


plt.xlabel("Ground Truth", fontweight="bold")
plt.ylabel("Completions", fontweight="bold")
plt.title("Estimated Correlation Dimension", fontweight="bold")
plt.legend(loc="best")
plt.tight_layout()
plt.savefig(
    os.path.join(
        "../figures",
        "gpdims.pdf",
    ),
    bbox_inches="tight",
)
plt.show()