# Set up

In [1]:
import torch
import gpytorch
import pandas as pd
import numpy as np
import tqdm as tqdm

import ast
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score
import os
import pickle

import re
import contextlib

import math

import arviz as az
import seaborn as sns

import os
import scipy.stats as stats

from statsmodels.graphics.tsaplots import plot_acf
import statsmodels

from scipy.stats import gaussian_kde
from pyro.ops.stats import (
    gelman_rubin,
    split_gelman_rubin,
    autocorrelation,
    effective_sample_size,
    resample,
    quantile,
    weighed_quantile
)

import pickle

import torch.nn.functional as F


import GP_functions.Loss_function as Loss_function
import GP_functions.bound as bound
import GP_functions.Estimation as Estimation
import GP_functions.Training as Training
import GP_functions.Prediction as Prediction
import GP_functions.GP_models as GP_models
import GP_functions.Tools as Tools
import GP_functions.FeatureE as FeatureE

import joblib

# Data

In [2]:
X_train = pd.read_csv('RealCase/RealCase_X_train.csv', header=None, delimiter=',').values
X_test = pd.read_csv('RealCase/RealCase_X_test.csv', header=None, delimiter=',').values

Y_train_std = pd.read_csv('RealCase/RealCase_Y_train_std.csv', header=None, delimiter=',').values
Y_test_std = pd.read_csv('RealCase/RealCase_Y_test_std.csv', header=None, delimiter=',').values
Realcase_data_std = pd.read_csv('RealCase/RealCase_Y_std.csv', header=None, delimiter=',').values

Y_train = pd.read_csv('RealCase/RealCase_Y_train.csv', header=None, delimiter=',').values
Y_test = pd.read_csv('RealCase/RealCase_Y_test.csv', header=None, delimiter=',').values
Realcase_data = pd.read_csv('RealCase/RealCase.csv', header=None, delimiter=',').values


In [3]:
train_x = torch.tensor(X_train, dtype=torch.float32)
test_x = torch.tensor(X_test, dtype=torch.float32)

train_y = torch.tensor(Y_train_std, dtype=torch.float32)
test_y = torch.tensor(Y_test_std, dtype=torch.float32)
realcase_y = torch.tensor(Realcase_data_std, dtype=torch.float32)


# Emulator

In [None]:
Device = 'cuda'

In [None]:
checkpoint = torch.load('multitask_gp_checkpoint_Realcase.pth', map_location=Device)
model_params = checkpoint['model_params']

MVGP_models = GP_models.MultitaskVariationalGP(train_x, train_y, 
                                               num_latents=model_params['num_latents'],
                                               num_inducing=model_params['num_inducing'],  
                                               covar_type=model_params['covar_type']).to(Device)

MVGP_models.load_state_dict(checkpoint['model_state_dict'])

MVGP_likelihoods = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=train_y.shape[1]).to(Device)
MVGP_likelihoods.load_state_dict(checkpoint['likelihood_state_dict'])

MVGP_models.eval()
MVGP_likelihoods.eval()

# MCMC Samples

In [None]:
loaded_samples = torch.load("mcmc_RealCase_final_mcmc_infer_2block_independent.pt", map_location=Device)

In [None]:
posterior_samples = Tools.extract_vector_params_from_mcmc(
    loaded_samples,
    key="params",
    param_names=[f"theta_{i}" for i in range(10)]  # 你也可以用真实参数名
)

posterior_samples_theta = posterior_samples.clone().detach()

In [None]:
keys_to_extract = ['sigma_meas_1', 'sigma_meas_2']
posterior_samples.update({k: loaded_samples[k] for k in keys_to_extract})

# MCMC Plots

In [None]:
Tools.visualize_posterior_1d_params(
    posterior_samples,
    bins=15,
    acf_lags=40,
    clip_percentiles=(0.5, 99.5),
    xlim=None,          # 需要固定范围就填 (low, high)
)

# Posterior predictive resample

In [None]:



