# Step 4: Multiple Imputation by Chained Equations (MICE)

In [1]:
import pandas as pd
import miceforest as mf
import os
import matplotlib.pyplot as plt
import numpy as np
from typing import Union
import json

# Set font family to a font that supports CJK characters
from matplotlib import rcParams
rcParams['font.sans-serif'] = ['PingFang HK']
rcParams['axes.unicode_minus'] = False  # Ensure minus sign renders correctly
from plotnine import theme, element_text

In [2]:
# Define paths to data
df_path = os.path.join("outputs", "Processed_Data.csv")
translation_path = os.path.join("raw_data", "translation.json")

## 1. Load Data

In [3]:
#Load data
df_full = pd.read_csv(df_path)
df_full.shape

(6724, 100)

In [4]:
#Drop ID column
ID_col = "匹配ID_日期"
df = df_full.drop(columns=[ID_col])
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6724 entries, 0 to 6723
Data columns (total 99 columns):
 #   Column              Non-Null Count  Dtype  
---  ------              --------------  -----  
 0   胎膜早破                6724 non-null   int64  
 1   胎儿宫内窘迫              6724 non-null   int64  
 2   巨大儿                 6724 non-null   int64  
 3   羊水污染                6724 non-null   int64  
 4   妊娠期糖尿病              6724 non-null   int64  
 5   妊娠期高血压              6724 non-null   int64  
 6   妊娠合并肝损害             6724 non-null   int64  
 7   妊娠合并肝内胆汁淤积症         6724 non-null   int64  
 8   孕妇产次                6724 non-null   int64  
 9   足月产次数               6724 non-null   int64  
 10  早产次数                6724 non-null   int64  
 11  流产次数                6724 non-null   int64  
 12  人流次数                6724 non-null   int64  
 13  体重                  6724 non-null   float64
 14  身高                  6724 non-null   float64
 15  1小时葡萄糖              6724 non-null   float64
 16  2小时葡萄糖

## 2. English Translation

In [5]:
#load the translation mapping json
with open(translation_path, 'r', encoding='utf-8') as f:
    column_mapping = json.load(f)

#check missing translations
unmapped = [col for col in df.columns if col not in column_mapping]
if unmapped:
    print("Missing translations detected:")
    for missing_col in unmapped:
        print(f"\t{missing_col}")
    raise ValueError(f"{len(unmapped)} columns are not mapped in the translation file.")

# Rename columns using the translation mapping
df_translated = df.rename(columns=column_mapping)
df_translated.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6724 entries, 0 to 6723
Data columns (total 99 columns):
 #   Column                                      Non-Null Count  Dtype  
---  ------                                      --------------  -----  
 0   Premature Rupture of Membranes              6724 non-null   int64  
 1   Fetal Distress                              6724 non-null   int64  
 2   Macrosomia                                  6724 non-null   int64  
 3   Amniotic Fluid Contamination                6724 non-null   int64  
 4   Gestational Diabetes Mellitus               6724 non-null   int64  
 5   Gestational Hypertension                    6724 non-null   int64  
 6   Pregnancy Complicated by Hepatic Injury     6724 non-null   int64  
 7   Intrahepatic Cholestasis of Pregnancy       6724 non-null   int64  
 8   Pregnant woman's parity                     6724 non-null   int64  
 9   Number of full-term births                  6724 non-null   int64  
 10  Number of pr

## 3. Apply MICE Imputation

In [6]:
IMPUTED_DATASETS = 3
ITERATIONS = 20

# Initialize the imputation kernel
kernel = mf.ImputationKernel(
    data=df_translated,
    num_datasets=IMPUTED_DATASETS,  # Number of imputed datasets
    random_state=42
)

# Perform MICE with N iterations per dataset
kernel.mice(ITERATIONS)

# Retrieve the imputed datasets 
imputed_datasets = [kernel.complete_data(dataset=i) for i in range(IMPUTED_DATASETS)]

# Ensure indexes match
for i, imputed_df in enumerate(imputed_datasets, start=1):
    assert imputed_df.shape[0] == df_translated.shape[0], f"Row count mismatch in dataset {i}"
    assert all(imputed_df.index == df_translated.index), f"Index mismatch in dataset {i}"
