# Set up

In [None]:
import torch
import gpytorch
import pandas as pd
import numpy as np
import tqdm as tqdm
from linear_operator import settings

import pyro
import math
import pickle
import time
from joblib import Parallel, delayed

from sklearn.preprocessing import StandardScaler

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

import pyro.distributions as dist
from pyro.infer import MCMC, NUTS
import arviz as az
import seaborn as sns

import os

import GP_functions.Loss_function as Loss_function
import GP_functions.bound as bound
import GP_functions.Estimation as Estimation
import GP_functions.Training as Training
import GP_functions.Prediction as Prediction
import GP_functions.GP_models as GP_models
import GP_functions.Tools as Tools
import GP_functions.FeatureE as FeatureE

# Data

In [None]:
X_train = pd.read_csv('Data/X_train.csv', header=None, delimiter=',').values
X_test = pd.read_csv('Data/X_test.csv', header=None, delimiter=',').values


Y_train_21 = pd.read_csv('Data/Y_train_std_21.csv', header=None, delimiter=',').values
Y_test_21 = pd.read_csv('Data/Y_test_std_21.csv', header=None, delimiter=',').values


train_x = torch.tensor(X_train, dtype=torch.float32)
test_x = torch.tensor(X_test, dtype=torch.float32)


train_y_21 = torch.tensor(Y_train_21, dtype=torch.float32)
test_y_21 = torch.tensor(Y_test_21, dtype=torch.float32)


# Device

In [None]:
Device = 'cuda'

# Training

In [None]:
MVGP_models, MVGP_likelihoods = Training.train_MultitaskVGP_minibatch(
    train_x=train_x.to(Device),
    train_y=train_y_21.to(Device),
    covar_type='RBF',
    num_latents=32,
    num_inducing=400,
    lr_hyper=0.01,
    lr_variational=0.1,
    num_iterations=10000,
    patience=10,
    device=Device,
    batch_size=512,
    eval_every=100,
    eval_batch_size=1024
)

# MCMC

In [None]:
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from statsmodels.graphics.tsaplots import plot_acf

# Pyro diagnostics
from pyro.ops.stats import gelman_rubin, split_gelman_rubin, effective_sample_size

def split_chain(chain_tensor: torch.Tensor):
    """
    将单链样本拆成两半，形成两条“伪链”
    输入: Tensor[N]
    输出: (Tensor[N//2], Tensor[N//2])
    """
    n = chain_tensor.shape[0]
    half = n // 2
    return chain_tensor[:half], chain_tensor[half:2*half]