def _sorted_param_keys(sample_dict):
    """
    支持 key 形如:
      - "param_0", "param_1", ...
      - "theta_0", "theta_1", ...
      - 任何以 _数字 结尾的 key
    """
    def idx(k: str) -> int:
        m = re.search(r"(\d+)$", k)
        if m is None:
            raise ValueError(f"Cannot parse index from key='{k}'. Expect suffix like _0, _1, ...")
        return int(m.group(1))

    return sorted(sample_dict.keys(), key=idx)


@torch.no_grad()
def posterior_predictive_resample(
    samples,
    Pre_function,
    Models,
    Likelihoods,
    num_draws=20,
    *,
    device="cpu",
    chunk_size=256,
    group_by_chain=False,      # True: samples[k] is [C,S,...] -> flatten to [C*S,...]
    use_fast_pred_var=False,   # 如果你在 Pre_function 内部已经开了 fast_pred_var，就设 False
    return_mean_var=False,     # 可选：同时返回每个 theta 的预测 mean/var
    use_rsample=True,          # True: rsample; False: sample
):
    """
    Posterior predictive resampling for GPyTorch multitask GP predictions.

    Parameters
    ----------
    samples : dict[str, Tensor]
        MCMC samples dict, each key maps to Tensor on CPU or GPU.
        Typical shapes:
          - [S] for single-chain
          - [C,S] for multi-chain (if group_by_chain=True will flatten)

    Pre_function : callable(model, likelihood, x)
        Should return a GPyTorch distribution, e.g. MultitaskMultivariateNormal

    Returns
    -------
    theta_all : Tensor [S_total, P]  (CPU)
    y_pp      : Tensor [S_total, num_draws, D]  (CPU)
    (optional) mean_all : Tensor [S_total, D] (CPU)
    (optional) var_all  : Tensor [S_total, D] (CPU)
    """

    # ---- 0) make sample_shape torch.Size (GPyTorch rsample expects torch.Size in some versions) ----
    if isinstance(num_draws, int):
        sample_shape = torch.Size([num_draws])
    else:
        # allow user pass tuple/list like (n,) or [n]
        sample_shape = torch.Size(list(num_draws))

    keys = _sorted_param_keys(samples)

    def flatten_chain(x: torch.Tensor) -> torch.Tensor:
        # If multi-chain tensor [C,S,...] and group_by_chain=True -> [C*S,...]
        if group_by_chain and x.dim() >= 2:
            return x.reshape(-1, *x.shape[2:])
        return x

    device = torch.device(device)

    # ---- 1) build theta_all: [S_total, P] ----
    param_list = []
    for k in keys:
        x = flatten_chain(samples[k])
        x = x.to(device=device, dtype=torch.float32)
        # Expect x is [S_total] for scalar params, or [S_total, ...] if vector params
        param_list.append(x)

    # Stack last dim as parameter dim P
    # Scalar params -> theta_all [S_total, P]
    theta_all = torch.stack(param_list, dim=-1)
    S_total, P = theta_all.shape

    # ---- 2) set eval mode (supports single model or list/tuple) ----
    try:
        Models.eval()
    except Exception:
        if isinstance(Models, (list, tuple)):
            for m in Models:
                if hasattr(m, "eval"):
                    m.eval()

    try:
        Likelihoods.eval()
    except Exception:
        if isinstance(Likelihoods, (list, tuple)):
            for l in Likelihoods:
                if hasattr(l, "eval"):
                    l.eval()

    # ---- 3) chunked posterior predictive sampling ----
    y_chunks = []
    mean_chunks, var_chunks = [], []

    fast_ctx = None
    if use_fast_pred_var:
        import gpytorch
        fast_ctx = gpytorch.settings.fast_pred_var()

    # helper: normalize draws to [B, num_draws, D]
    def _to_BSD(draws: torch.Tensor, B: int) -> torch.Tensor:
        """
        Common cases:
          - [num_draws, B, D]
          - [num_draws, 1, B, D]  (extra singleton batch)
          - [num_draws, B, 1, D]  (extra singleton batch)
          - [num_draws, D]        (B==1)
        Return:
          - [B, num_draws, D]
        """
        if draws.dim() == 4:
            # [num_draws, 1, B, D] -> [num_draws, B, D]
            if draws.shape[1] == 1 and draws.shape[2] == B:
                draws = draws.squeeze(1)
            # [num_draws, B, 1, D] -> [num_draws, B, D]
            elif draws.shape[2] == 1 and draws.shape[1] == B:
                draws = draws.squeeze(2)

        if draws.dim() == 2:
            # [num_draws, D] -> [num_draws, 1, D]
            draws = draws.unsqueeze(1)

        if draws.dim() != 3:
            raise RuntimeError(f"Unexpected draws shape {tuple(draws.shape)}; cannot convert to [B,S,D].")

        # now [num_draws, B, D] -> [B, num_draws, D]
        if draws.shape[1] != B:
            raise RuntimeError(
                f"Draws has shape {tuple(draws.shape)}; expected second dim == B={B}."
            )
        return draws.permute(1, 0, 2).contiguous()

    # helper: normalize mean/var to [B, D]
    def _to_BD(x: torch.Tensor, B: int) -> torch.Tensor:
        """
        Common cases:
          - [B, D]
          - [1, B, D]
          - [B] or [1, D] shouldn't happen for multitask, but handle minimally
        """
        if x.dim() == 3 and x.shape[0] == 1 and x.shape[1] == B:
            x = x.squeeze(0)
        if x.dim() == 2 and x.shape[0] == B:
            return x
        if x.dim() == 1 and B == 1:
            # [D] -> [1, D]
            return x.unsqueeze(0)
        raise RuntimeError(f"Unexpected mean/var shape {tuple(x.shape)}; expected [B,D] (B={B}).")

    with (fast_ctx if fast_ctx is not None else contextlib.nullcontext()):
        for start in range(0, S_total, chunk_size):
            th = theta_all[start:start + chunk_size]  # [B, P]
            B = th.shape[0]

            # 3.1 preferred: vectorized call
            try:
                pred_dist = Pre_function(Models, Likelihoods, th)

                if use_rsample and hasattr(pred_dist, "rsample"):
                    draws = pred_dist.rsample(sample_shape=sample_shape)
                else:
                    draws = pred_dist.sample(sample_shape=sample_shape)

                draws = _to_BSD(draws, B)  # [B, num_draws, D]
                y_chunks.append(draws.detach().cpu())

                if return_mean_var:
                    mean = _to_BD(pred_dist.mean, B)
                    var = _to_BD(pred_dist.variance, B)
                    mean_chunks.append(mean.detach().cpu())
                    var_chunks.append(var.detach().cpu())

            except Exception:
                # 3.2 fallback: per-theta loop
                draws_list = []
                mean_list, var_list = [], []

                for i in range(B):
                    pred_i = Pre_function(Models, Likelihoods, th[i].unsqueeze(0))  # [1,P]

                    if use_rsample and hasattr(pred_i, "rsample"):
                        d_i = pred_i.rsample(sample_shape=sample_shape)
                    else:
                        d_i = pred_i.sample(sample_shape=sample_shape)

                    # expected [num_draws, 1, D] or [num_draws, D]
                    if d_i.dim() == 3 and d_i.shape[1] == 1:
                        d_i = d_i.squeeze(1)  # [num_draws, D]
                    if d_i.dim() == 2:
                        pass
                    else:
                        raise RuntimeError(f"Unexpected per-theta draws shape {tuple(d_i.shape)}")

                    draws_list.append(d_i)

                    if return_mean_var:
                        # pred_i.mean likely [1,D]
                        mean_list.append(pred_i.mean.squeeze(0))
                        var_list.append(pred_i.variance.squeeze(0))

                # stack: B * [num_draws, D] -> [B, num_draws, D]
                draws = torch.stack(draws_list, dim=0)
                y_chunks.append(draws.detach().cpu())

                if return_mean_var:
                    mean_chunks.append(torch.stack(mean_list, dim=0).detach().cpu())
                    var_chunks.append(torch.stack(var_list, dim=0).detach().cpu())

    y_pp = torch.cat(y_chunks, dim=0)  # [S_total, num_draws, D]
    theta_cpu = theta_all.detach().cpu()

    if return_mean_var:
        mean_all = torch.cat(mean_chunks, dim=0)  # [S_total, D]
        var_all = torch.cat(var_chunks, dim=0)    # [S_total, D]
        return theta_cpu, y_pp, mean_all, var_all

    return theta_cpu, y_pp


