In [None]:
import time
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================
# 1. 读数据 + 划分 + 标准化
# =========================
df = pd.read_excel("../../data/ENB2012_data.xlsx")
X = df.iloc[:, :8].values.astype(np.float32)
y = df.iloc[:, 8].values.astype(np.float32).reshape(-1, 1)

X_train_full, X_test, y_train_full, y_test = train_test_split(
    X, y, test_size=0.2, random_state=0
)

scaler_x = StandardScaler().fit(X_train_full)
scaler_y = StandardScaler().fit(y_train_full)

X_train_full = scaler_x.transform(X_train_full)
X_test       = scaler_x.transform(X_test)
y_train_full = scaler_y.transform(y_train_full)
y_test       = scaler_y.transform(y_test)

# 再把训练集一分为二：train / cal，给 split CP 用
X_tr, X_cal, y_tr, y_cal = train_test_split(
    X_train_full, y_train_full, test_size=0.5, random_state=1
)

X_tr  = torch.tensor(X_tr,  dtype=torch.float32).to(device)
y_tr  = torch.tensor(y_tr,  dtype=torch.float32).to(device)
X_cal = torch.tensor(X_cal, dtype=torch.float32).to(device)
y_cal = torch.tensor(y_cal, dtype=torch.float32).to(device)
X_test_t = torch.tensor(X_test, dtype=torch.float32).to(device)
y_test_t = torch.tensor(y_test, dtype=torch.float32).to(device)

print("Train:", X_tr.shape, y_tr.shape)
print("Cal  :", X_cal.shape, y_cal.shape)
print("Test :", X_test_t.shape, y_test_t.shape)

# =========================
# 2. 定义一个简单 MLP
# =========================
class MLP(nn.Module):
    def __init__(self, in_dim=8, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1),
        )
    def forward(self, x):
        return self.net(x)

def train_mlp(X, y, epochs=200, batch_size=64, lr=1e-3):
    ds = TensorDataset(X, y)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True)
    model = MLP(in_dim=X.shape[1]).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    model.train()
    for _ in range(epochs):
        for xb, yb in dl:
            opt.zero_grad()
            pred = model(xb)
            loss = loss_fn(pred, yb)
            loss.backward()
            opt.step()
    return model

# =========================
# 3. Split Conformal (一次训练)
# =========================
def split_conformal(alpha=0.1):
    t0 = time.time()
    model = train_mlp(X_tr, y_tr)
    t1 = time.time()

    model.eval()
    with torch.no_grad():
        mu_cal = model(X_cal)
    resid = (y_cal - mu_cal).abs().cpu().numpy().ravel()
    q_hat = np.quantile(resid, 1 - alpha)
    t2 = time.time()

    with torch.no_grad():
        mu_test = model(X_test_t)
    mu_test_np = mu_test.cpu().numpy().ravel()
    lower = mu_test_np - q_hat
    upper = mu_test_np + q_hat
    t3 = time.time()

    # 还原回原尺度，方便算 coverage / 长度
    y_test_np = scaler_y.inverse_transform(y_test).ravel()
    mu_test_raw = scaler_y.inverse_transform(mu_test_np.reshape(-1, 1)).ravel()
    lower_raw   = scaler_y.inverse_transform(lower.reshape(-1, 1)).ravel()
    upper_raw   = scaler_y.inverse_transform(upper.reshape(-1, 1)).ravel()

    coverage = np.mean((y_test_np >= lower_raw) & (y_test_np <= upper_raw))
    avg_len  = np.mean(upper_raw - lower_raw)

    train_time = t1 - t0
    cal_time   = t2 - t1
    test_time  = t3 - t2
    total_time = t3 - t0

    return dict(
        coverage=coverage,
        avg_len=avg_len,
        train_time=train_time,
        cal_time=cal_time,
        test_time=test_time,
        total_time=total_time,
        q_hat=q_hat,
    )

# =========================
# 4. Bootstrap 预测区间 (训练 B 次 + 预测 B 次)
# =========================
def bootstrap_pi(B=20, alpha=0.1, epochs=200):
    n_tr = X_tr.shape[0]
    t0 = time.time()
    preds = []

    for b in range(B):
        idx = np.random.choice(n_tr, n_tr, replace=True)
        X_b = X_tr[idx]
        y_b = y_tr[idx]
        model_b = train_mlp(X_b, y_b, epochs=epochs)
        model_b.eval()
        with torch.no_grad():
            mu_test_b = model_b(X_test_t)
        preds.append(mu_test_b.cpu().numpy().ravel())

    t1 = time.time()
    preds = np.stack(preds, axis=0)  # (B, n_test)

    lower = np.quantile(preds, alpha/2, axis=0)
    upper = np.quantile(preds, 1 - alpha/2, axis=0)
    t2 = time.time()

    # 还原尺度
    y_test_np = scaler_y.inverse_transform(y_test).ravel()
    lower_raw = scaler_y.inverse_transform(lower.reshape(-1, 1)).ravel()
    upper_raw = scaler_y.inverse_transform(upper.reshape(-1, 1)).ravel()

    coverage = np.mean((y_test_np >= lower_raw) & (y_test_np <= upper_raw))
    avg_len  = np.mean(upper_raw - lower_raw)

    train_time = t1 - t0      # 训练 B 个网络的时间
    pred_time  = t2 - t1      # 在测试集上聚合分位数的时间
    total_time = t2 - t0

    return dict(
        coverage=coverage,
        avg_len=avg_len,
        train_time=train_time,
        pred_time=pred_time,
        total_time=total_time,
    )

# =========================
# 5. 跑一遍对比
# =========================
if __name__ == "__main__":
    alpha = 0.1

    print("=== Split Conformal ===")
    res_sc = split_conformal(alpha=alpha)
    print(res_sc)
    Bs = [5, 10, 20,100,300,600]
    for B in [5, 10, 20,100,300,600]:
        print(f"\n=== Bootstrap, B={B} ===")
        res_boot = bootstrap_pi(B=B, alpha=alpha, epochs=200)
        print(res_boot)


Train: torch.Size([307, 8]) torch.Size([307, 1])
Cal  : torch.Size([307, 8]) torch.Size([307, 1])
Test : torch.Size([154, 8]) torch.Size([154, 1])
=== Split Conformal ===
{'coverage': 0.8896103896103896, 'avg_len': 3.0321617, 'train_time': 0.6756629943847656, 'cal_time': 0.0008099079132080078, 'test_time': 8.320808410644531e-05, 'total_time': 0.6765561103820801, 'q_hat': 0.15206149816513073}

=== Bootstrap, B=5 ===
{'coverage': 0.44805194805194803, 'avg_len': 2.092079302336545, 'train_time': 3.001145839691162, 'pred_time': 0.00034427642822265625, 'total_time': 3.0014901161193848}

=== Bootstrap, B=10 ===
{'coverage': 0.6038961038961039, 'avg_len': 2.371622004570438, 'train_time': 6.057910203933716, 'pred_time': 0.00043392181396484375, 'total_time': 6.058344125747681}

=== Bootstrap, B=20 ===
{'coverage': 0.6038961038961039, 'avg_len': 2.7865296517702665, 'train_time': 12.030375957489014, 'pred_time': 0.0005068778991699219, 'total_time': 12.030882835388184}

=== Bootstrap, B=100 ===
{'c