In [None]:
%load_ext autoreload
%autoreload 2

# basic imports
import numpy as np
import matplotlib.pyplot as plt
import math
import torch
import torch.nn as nn

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using GPU:", device)
else:
    device = torch.device("cpu")
    print("No GPU -> using CPU:", device)

import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


import sys


sys.path.append("../../PFN4CASHPlus")
sys.path.append("../Priors/")

from data.synth_ablation import get_batch_func as synth_get_batch_func

In [None]:
import lcpfn

Neg = lcpfn.LCPFN(model_name="./trained_models/Neg_03-08-07-02-36_full_model.pt")

Neg_limited = lcpfn.LCPFN(
    model_name="./trained_models/Neg_limited_12_03-03-10-03-33_full_model.pt"
)

models = {
    "Neg": Neg,
    "Neg(limited)": Neg_limited,
}
colors = {
    "Neg": "blue",
    "Neg(limited)": "red",
}

In [None]:
xs, ys, y_raw = synth_get_batch_func(1, 200, num_outputs=1, to_torch=True)
data = ys[:, 0, 0].detach().cpu().numpy()
raw_data = y_raw[:, 0, 0].detach().cpu().numpy()
x = xs[:, 0].detach().cpu().numpy()

data = ys[:, 0, 0].detach().cpu().numpy()
raw_data = y_raw[:, 0, 0].detach().cpu().numpy()
x = xs[:, 0].detach().cpu().numpy()

cutoff = 1
seq_len = 200

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams.update({"font.size": 24})
# Create figure with constrained width ratios
fig, axs = plt.subplots(
    nrows=1, ncols=2, figsize=(8, 4.5), gridspec_kw={"width_ratios": [3, 0.8]}
)

ax_main, ax_hist_y = axs

# --- Main Plot: Line Plot ---
ax_main.plot(x[:seq_len], data[:seq_len], "black", label="Target Data")

for model_name, model_obj in models.items():
    points = (
        (
            model_obj.model.criterion.borders[:-1]
            + model_obj.model.criterion.bucket_widths / 2
        )
        .detach()
        .cpu()
        .numpy()
    )

    predictions = model_obj.predict_mean(
        x_train=xs[:seq_len][:cutoff, 0].to(device),
        y_train=ys[:seq_len][:cutoff, 0].to(device),
        x_test=xs[:seq_len][cutoff:, 0].to(device),
    )

    ax_main.plot(
        x[:seq_len][cutoff:, 0],
        predictions[:, 0, 0].cpu(),
        colors[model_name],
        label=f"Extrapolation by {model_name}",
    )
    predictions = model_obj.predict_quantiles(
        x_train=xs[:seq_len][:cutoff, 0].to(device),
        y_train=ys[:seq_len][:cutoff, 0].to(device),
        x_test=xs[:seq_len][cutoff:, 0].to(device),
        qs=[0.05, 0.5, 0.95],
    )

    ax_main.fill_between(
        x[:seq_len][cutoff:, 0].flatten(),
        predictions[:, 0, 0].cpu(),
        predictions[:, 2, 0].cpu(),
        color=colors[model_name],
        alpha=0.2,
        label="CI of 95%",
    )

# Cutoff line
ax_main.axvline(x[cutoff], color="k", linestyle="--", linewidth=0.8, label="Cutoff")
# ax_main.legend(loc="lower right")
ax_main.legend(loc="center left", bbox_to_anchor=(1.3, 0.5))
ax_main.set_ylabel("Values")
ax_main.set_ylim(0.0, 1)

sns.kdeplot(
    y=raw_data,
    ax=ax_hist_y,
    color="black",
    fill=True,
    alpha=0.7,
    clip=[raw_data.min(), raw_data.max()],
)

ax_hist_y.xaxis.set_label_position("top")
ax_hist_y.xaxis.tick_top()
ax_hist_y.set_xlabel("Frequency")
ax_hist_y.set_ylim(0.0, 1)

fig.subplots_adjust(wspace=0)

ax_hist_y.spines["right"].set_visible(False)
ax_hist_y.spines["bottom"].set_visible(False)

ax_hist_y.yaxis.set_visible(False)

pos_main = ax_main.get_position()
pos_hist = ax_hist_y.get_position()
ax_hist_y.set_position([pos_hist.x0, pos_main.y0, pos_hist.width, pos_main.height])
plt.savefig(
    "./synth/results/figures/Neg_limited_vs_Neg_extrapolation.pdf",
    bbox_inches="tight",
    dpi=600,
)