def visualize_posterior_1d_params(
    single_chain_samples: dict,
    *,
    true_params_tensor=None,        # <- 新增：torch Tensor[D]，比如 test_x[0]
    bins=15,
    acf_lags=40,
    clip_percentiles=(0.5, 99.5),
    xlim=None,
):
    """
    single_chain_samples: dict[name -> Tensor[N]]
    true_params_tensor: Tensor[D]，若提供，会在 trace/hist 上标出真实值
    对每个参数：split Rhat/ESS + trace + hist+KDE+quantiles + ACF
    """

    # --- display name mapping (theta_0..theta_9 -> Ca_1..Cb_5) ---
    param_labels = ["Ca_1","Cb_1","Ca_2","Cb_2","Ca_3","Cb_3","Ca_4","Cb_4","Ca_5","Cb_5"]
    name_map = {f"theta_{i}": param_labels[i] for i in range(len(param_labels))}
    name_map.update({f"param_{i}": param_labels[i] for i in range(len(param_labels))})

    def display_name(p: str) -> str:
        return name_map.get(p, p)

    def parse_index(p: str):
        # 只对 theta_7 / param_7 这种形式返回 index
        try:
            head, tail = p.split("_", 1)
            if head in ("theta", "param"):
                return int(tail)
        except Exception:
            return None
        return None

    # 真实参数向量准备
    if true_params_tensor is not None:
        if not torch.is_tensor(true_params_tensor):
            true_vec = torch.tensor(true_params_tensor, dtype=torch.float32).detach().flatten().cpu()
    else:
        true_vec = None

    # 整理成 mcmc_samples：param -> Tensor[2, n_half]
    mcmc_samples = {}
    for param, samples in single_chain_samples.items():
        if samples.ndim != 1:
            raise ValueError(f"{param} should be 1D Tensor[N], got shape {tuple(samples.shape)}")
        chain_a, chain_b = split_chain(samples)
        mcmc_samples[param] = torch.stack([chain_a, chain_b], dim=0)  # [2, n_half]

    # 诊断和可视化
    for param, samples_chains in mcmc_samples.items():
        disp = display_name(param)

        idx = parse_index(param)
        true_val = None
        if (true_vec is not None) and (idx is not None) and (0 <= idx < true_vec.numel()):
            true_val = float(true_vec[idx].item())

        rhat = gelman_rubin(samples_chains, chain_dim=0, sample_dim=1)
        split_rhat = split_gelman_rubin(samples_chains, chain_dim=0, sample_dim=1)
        ess = effective_sample_size(samples_chains, chain_dim=0, sample_dim=1)

        if true_val is None:
            print(f"{disp}: R-hat = {rhat:.3f}, split R-hat = {split_rhat:.3f}, ESS = {ess:.1f}")
        else:
            print(f"{disp}: R-hat = {rhat:.3f}, split R-hat = {split_rhat:.3f}, ESS = {ess:.1f}, true = {true_val:.6g}")

        # --- Trace + Histogram/KDE ---
        plt.figure(figsize=(12, 4))

        # Trace
        plt.subplot(1, 2, 1)
        for i in range(2):
            plt.plot(samples_chains[i].cpu().numpy(), marker='o', label=f"Chain {i+1}", alpha=0.7)

        if true_val is not None:
            plt.axhline(true_val, linestyle="--", linewidth=2, label="True value", color="green")  # 真实值水平线

        plt.title(f"Trace Plot for {disp}")
        plt.xlabel("Sample Index")
        plt.ylabel(disp)
        plt.legend()

        # Histogram + KDE + Quantiles
        plt.subplot(1, 2, 2)
        all_samps = samples_chains.reshape(-1).cpu().numpy()

        p_lo, p_hi = clip_percentiles
        xmin, xmax = np.percentile(all_samps, [p_lo, p_hi])

        # 确保真实值在线性范围内（避免被裁掉看不到）
        if true_val is not None:
            xmin = min(xmin, true_val)
            xmax = max(xmax, true_val)
            # 给一点边距
            pad = 0.02 * (xmax - xmin + 1e-12)
            xmin, xmax = xmin - pad, xmax + pad

        plt.hist(all_samps, bins=bins, density=True, alpha=0.7, color='gray')

        # KDE（样本太少/方差太小有时会报错，所以做个保护）
        if np.std(all_samps) > 0 and len(all_samps) > 5:
            kde = gaussian_kde(all_samps)
            x_grid = np.linspace(xmin, xmax, 200)
            plt.plot(x_grid, kde(x_grid), linewidth=2)

        qs = torch.quantile(torch.from_numpy(all_samps), torch.tensor([0.025, 0.5, 0.975]))
        for q in qs:
            plt.axvline(q.item(), color='red', linestyle='--', linewidth=2)

        if true_val is not None:
            plt.axvline(true_val, linewidth=2, label="True value", color="green")  # 真实值竖线

        if xlim is not None:
            plt.xlim(*xlim)
        else:
            plt.xlim(xmin, xmax)

        plt.title(f"Histogram + 2.5/50/97.5% Quantiles ({disp})")
        plt.xlabel("Value")
        plt.ylabel("Density")
        plt.legend()
        plt.tight_layout()
        plt.show()

        # --- ACF（仅第一“伪链”） ---
        plt.figure(figsize=(6, 4))
        plot_acf(samples_chains[0].cpu().numpy(), lags=acf_lags)
        plt.title(f"ACF for {disp} (Chain 1)")
        plt.xlabel("Lag")
        plt.ylabel("Autocorrelation")
        plt.tight_layout()
        plt.show()



## 21

array([ 21,  28,  79, 119, 103])

In [None]:
row_idx = 103

input_point = test_y_21[row_idx, :]

local_train_x, local_train_y = Tools.find_k_nearest_neighbors_CPU(input_point, train_x, train_y_21, k=100)

bounds = bound.get_bounds(local_train_x)

In [None]:
mcmc_result_Uniform = Estimation.run_mcmc_Uniform(
    Prediction.preds_distribution, MVGP_models, MVGP_likelihoods, 
    row_idx, test_y_21, bounds, 
    num_sampling=1200, warmup_step=300, num_chains=1, device=Device
)
posterior_samples_Uniform = mcmc_result_Uniform.get_samples()

In [None]:
visualize_posterior_1d_params(
    posterior_samples_Uniform,
    true_params_tensor=X_test[row_idx],
    bins=15,
    acf_lags=40,
    clip_percentiles=(0.5, 99.5),
    xlim=None, 
)

In [None]:
mcmc_dir = 'PlotData'

mcmc_file = os.path.join(mcmc_dir, f'result_{row_idx + 1}.pkl')
with open(mcmc_file, 'wb') as f:
    pickle.dump(posterior_samples_Uniform, f)