print("All imputed datasets match the original DataFrame indexes.")


All imputed datasets match the original DataFrame indexes.


## 4. Get metrics

In [7]:
METRICS_BASE_DIR = os.path.join(os.getcwd(), "MICE", "imputation_metrics")
if not os.path.isdir(METRICS_BASE_DIR):
    os.makedirs(METRICS_BASE_DIR)

#Get feature names that had missing values before imputation
def get_na_feature_names(df: pd.DataFrame):
    return [col for col in df.columns if df[col].isna().any()]

#Convergence diagnostic
def get_convergence_diagnostic(kernel: mf.ImputationKernel, feature_names: list[str], iterations_cap: int=ITERATIONS):
    for dataset_id in range(kernel.num_datasets):
        #Check directory for current dataset
        dataset_file_dir = f"Convergence Imputed_{dataset_id + 1}"
        save_dir = os.path.join(METRICS_BASE_DIR, dataset_file_dir)
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        
        for feature_name in feature_names:
            means_per_iteration = []
            for iteration in range(iterations_cap):
                current_imputed = kernel.complete_data(dataset=dataset_id, iteration=iteration)
                means_per_iteration.append(np.mean(current_imputed[feature_name]))

            plt.plot(means_per_iteration, marker='o')
            plt.xlabel("Iteration")
            plt.ylabel("Mean of Imputed Values")
            plt.title(f"Mean Convergence for '{feature_name}'")
            
            # Adjust plot display for the X axis
            _ticks = np.arange(iterations_cap)
            _labels = np.arange(1, iterations_cap + 1)
            plt.xticks(ticks=_ticks, labels=_labels)
            
            save_path = os.path.join(save_dir, feature_name + ".png")
            plt.savefig(save_path, bbox_inches='tight')
            plt.close()
            
        print(f"{dataset_file_dir} completed.")

# Imputed distributions
def get_imputed_distributions(kernel: mf.ImputationKernel, feature_names: Union[list[str], None]=None):
    ''' 
    Save the imputed value distributions as a PNG. 
    It works using miceforest's authors implementation of the method `.plot_imputed_distributions()`.
    
    * Red lines are the distribution of original data.
    * Black lines are the distribution of the imputed values.
    ''' 
    save_path = os.path.join(METRICS_BASE_DIR, "Imputed Distributions.png")
    
    print("Imputed Distribution Legend:\n\tRed lines are the distribution of original data.\n\tBlack lines are the distribution of the imputed values.")
    
    fig = kernel.plot_imputed_distributions(variables=feature_names)
    
    # Update the theme to use a font that supports CJK characters
    fig = fig + theme(text=element_text(family="PingFang HK"))
    
    fig.save(save_path, width=25, height=18, dpi=250, verbose=False)

In [8]:
# Save metrics
na_feature_names = get_na_feature_names(df_translated)
get_convergence_diagnostic(kernel=kernel, feature_names=na_feature_names)
get_imputed_distributions(kernel=kernel, feature_names=na_feature_names)

Convergence Imputed_1 completed.
Convergence Imputed_2 completed.
Convergence Imputed_3 completed.
Imputed Distribution Legend:
	Red lines are the distribution of original data.
	Black lines are the distribution of the imputed values.


## 5. Save Imputed Datasets

In [9]:
#Output path
DATASETS_OUTPUT_DIR = os.path.join(os.getcwd(), "MICE", "Imputed_Datasets")
if not os.path.isdir(DATASETS_OUTPUT_DIR):
    os.makedirs(DATASETS_OUTPUT_DIR)

# Save each imputed dataset with a unique name
for i, imputed_df in enumerate(imputed_datasets, start=1):
    if i < 10:
        file_name = f"imputed_0{i}.csv"
    else:
        file_name = f"imputed_{i}.csv"
    output_path = os.path.join(DATASETS_OUTPUT_DIR, file_name)
    imputed_df.to_csv(output_path, index=False)
    print(f"Saved {file_name} with shape {imputed_df.shape}")

Saved imputed_01.csv with shape (6724, 99)
Saved imputed_02.csv with shape (6724, 99)
Saved imputed_03.csv with shape (6724, 99)