In [None]:
theta_samples, y_pp = posterior_predictive_resample(
    posterior_samples,
    Pre_function=Prediction.preds_distribution,
    Models=MVGP_models,
    Likelihoods=MVGP_likelihoods,
    num_draws=10,
    device="cuda",
    chunk_size=256,
    group_by_chain=False,
    use_fast_pred_var=True,   # 如果你想在这里统一开 fast_pred_var，就改 True
)

print(theta_samples.shape)  # [S, P]
print(y_pp.shape)           # [S, 20, D]  (比如 D=33)

# Plot violin

In [None]:

# 1. 定义你要求的 33 个标签列表
CARDIAC_LABELS = [
    "Volume",
    "cir-Basal-InfSept", "cir-Basal-AntSept", "cir-Basal-Ant", "cir-Basal-AntLat", "cir-Basal-InfLat", "cir-Basal-Inf",
    "cir-Mid-InfSept", "cir-Mid-AntSept", "cir-Mid-Ant", "cir-Mid-AntLat", "cir-Mid-InfLat", "cir-Mid-Inf",
    "cir-Apical-Septal", "cir-Apical-Anterior", "cir-Apical-Lateral", "cir-Apical-Inferior",
    "rad-Basal-InfSept", "rad-Basal-AntSept", "rad-Basal-Ant", "rad-Basal-AntLat", "rad-Basal-InfLat", "rad-Basal-Inf",
    "rad-Mid-InfSept", "rad-Mid-AntSept", "rad-Mid-Ant", "rad-Mid-AntLat", "rad-Mid-InfLat", "rad-Mid-Inf",
    "rad-Apical-Septal", "rad-Apical-Anterior", "rad-Apical-Lateral", "rad-Apical-Inferior"
]

