### Author: Ally Sprik
### Last-updated: 25-02-2024

Goal of this notebook is to explore data imputation with the MIDAS algorithm. A deep learning autoencoder that is able to handle both continuous and categorical data. The algorithm is able to generate multiple imputations for missing data. 


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import MIDASpy as midas

df = pd.read_csv('../0.1. Cleaned_data/Training_JAMA+Brno_model_cleaned.csv', sep=',')
extra_cols = df[["Study_number", "Included_in_training_cohort", "Comorbidity_index", "Platelets_numeric", "CA125_PREOP", "Age", "BMI"]].copy()
df = df.drop(["Study_number", "Included_in_training_cohort", "Comorbidity_index", "Platelets_numeric", "CA125_PREOP", "Age", "BMI"], axis=1)
for col in df.columns.values:
    for i in range(len(df)):
        if pd.isna(df[col][i]):
            df[col][i] = np.nan
            
encoded, cat_cols_list = midas.cat_conv(df)

CUDA_VISIBLE_DEVICES=""

Set up and build the imputation mode

In [None]:
imputer = midas.Midas(layer_structure=[256,256], vae_layer=True, seed=123, input_drop=0.75)
imputer.build_model(encoded)
imputer.train_model(training_epochs=15)

Impute the data

In [None]:
imputations = imputer.generate_samples(m=10).output_list

Reapply the categorical labels

In [None]:
flat_cats = [cat for variable in cat_cols_list for cat in variable]
categorical = df.columns.values

for i in range(len(imputations)):
    tmp_cat = [imputations[i][x].idxmax(axis=1) for x in cat_cols_list]
    cat_df = pd.DataFrame({categorical[i]:tmp_cat[i] for i in range(len(categorical))})
    imputations[i] = pd.concat([imputations[i], cat_df], axis = 1).drop(flat_cats, axis = 1)

Reapply the column names

In [None]:
for i in range(0,10):
    imputation = imputations[i]
    for col in imputation.columns.values:
        for j in range(len(imputation)):
            imputations[i][col][j] = imputation[col][j].removeprefix(col + '_')

Save the last imputation

In [None]:
result = imputations[9]
for col in extra_cols.columns.values:
    for i in range(len(extra_cols)):
        result[col][i] = extra_cols[col][i]
            
result.to_csv('../0.2. Imputed_data/MIDAS_Imputed_TCGATraining_JAMA_Brno.csv', index=False)