In [None]:
import numpy as np
from sklearn.impute import SimpleImputer
from scipy.stats import multivariate_normal

# 다변량 정규분포 기반 결측값 대체
def mvn_impute(data, tol=1e-6, max_iter=100):
    """
    MVN (다변량 정규분포)을 기반으로 결측값 대체.
    
    Parameters:
        data (numpy.ndarray): 결측값이 포함된 데이터 (결측값은 np.nan으로 표기)
        tol (float): 수렴 기준
        max_iter (int): 최대 반복 횟수
        
    Returns:
        numpy.ndarray: 결측값이 대체된 데이터
    """
    data = np.array(data, dtype=np.float64)
    missing_mask = np.isnan(data)
    n_samples, n_features = data.shape

    # 초기화: 평균과 공분산 추정
    imputer = SimpleImputer(strategy="mean")
    data_imputed = imputer.fit_transform(data)
    prev_data_imputed = data_imputed.copy()

    for iteration in range(max_iter):
        # 평균 및 공분산 추정
        mu = np.nanmean(data_imputed, axis=0)
        cov = np.cov(data_imputed, rowvar=False)
        
        # 결측값 대체
        for i in range(n_samples):
            missing_indices = np.where(missing_mask[i])[0]
            observed_indices = np.where(~missing_mask[i])[0]

            if len(missing_indices) > 0:
                # 분리된 공분산 행렬
                cov_oo = cov[np.ix_(observed_indices, observed_indices)]
                cov_om = cov[np.ix_(observed_indices, missing_indices)]
                cov_mm = cov[np.ix_(missing_indices, missing_indices)]

                # 조건부 평균 계산
                x_o = data_imputed[i, observed_indices]
                mu_m = mu[missing_indices]
                mu_o = mu[observed_indices]

                conditional_mean = mu_m + cov_om.T @ np.linalg.inv(cov_oo) @ (x_o - mu_o)
                
                # 결측값 대체
                data_imputed[i, missing_indices] = conditional_mean

        # 수렴 확인
        if np.linalg.norm(data_imputed - prev_data_imputed) < tol:
            print(f"Converged in {iteration + 1} iterations.")
            break

        prev_data_imputed = data_imputed.copy()

    return data_imputed


# 예제 데이터 (결측값 np.nan 포함)
data = np.array([
    [1.0, 2.0, np.nan],
    [2.0, np.nan, 0.0],
    [np.nan, 1.0, 3.0],
    [4.0, 3.0, 2.0]
])

# 결측값 대체
imputed_data = mvn_impute(data)
print("Imputed Data:")
print(imputed_data)