def plot_violin_by_dim_with_truth(
    y_pp: torch.Tensor,          # [S, R, D]
    y_true: torch.Tensor,        # [1, D] or [D]
    *,
    dim_labels=None,             # List of strings
    max_points_per_dim=5000,
    show_median=True,
    show_extrema=False,
    figsize=(20, 6),             # 稍微加宽画布以适应33个维度
    label_rotation=90,           # 新增参数：控制标签旋转角度，默认90度
    title="Posterior Distribution of Volume, Circumferential & Radial Strains"
):
    # ---------- shape checks ----------
    if y_pp.dim() != 3:
        raise ValueError(f"y_pp must be [S,R,D], got shape {tuple(y_pp.shape)}")
    S, R, D = y_pp.shape

    y_true = y_true.detach()
    if y_true.dim() == 2 and y_true.shape[0] == 1:
        y_true = y_true.squeeze(0)
    if y_true.dim() != 1 or y_true.numel() != D:
        raise ValueError(f"y_true must be [D] or [1,D] with D={D}, got shape {tuple(y_true.shape)}")

    # ---------- flatten samples ----------
    y_dim = y_pp.permute(2, 0, 1).contiguous().reshape(D, -1)
    N = y_dim.shape[1]

    if max_points_per_dim is not None and N > max_points_per_dim:
        idx = torch.randperm(N, device=y_dim.device)[:max_points_per_dim]
        y_dim = y_dim[:, idx]

    y_list = [y_dim[d].detach().cpu().numpy() for d in range(D)]
    truth = y_true.detach().cpu().numpy()

    # ---------- labels check ----------
    if dim_labels is None:
        dim_labels = [str(i + 1) for i in range(D)]
    
    # 检查维度是否匹配 (防止传入的 labels 数量和数据维度不一致)
    if len(dim_labels) != D:
        print(f"Warning: dim_labels length ({len(dim_labels)}) != D ({D}). Using indices instead.")
        dim_labels = [str(i + 1) for i in range(D)]

    positions = np.arange(1, D + 1)

    # ---------- plot ----------
    plt.figure(figsize=figsize)
    
    vp = plt.violinplot(
        y_list,
        positions=positions,
        showmeans=False,
        showmedians=show_median,
        showextrema=show_extrema
    )

    # --- 红色/蓝色 区分逻辑 ---
    in_range_x, in_range_y = [], []
    out_range_x, out_range_y = [], []

    for i, samples in enumerate(y_list):
        t_val = truth[i]
        v_min, v_max = np.min(samples), np.max(samples)
        pos = positions[i]
        
        if t_val < v_min or t_val > v_max:
            out_range_x.append(pos)
            out_range_y.append(t_val)
        else:
            in_range_x.append(pos)
            in_range_y.append(t_val)

    if in_range_x:
        plt.scatter(in_range_x, in_range_y, s=25, zorder=3, c='C0', label="Observation (in range)")
    if out_range_x:
        plt.scatter(out_range_x, out_range_y, s=40, zorder=3, c='red', marker='x', label="Observation (out of range)")

    # --- 设置标签并旋转 ---
    plt.xticks(positions, dim_labels, rotation=label_rotation) # 关键修改：应用旋转
    
    plt.xlim(0.5, D + 0.5) # 确保两边留白
    plt.xlabel("Output dimension")
    plt.ylabel("Value")
    plt.title(title)
    plt.legend()
    plt.tight_layout() # 自动调整布局防止标签被切掉
    plt.show()



