In [None]:
# ===== REPLACE EVERYTHING ABOVE YOUR OLD CLASS WITH THIS BLOCK =====
import numpy as np
import pandas as pd
from scipy.optimize import minimize


class FIMLMissing:
    """
    FIML (Full Information Maximum Likelihood) linear regression with missing data
    under a joint multivariate normal (X, y). We estimate the joint mean vector
    and covariance matrix using all partially observed rows, then derive the
    conditional expectation E[y|X] = beta_0 + X @ beta from the joint moments.

    性能关键点（已内置）：
    - 预计算“缺失模式分组”，每种观测列组合只做一次 Cholesky 分解；
    - 组内批量求解二次型，避免 Python for 循环（完全向量化）。
    """

    def __init__(self):
        # 回归系数
        self.beta_0 = None
        self.beta_coeffs = None

        # 估计得到的联合分布参数
        self.est_mu = None          # (p+1,)
        self.est_sigma = None       # (p+1, p+1)

        # 训练时的列名（X 的列顺序 + y 名）
        self.feature_names = None
        self.target_name = None

        # FIML 过程中缓存的数据与分组（用于加速）
        self._X_joint = None
        self._mask_groups = None    # list of (rows_idx, cols_obs)

    # --------- 工具函数：协方差上三角展开/重构 ---------
    @staticmethod
    def _flatten_sigma_upper(S):
        n = S.shape[0]
        out = []
        for i in range(n):
            for j in range(i, n):
                out.append(S[i, j])
        return np.asarray(out)

    @staticmethod
    def _unflatten_sigma_upper(elems, n):
        S = np.zeros((n, n))
        k = 0
        for i in range(n):
            for j in range(i, n):
                S[i, j] = elems[k]
                S[j, i] = elems[k]
                k += 1
        return S

    # --------- 预计算：按缺失模式分组 ---------
    @staticmethod
    def _build_mask_groups(X_joint):
        """
        X_joint: (N, d) with NaNs for missing
        return: list of (rows_idx, cols_obs)
        """
        mask = ~np.isnan(X_joint)               # True 表示该列有观测
        # 用 packbits 压缩每行布尔向量，作为键
        keys = np.packbits(mask, axis=1)
        row_keys = [k.tobytes() for k in keys]

        groups_dict = {}
        for i, k in enumerate(row_keys):
            if k not in groups_dict:
                cols = np.where(mask[i])[0]
                groups_dict[k] = {"rows": [], "cols": cols}
            groups_dict[k]["rows"].append(i)

        groups = [(np.asarray(v["rows"], dtype=int), v["cols"])
                  for v in groups_dict.values()]
        return groups

    # --------- 目标函数：负对数似然（按分组批量计算） ---------
    def _neg_log_likelihood(self, params_flat, n_vars):
        """
        params_flat = [mu(0:n_vars), Sigma_upper_tri_flat]
        使用 self._X_joint 与 self._mask_groups（均由 fit() 预先缓存）
        返回：负对数似然
        """
        mu = params_flat[:n_vars]
        sigma_elems = params_flat[n_vars:]
        Sigma = self._unflatten_sigma_upper(sigma_elems, n_vars)

        # 轻微数值稳定处理：必要时抖动
        try:
            np.linalg.cholesky(Sigma)
        except np.linalg.LinAlgError:
            Sigma = Sigma + np.eye(n_vars) * 1e-6

        X = self._X_joint
        groups = self._mask_groups
        if X is None or groups is None:
            raise RuntimeError("Internal state not prepared. Call fit() first.")

        total_loglik = 0.0
        const_log2pi = np.log(2.0 * np.pi)

        for rows_idx, cols_obs in groups:
            if len(cols_obs) == 0 or len(rows_idx) == 0:
                continue

            mu_sub = mu[cols_obs]
            Sigma_sub = Sigma[np.ix_(cols_obs, cols_obs)]

            # 该模式只做一次 Cholesky
            try:
                L = np.linalg.cholesky(Sigma_sub)
            except np.linalg.LinAlgError:
                # 再加一点抖动，若仍失败，给一个大负惩罚引导优化器离开
                eps = 1e-8
                try:
                    L = np.linalg.cholesky(Sigma_sub + np.eye(len(cols_obs)) * eps)
                except np.linalg.LinAlgError:
                    total_loglik -= 1e6
                    continue

            Y = X[rows_idx][:, cols_obs]   # 该组没有缺失
            centered = Y - mu_sub          # (m, k)
            # 批量解 L z^T = centered^T
            z = np.linalg.solve(L, centered.T).T  # (m, k)
            quad = np.sum(z * z, axis=1)          # 每行的二次型

            kdim = len(cols_obs)
            logdet = 2.0 * np.sum(np.log(np.diag(L)))  # log|Sigma_sub|
            group_ll = -0.5 * (kdim * const_log2pi + logdet + quad)
            total_loglik += np.sum(group_ll)

        return -total_loglik

    # --------- 训练 ---------
    def fit(self, X, y):
        """
        X: pd.DataFrame，可含缺失
        y: pd.Series，可含缺失
        """
        if not isinstance(X, pd.DataFrame):
            X = pd.DataFrame(X)
        if not isinstance(y, (pd.Series, pd.DataFrame)):
            y = pd.Series(y)

        if isinstance(y, pd.DataFrame):
            if y.shape[1] != 1:
                raise ValueError("y must be a 1-D vector or single-column DataFrame.")
            y = y.iloc[:, 0]

        self.feature_names = list(X.columns)
        self.target_name = y.name if y.name is not None else "target"

        # 组合联合数据，列顺序：X 列 + y
        data_df = pd.concat([X.reset_index(drop=True),
                             y.reset_index(drop=True).rename(self.target_name)],
                            axis=1)
        n_vars = data_df.shape[1]  # p + 1

        # ---- 初始值（pairwise deletion） ----
        initial_mu = data_df.mean(axis=0).values  # 忽略 NaN
        # DataFrame.cov() 默认 pairwise deletion
        initial_sigma_df = data_df.cov()
        # 若有全缺列导致 NaN，做点兜底
        initial_sigma_df = initial_sigma_df.reindex(index=data_df.columns,
                                                    columns=data_df.columns)
        if initial_sigma_df.isnull().values.any():
            # 用列方差/协方差的简单填充做个可行初值
            fill_val = np.nanmean(initial_sigma_df.values)
            initial_sigma_df = initial_sigma_df.fillna(fill_val if not np.isnan(fill_val) else 0.0)

        # 确保正定
        initial_sigma = initial_sigma_df.values.copy()
        try:
            np.linalg.cholesky(initial_sigma)
        except np.linalg.LinAlgError:
            initial_sigma = initial_sigma + np.eye(n_vars) * 1e-6

        initial_sigma_elems = self._flatten_sigma_upper(initial_sigma)
        initial_params_flat = np.concatenate([initial_mu, initial_sigma_elems])

        # ---- 预缓存数据与分组（提速关键） ----
        self._X_joint = data_df.values  # (N, d)
        self._mask_groups = self._build_mask_groups(self._X_joint)

        # ---- 最优化（BFGS） ----
        result = minimize(
            fun=self._neg_log_likelihood,
            x0=initial_params_flat,
            args=(n_vars,),
            method="BFGS",
            options={"maxiter": 200, "disp": False}
        )

        if not result.success:
            print(f"[Warn] FIML optimization did not fully converge: {result.message}")

        est_params = result.x
        est_mu = est_params[:n_vars]
        est_sigma = self._unflatten_sigma_upper(est_params[n_vars:], n_vars)

        self.est_mu = est_mu
        self.est_sigma = est_sigma

        # ---- 由联合矩推回回归系数 ----
        # 约定 y 在最后一列
        idx_y = n_vars - 1
        idx_X = np.arange(0, n_vars - 1)

        mu_X = est_mu[idx_X]                     # (p,)
        mu_y = est_mu[idx_y]                     # scalar
        Sigma_XX = est_sigma[np.ix_(idx_X, idx_X)]  # (p, p)
        Sigma_yX = est_sigma[idx_y, idx_X][None, :]  # (1, p)

        # 逆/解线性方程
        Sigma_XX_inv = np.linalg.inv(Sigma_XX)
        beta_row = Sigma_yX @ Sigma_XX_inv       # (1, p)

        self.beta_coeffs = beta_row.ravel()      # (p,)
        self.beta_0 = float(mu_y - beta_row @ mu_X)

        return self

    # --------- 预测 ---------
    def predict(self, X_new):
        """
        X_new: pd.DataFrame（不应再包含缺失；至少应包含训练时的所有列）
        return: np.ndarray shape (n_samples,)
        """
        if self.beta_0 is None or self.beta_coeffs is None:
            raise RuntimeError("Model has not been fitted. Call fit() first.")

        if not isinstance(X_new, pd.DataFrame):
            X_new = pd.DataFrame(X_new, columns=self.feature_names)
        # 保证列顺序一致
        X_ord = X_new[self.feature_names]
        return self.beta_0 + X_ord.values @ self.beta_coeffs
# ===== END OF REPLACEMENT BLOCK =====
