In [None]:
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shap

def load_trace(path):
    with open(path, 'rb') as file:
        trace = pickle.load(file)
    return trace

def add_closest_imputation_col(df, target):
    def find_closest_imputation(row, target):
        values = np.array([row[f'{target}_imputed_full'], row[f'{target}_imputed_full_selected'], row[f'{target}_imputed_cluster']])
        truth = row[f'truth_{target}']
        closest = np.argmin(np.abs(values - truth))
        return values[closest]
    
    def find_closest_imputation_index(row, target):
        values = np.array([row[f'{target}_imputed_full'], row[f'{target}_imputed_full_selected'], row[f'{target}_imputed_cluster']])
        truth = row[f'truth_{target}']
        closest = np.argmin(np.abs(values - truth))
        return closest

    df['closest_imputation'] = df.apply(find_closest_imputation, axis=1, target=target)
    df['closest_imputation_index'] = df.apply(find_closest_imputation_index, axis=1, target=target)
    return df


In [None]:

# Shap values analysis
plt.figure(figsize=(12, 8))
title = 'SMOKE_NC_tree.pkl'
model = load_trace(title)
target = 'SMOKE_NC'
trace = load_trace('./results/tree_hyperimpute_target_trace_LASSO_reduced_0.pkl')
raw_imputed_values = pd.DataFrame.from_dict(trace[target]['values'])
imputed_values = raw_imputed_values.drop(columns=[f'truth_{target}', f'combined_{target}'])

explainer = shap.TreeExplainer(model)
values = explainer.shap_values(imputed_values)

instance_index = 15
instance_values = raw_imputed_values.iloc[instance_index].to_dict()
truth = instance_values[f'truth_{target}']
ensemble = instance_values[f'combined_{target}']
class_index = int(ensemble)
shap.initjs()
shap.force_plot(explainer.expected_value[class_index], values[instance_index][:, class_index], imputed_values.iloc[instance_index, :], matplotlib=True)


In [None]:

import logging
logging.basicConfig(
    filename='temp.log',
    level=logging.DEBUG,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

targs = ['BREASTFEED_BIRTH_FLG_NC']
traces = load_trace('./results/hyperimpute_target_trace_deprived_LASSO_important_0.pkl')
# scores = []
# for seed in range(3):
#     traces = load_trace(f'./results/hyperimpute_target_trace_LASSO_reduced_{seed}.pkl')
#     for target in targs:
#         score = traces[target]['cluster_scores']
#         scores.append(score[0])
# logging.info('HRM')
# logging.info(f'{np.mean(scores)}, {np.std(scores, ddof=1)}')
# targs = ['BREASTFEED_BIRTH_FLG_NC']
# targs = ['BIRTH_WEIGHT_NC']

name_map = {
    'BIRTH_WEIGHT_NC': 'Birth Weight',
    'SMOKE_NC': 'Maternal Smoking',
    'BREASTFEED_BIRTH_FLG_NC': 'Breastfeeding',
    'BIRTH_WEIGHT_CAT_NC': 'Birth Weight Category',
    'DEP_SCORE': 'Deprivation Score',
    'MAT_AGE_NC': 'Maternal Age',
    'APGAR_2_NC': 'Apgar-2 Score',
    'GEST_AGE_NC': 'Gestational Age',
}

def visualise_var_across_clusters(target, var, data):
    clusters = sorted(data['CLUSTER'].unique())
    print(clusters)

    means = []
    stds = []
    for cluster in clusters:
        cluster_data = data[data['CLUSTER'] == cluster][var]
        means.append(np.mean(cluster_data))
        stds.append(np.std(cluster_data, ddof=1))
    plt.figure(figsize=(8, 6))
    x = range(len(clusters))
    plt.bar(x, means, yerr=stds, capsize=5, edgecolor='black')
    plt.title(f'Distribution of {name_map[var]} per Cluster (for {name_map[target]})')
    plt.ylabel(name_map[var])
    plt.xlabel('Cluster ID')
    plt.xticks(x, [f'Cluster {int(cluster)}' for cluster in clusters])
    plt.show()

interesting = ['MAT_AGE_NC', 'BIRTH_WEIGHT_NC', 'GEST_AGE_NC', 'DEP_SCORE']
# interesting = ['MAT_AGE_NC', 'GEST_AGE_NC', 'DEP_SCORE']

for target in targs:
    trace = traces[target]
    # average = np.mean([score[0] for score in traces[target]['cluster_scores']])
    # logging.info(f'Average Cluster Score: {trace['cluster_scores']}')
    selected_vars = trace['selected_vars']
    rankings = trace['feature_rankings']
    sorted_rankings = sorted(rankings, key=lambda x: x[1], reverse=True)
    # logging.info(f'Rankings: {[var for var in sorted_rankings if var[1] > 0.01]}')
    # mnar = [var for var in sorted_rankings if 'missing' in var[0] and var[1] > 0]
    # logging.info(f'MNAR Vars: {mnar}')
    cluster_imputed = pd.DataFrame.from_dict(trace['clustering_imputed'])
    imputed_values = pd.DataFrame.from_dict(trace['values'])
    clustered_imputed = cluster_imputed.copy()
    clusters_series = pd.Series(trace['clusters'])
    clustered_imputed['CLUSTER'] = clusters_series
    print(clustered_imputed['CLUSTER'].unique())
    clusters = sorted(clustered_imputed['CLUSTER'].unique())
    # for cluster in clusters:
    #     logging.info(f'Proportion for Cluster: {cluster} with size {len(imputed_values[clustered_imputed['CLUSTER'] == cluster])}')
    #     logging.info(imputed_values[clustered_imputed['CLUSTER'] == cluster][f'truth_{target}'].value_counts(normalize=True) * 100)
    conditions = ((clustered_imputed['BIRTH_WEIGHT_NC'] >= 400) & (clustered_imputed['BIRTH_WEIGHT_NC'] <= 8000))
    # conditions = ((clustered_imputed['BIRTH_WEIGHT_NC'] >= 400) & (clustered_imputed['BIRTH_WEIGHT_NC'] <= 8000))
    clustered_imputed = clustered_imputed[conditions]
    for var in interesting:
        if var != 'CLUSTER':
            visualise_var_across_clusters(target, var, clustered_imputed)


In [None]:
import matplotlib.pyplot as plt
import pandas as pd

trace = load_trace('./results/hyperimpute_target_trace_deprived_LASSO_reduced_0.pkl')

def get_gest_cat(value):
    print(value)
    if value <= 27.0:
        return 0
    elif value <= 31.0:
        return 4.0
    elif value <= 36.0:
        return 2.0
    elif value <= 41.0:
        return 3.0
    else:
        return 1.0

def get_birth_cat(value):
    if value <= 1000.0:
        return 0
    elif value <= 1500.0:
        return 5.0
    elif value <= 2500.0:
        return 2.0
    elif value <= 4000.0:
        return 1.0
    elif value <= 4500.0:
        return 3.0
    else:
        return 4.0

name_map = {
    'BIRTH_WEIGHT_NC': 'Birth Weight',
    'SMOKE_NC': 'Maternal Smoking',
    'BREASTFEED_BIRTH_FLG_NC': 'Breastfeeding',
    'BIRTH_WEIGHT_CAT_NC': 'Birth Weight Category',
}

def strat_name(strategy):
    if 'cluster' in strategy:
        return 'Clustered'
    elif 'selected' in strategy:
        return 'Selected'
    elif 'combined' in strategy:
        return 'Ensemble'
    elif 'full' in strategy:
        return 'Full'

def plot_performance():
    targets = ['SMOKE_NC', 'BREASTFEED_BIRTH_FLG_NC']
    cont_vars = ['BIRTH_WEIGHT_NC', 'GEST_AGE_NC', 'MAT_AGE_NC']

    for col in targets:
        if col == 'BIRTH_WEIGHT_CAT_NC':
            new_col = 'BIRTH_WEIGHT_NC'
            target_trace = trace[new_col]
            imputed_values = pd.DataFrame.from_dict(target_trace['values'])
            true_col = imputed_values[f'truth_{new_col}']
            strats = [f'{new_col}_imputed_full', f'{new_col}_imputed_full_selected', f'{new_col}_imputed_cluster', f'combined_{new_col}']
        elif col == 'GEST_AGE_CAT_NC':
            new_col = 'GEST_AGE_NC'
            target_trace = trace[new_col]
            imputed_values = pd.DataFrame.from_dict(target_trace['values'])
            true_col = imputed_values[f'truth_{new_col}']
            strats = [f'{new_col}_imputed_full', f'{new_col}_imputed_full_selected', f'{new_col}_imputed_cluster', f'combined_{new_col}']
        else:
            target_trace = trace[col]
            imputed_values = pd.DataFrame.from_dict(target_trace['values'])
            true_col = imputed_values[f'truth_{col}']
            strats = [f'{col}_imputed_full', f'{col}_imputed_full_selected', f'{col}_imputed_cluster', f'combined_{col}']

        for strat in strats:
            imputed_col = imputed_values[strat]
            if col not in cont_vars:
                imputed_col = imputed_values[strat].round().astype(int)
                fig, ax = plt.subplots()
                if col == 'BIRTH_WEIGHT_CAT_NC':
                    true_col = imputed_values['truth_BIRTH_WEIGHT_NC'].apply(get_birth_cat)
                    imputed_col = imputed_values[strat].apply(get_birth_cat)
                elif col == 'GEST_AGE_CAT_NC':
                    true_col = imputed_values['truth_GEST_AGE_NC'].apply(get_gest_cat)
                    imputed_col = imputed_values[strat].apply(get_gest_cat)
                elif col == 'BREASTFEED_BIRTH_FLG_NC':
                    imputed_col = imputed_values[strat].replace(9.0, np.nan).dropna()
                    true_col = imputed_values[strat].loc[true_col.index].round().astype(int)

                imputed_counts = imputed_col.value_counts()
                true_counts = true_col.value_counts()
                df_counts = pd.DataFrame({'True': true_counts, 'Imputed': imputed_counts})
                df_counts.plot(kind='bar', ax=ax)
                ax.set_title(f'{strat_name(strat)} Imputation against Ground Truth for {name_map[col]}')
                ax.set_xlabel(name_map[col])
                ax.set_ylabel('Count')
                plt.show()
                plt.close(fig)

            else:
                fig, ax = plt.subplots()
                ax.plot(true_col.index, true_col, label='True', marker='o', linestyle='-')
                ax.plot(imputed_col.index, imputed_col, label='Imputed', marker='x', linestyle='--')
                ax.set_title(f'{strat_name(strat)} Imputation Values against Ground Truth for {name_map[col]}')
                ax.set_xlabel('Row Index')
                ax.set_ylabel('Weight')
                ax.legend()
                plt.show()
                plt.close(fig)

plot_performance()