In [None]:
plot_violin_by_dim_with_truth(
    y_pp, 
    realcase_y, 
    dim_labels=CARDIAC_LABELS, # 传入标签
    max_points_per_dim=None,  # 可以根据需要调整这个值，或者设为 None 来使用所有点
    figsize=(20, 8)            # 画布高度加大，给旋转的标签留空间
)

# Inverse standardize

In [None]:
import torch
import joblib

scaler = joblib.load("RealCase/y_scaler_RealCase.joblib")

def inverse_standardize_torch(y: torch.Tensor, scaler) -> torch.Tensor:
    """
    y: torch Tensor, last dim must be D
    return: same shape as y, inversed to original scale
    """
    if not (hasattr(scaler, "mean_") and hasattr(scaler, "scale_")):
        raise TypeError("Scaler has no mean_/scale_. If this is not StandardScaler, use method B.")

    mean = torch.tensor(scaler.mean_, dtype=y.dtype, device=y.device)   # [D]
    scale = torch.tensor(scaler.scale_, dtype=y.dtype, device=y.device) # [D]
    return y * scale + mean

# ---- usage ----
# y_pp: [S,R,D]
y_pp_inv = inverse_standardize_torch(y_pp, scaler)

# y_true: [1,D] or [D]
y_true_t = realcase_y.squeeze(0) if (realcase_y.dim() == 2 and realcase_y.shape[0] == 1) else realcase_y
y_true_inv = inverse_standardize_torch(y_true_t, scaler)  # [D]


# Histogram for Volum Violin for remaining dims

In [3]:
def plot_first_hist_rest_violin(
    y_pp: torch.Tensor,                 # [S, R, D]
    y_true: torch.Tensor,               # [1,D] or [D]
    *,
    first_dim: int = 0,                 # 第一个维度的 index（默认 0）
    dim_labels=None,                    # list[str], len=D；默认 1..D
    scaler_path: str | None = None,     # 如 "RealCase/y_scaler_RealCase.joblib"；None则不逆标准化
    already_inverse_scaled: bool = False,  # 如果你已经提前 inverse 过，就设 True
    max_points_per_dim: int = 6000,     # 小提琴每维最多采样点数（避免太慢）
    max_hist_points: int = 200000,      # 直方图最多采样点数（避免太慢）
    bins: int = 60,
    density: bool = True,
    figsize=(18, 7),
    title_prefix="Posterior predictive (original scale)"
):
    # ---------- checks ----------
    if y_pp.dim() != 3:
        raise ValueError(f"y_pp must be [S,R,D], got {tuple(y_pp.shape)}")
    S, R, D = y_pp.shape

    y_true = y_true.detach()
    if y_true.dim() == 2 and y_true.shape[0] == 1:
        y_true = y_true.squeeze(0)
    if y_true.dim() != 1 or y_true.numel() != D:
        raise ValueError(f"y_true must be [D] or [1,D], with D={D}, got {tuple(y_true.shape)}")

    if not (0 <= first_dim < D):
        raise ValueError(f"first_dim must be in [0, {D-1}]")

    # ---------- inverse scaling if needed ----------
    if scaler_path is not None and (not already_inverse_scaled):
        scaler = joblib.load(scaler_path)
        y_pp = inverse_standardize_torch(y_pp, scaler)
        y_true = inverse_standardize_torch(y_true, scaler)

    # ---------- prepare samples by dimension ----------
    # y_pp: [S,R,D] -> y_dim: [D, N] where N=S*R
    y_dim = y_pp.permute(2, 0, 1).contiguous().reshape(D, -1)  # [D, N]
    N = y_dim.shape[1]

    # labels
    if dim_labels is None:
        dim_labels = [str(i + 1) for i in range(D)]
    if len(dim_labels) != D:
        raise ValueError("dim_labels length must equal D")

    # ---------- subsample for speed ----------
    # histogram samples (for first dim)
    first_samples = y_dim[first_dim]
    if max_hist_points is not None and first_samples.numel() > max_hist_points:
        idx = torch.randperm(first_samples.numel(), device=first_samples.device)[:max_hist_points]
        first_samples = first_samples[idx]

    # violin samples (for remaining dims)
    rest_dims = [d for d in range(D) if d != first_dim]
    rest_samples = y_dim[rest_dims]  # [D-1, N]
    if max_points_per_dim is not None and N > max_points_per_dim:
        idx = torch.randperm(N, device=rest_samples.device)[:max_points_per_dim]
        rest_samples = rest_samples[:, idx]  # [D-1, max_points_per_dim]

    # move to cpu numpy for matplotlib
    first_np = first_samples.detach().cpu().numpy()
    truth_np = y_true.detach().cpu().numpy()
    rest_list = [rest_samples[i].detach().cpu().numpy() for i in range(rest_samples.shape[0])]

    # ---------- plotting (two panels: hist + violin) ----------
    fig, axes = plt.subplots(2, 1, figsize=figsize, gridspec_kw={"height_ratios": [1, 1.4]})

    # (A) histogram for first dim
    ax0 = axes[0]
    ax0.hist(first_np, bins=bins, density=density)
    ax0.axvline(truth_np[first_dim], linestyle="--", linewidth=2, label="Truth")
    ax0.set_title(f"{title_prefix} — dim {dim_labels[first_dim]}: histogram")
    ax0.set_xlabel("Value")
    ax0.set_ylabel("Density" if density else "Count")
    ax0.legend()

    # (B) violin for remaining dims
    ax1 = axes[1]
    positions = np.arange(1, len(rest_dims) + 1)
    ax1.violinplot(rest_list, positions=positions, showmeans=False, showmedians=True, showextrema=False)

    # overlay truth points for rest dims
    truth_rest = truth_np[rest_dims]
    ax1.scatter(positions, truth_rest, s=22, zorder=3, label="Truth")

    ax1.set_title(f"{title_prefix} — all dims except {dim_labels[first_dim]}: violin")
    ax1.set_xlabel("Output dimension")
    ax1.set_ylabel("Value")
    ax1.set_xticks(positions)
    ax1.set_xticklabels([dim_labels[d] for d in rest_dims], rotation=0)
    ax1.legend()

    plt.tight_layout()
    plt.show()


# ---------------- usage example ----------------
# y_pp: [S_total, num_draws, 33]  (standardized or original)
# y_true: [1,33] or [33] (same scale as y_pp)

# 1) 如果 y_pp / y_true 还是标准化尺度：传 scaler_path 让函数内部 inverse
# plot_first_hist_rest_violin(y_pp, y_true, scaler_path="RealCase/y_scaler_RealCase.joblib")

# 2) 如果你已经提前 inverse 过：already_inverse_scaled=True 或 scaler_path=None realcase_y
# plot_first_hist_rest_violin(y_pp_inv, y_true_inv, already_inverse_scaled=True)


In [None]:
plot_first_hist_rest_violin(y_pp_inv, y_true_inv, 
                            already_inverse_scaled=True, 
                            max_points_per_dim=None, 
                            max_hist_points=None)