In [1]:
import os
import pandas as pd # type: ignore
import numpy as np # type: ignore

try:
    os.chdir('/container/mount/point')
except FileNotFoundError:
    print("Warning: Directory '/container/mount/point' does not exist.")

from utils.preprocessing import filter_and_process_asv_table
from utils.pair_matching import discrepancyMatrix, construct_network, process_matched_pairs, generate_simulated_outcomes  # noqa: E402

In [2]:
def preprocess_exposure(data, target_variable, mapping, new_col, dataset_name):
    """
    Preprocesses the exposure variable in the dataset.

    Parameters:
        data (pd.DataFrame): Input DataFrame.
        target_variable (str): Name of the column to map.
        mapping (dict): Dictionary mapping original values to new values.
        new_col (str): Name of the new column to create.
        dataset_name (str): Name of the dataset (for logging).

    Returns:
        pd.DataFrame: DataFrame with the new exposure column and rows with missing values dropped.
    """
    print(f"\nRunning dataset: {dataset_name}")
    print(f"Covariates data (before): {data.shape}")
    data.loc[:, new_col] = data[target_variable].map(mapping)
    data = data.dropna(subset=[new_col])
    print(f"Covariates data (after): {data.shape}")
    return data

def set_thresholds(df, column_thresholds):
    """
    Sets matching thresholds for covariates.

    Parameters:
        df (pd.DataFrame): Input DataFrame.
        column_thresholds (dict): Dictionary of column names and their threshold values.

    Returns:
        np.ndarray: Array of thresholds aligned with DataFrame columns.
    """
    thresholds = np.full(df.shape[1], np.nan)
    for col, val in column_thresholds.items():
        if col in df.columns:
            thresholds[df.columns.get_loc(col)] = val
    return thresholds

def match_and_simulate(df, target_variable, target_encoding, column_thresholds, n_col, output_prefix, dataset_name):
    """
    Performs matching and simulates outcomes for a given dataset.

    Parameters:
        df (pd.DataFrame): Input DataFrame.
        target_variable (str): Name of the exposure variable.
        target_encoding (dict): Mapping for exposure variable to string labels.
        column_thresholds (dict): Dictionary of covariate thresholds.
        n_col (int): Number of randomizations for simulation.
        output_prefix (str): Prefix for output file paths.
        dataset_name (str): Name of the dataset (for logging).

    Returns:
        pd.DataFrame: DataFrame of matched pairs.
    """
    print(f"\nMatching and simulating for dataset: {dataset_name}")
    df["W"] = df[target_variable]
    df["W_str"] = df["W"].map(target_encoding)
    df["is_treated"] = df["W"].astype(bool)
    df["pair_nb"] = np.nan

    test, control = df[df["W"] == 0], df[df["W"] == 1]
    print(f"Number of test - {len(test)}")
    print(f"Number of control - {len(control)}")

    thresholds = set_thresholds(df, column_thresholds)
    scaling = np.ones(df.shape[1], dtype=int)

    treated_units = df[df["is_treated"]]
    control_units = df[~df["is_treated"]]
    print(f"Number of treated units: {treated_units.shape[0]}")
    print(f"Number of control units: {control_units.shape[0]}")

    discrepancies = discrepancyMatrix(treated_units, control_units, thresholds, scaling)
    g, pairs_dict = construct_network(discrepancies, treated_units.shape[0], control_units.shape[0])
    matched_df = process_matched_pairs(pairs_dict, treated_units, control_units)

    print(f"Number of pairs: {len(matched_df.W)}")
    print(f"Number of test individuals: {len(matched_df[matched_df.W == 0])}")
    print(f"Number of control individuals: {len(matched_df[matched_df.W == 1])}\n")

    matched_df.to_csv(f'{output_prefix}_matched_df_{target_variable}.csv', index=True)
    simulated_outcomes = generate_simulated_outcomes(matched_df, n_col)
    simulated_outcomes.to_csv(f'{output_prefix}_simulated_outcomes_{target_variable}.csv', index=True)

    return matched_df

### AGP Dataset

In [11]:
# AGP parameters
K = 1000
agp_params = {
    "name": "AGP",
    "file": "data/AGP/agdata_smoke.csv",
    "index_col": 0,
    "target_variable": "smoking_frequency",
    "mapping": {"Daily": 0, "Never": 1},
    "new_col": "W",
    "output_prefix": "data/AGP",
    "column_thresholds": {"sex": 0, "age_cat": 0, "bmi_corrected": 4},
    "encoding": {0: "Yes", 1: "No"}
}

# AGP workflow
agp_data = pd.read_csv(agp_params["file"], index_col=agp_params["index_col"], low_memory=False)
agp_data = preprocess_exposure(agp_data, agp_params["target_variable"], agp_params["mapping"], agp_params["new_col"], agp_params["name"])
agp_data.to_csv(f"{agp_params['output_prefix']}_preprocessed.csv", index=True)

agp_matched_df = match_and_simulate(
    agp_data, agp_params["new_col"], agp_params["encoding"], agp_params["column_thresholds"], K, agp_params["output_prefix"], agp_params["name"]
)
agp_matched_df.to_csv("data/smoking_AGP_experiment.csv", index=True)


Running dataset: AGP
Covariates data (before): (12089, 660)
Covariates data (after): (12089, 660)

Matching and simulating for dataset: AGP
Number of test - 234
Number of control - 11855
Number of treated units: 11855
Number of control units: 234
Number of pairs: 468
Number of test individuals: 234
Number of control individuals: 234



### KORA Dataset

In [3]:
# --- PARAMETERS ---
K = 1000
kora_params = {
    "name": "KORA",
    "file": "data/kora_full_preprocessed_masked.csv",
    "index_col": "u3_16s_id",
    "target_variable": "smoking_(cat)",
    "mapping": {1: 0, 3: 1},
    "new_col": "smoking_bin",
    "output_prefix": "data/KORA",
    "column_thresholds": {"sex": 0, "age_exm": 0, "bmi": 4},
    "encoding": {0: "Yes", 1: "No"}
}

# --- LOAD DATA ---
kora_data = pd.read_csv(kora_params["file"], index_col=kora_params["index_col"], low_memory=False)
kora_data = preprocess_exposure(
    kora_data,
    kora_params["target_variable"],
    kora_params["mapping"],
    kora_params["new_col"],
    kora_params["name"]
)
asv = pd.read_csv("data/feature_table.tsv", index_col=0, sep='\t')

# --- OVERLAP SAMPLES BEFORE MATCHING ---
metadata_ids = kora_data.index.astype(str)
asv_ids = asv.columns.astype(str)
common_ids = sorted(set(metadata_ids) & set(asv_ids))

kora_data_filtered = kora_data.loc[kora_data.index.astype(str).isin(common_ids)].copy()
asv_filtered = asv.loc[:, asv.columns.astype(str).isin(common_ids)].copy()

# --- PAIR-MATCHING ON FILTERED METADATA ---
kora_matched_df = match_and_simulate(
    kora_data_filtered,
    kora_params["new_col"],
    kora_params["encoding"],
    kora_params["column_thresholds"],
    K,
    kora_params["output_prefix"],
    kora_params["name"]
)

# --- FILTER ASV TABLE TO MATCHED SAMPLES ---
matched_ids = sorted(kora_matched_df.index.astype(str))
asv_matched = asv_filtered.loc[:, matched_ids]

# --- OPTIONAL: FILTER AND PROCESS ASV TABLE ---
asv_top99_samples, asv_samples_ids = filter_and_process_asv_table(asv_matched, freq_threshold=0.01)

# --- FINAL ALIGNMENT AND BALANCING ---
final_ids = sorted(set(kora_matched_df.index.astype(str)) & set(asv_top99_samples.columns.astype(str)))
kora_matched_df_final = kora_matched_df.loc[kora_matched_df.index.astype(str).isin(final_ids)]
asv_top99_samples_final = asv_top99_samples.loc[:, final_ids]

# Re-balance groups to the minimum group size
group_sizes = kora_matched_df_final[kora_params["new_col"]].value_counts()
min_group_size = group_sizes.min()

group_0 = kora_matched_df_final[kora_matched_df_final[kora_params["new_col"]] == 0].sample(n=min_group_size, random_state=42)
group_1 = kora_matched_df_final[kora_matched_df_final[kora_params["new_col"]] == 1].sample(n=min_group_size, random_state=42)
kora_matched_df_final = pd.concat([group_0, group_1]).sort_index()

# Final alignment
final_ids = sorted(kora_matched_df_final.index.astype(str))
asv_top99_samples_final = asv_top99_samples_final.loc[:, final_ids]

# Final checks
assert len(final_ids) == min_group_size * 2, "Group sizes are not equal!"
assert asv_top99_samples_final.shape[1] == kora_matched_df_final.shape[0], "Sample counts do not match!"

print("Final matched metadata shape:", kora_matched_df_final.shape)
print("Final ASV table shape:", asv_top99_samples_final.shape)
print("Final group sizes:")
print(kora_matched_df_final[kora_params["new_col"]].value_counts())

# Draw random outcome for smoking status
simulated_outcomes = generate_simulated_outcomes(kora_matched_df_final, K)

asv_top99_samples_final.to_csv("data/filtered_count_table.csv", index=True)
kora_matched_df_final.to_csv("data/smoking_KORA_experiment.csv", index=True)
simulated_outcomes.to_csv("data/simulated_KORA_outcomes.csv", index=True)


Running dataset: KORA
Covariates data (before): (1938, 75)
Covariates data (after): (1084, 76)

Matching and simulating for dataset: KORA
Number of test - 237
Number of control - 733
Number of treated units: 733
Number of control units: 237
Number of pairs: 440
Number of test individuals: 220
Number of control individuals: 220

These columns have not variance and will be dropped: Index(['33231', '50139'], dtype='object')
Final matched metadata shape: (436, 80)
Final ASV table shape: (1469, 436)
Final group sizes:
smoking_bin
0.0    218
1.0    218
Name: count, dtype: int64


### Check BMI before and after pair-matching

In [24]:
import plotly.figure_factory as ff

# Before pair-matching
group_0 = kora_matched_df_final[kora_matched_df_final['smoking_bin'] == 0]
group_1 = kora_matched_df_final[kora_matched_df_final['smoking_bin'] == 1]

# After pair-matching
smoker_df = kora_data[kora_data['smoking_bin'] == 0]
non_smoker_df = kora_data[kora_data['smoking_bin'] == 1]

def clean_bmi(df, col='bmi'):
    """Remove NaN and inf values from BMI column."""
    return df[col].replace([np.inf, -np.inf], np.nan).dropna()

def plot_bmi_dist(hist_data, group_labels, title, colors):
    """Plot probability density for BMI distributions."""
    fig = ff.create_distplot(
        hist_data, 
        group_labels, 
        bin_size=2, 
        show_rug=False,
        histnorm="probability density", 
        colors=colors
    )
    fig.update_layout(
        title_text=title,
        xaxis_title_text="BMI (kg/m²)",
        yaxis_title_text='Probability density',
        bargap=0.1,
        bargroupgap=0.01
    )

    return fig

colors = ['slategray', 'magenta']
group_labels = ['Smoker', 'Non-smoker']

# Before pair-matching
bmi_before = [
    clean_bmi(group_0),
    clean_bmi(group_1)
]

# After pair-matching
bmi_after = [
    clean_bmi(smoker_df),
    clean_bmi(non_smoker_df)
]

bin_size = 2  # keep one source of truth

def trace_ymax(fig):
    import numpy as np
    ymax = 0.0
    for tr in fig.data:
        y = getattr(tr, "y", None)
        if y is None:
            continue
        arr = np.asarray(y, dtype=float)
        if arr.size == 0:
            continue
        arr = arr[np.isfinite(arr)]
        if arr.size:
            ymax = max(ymax, float(arr.max()))
    return ymax

# Global x-range from all BMI values
all_bmi = pd.concat([
    clean_bmi(group_0), clean_bmi(group_1),
    clean_bmi(smoker_df), clean_bmi(non_smoker_df)
])
x_min, x_max = all_bmi.min(), all_bmi.max()
x_min = np.floor(x_min / bin_size) * bin_size
x_max = np.ceil(x_max / bin_size) * bin_size

# Rebuild using the same bin size everywhere
fig_before = plot_bmi_dist(
    [clean_bmi(group_0), clean_bmi(group_1)],
    ['Smoker', 'Non-smoker'],
    'KORA: BMI probability density before pair-matching',
    ['slategray', 'magenta']
)
fig_after = plot_bmi_dist(
    [clean_bmi(smoker_df), clean_bmi(non_smoker_df)],
    ['Smoker', 'Non-smoker'],
    'KORA: BMI probability density after pair-matching',
    ['slategray', 'magenta']
)

# Align axes
y_max = max(trace_ymax(fig_before), trace_ymax(fig_after))
for fig in (fig_before, fig_after):
    fig.update_xaxes(range=[x_min, x_max])
    fig.update_yaxes(range=[0, y_max])

fig_before.show()
fig_after.show()


fig_before.write_image("plots/png/bmi_before.png")
fig_before.write_image("plots/svg/bmi_before.svg")
fig_before.write_html("plots/html/bmi_before.html")

fig_after.write_image("plots/png/bmi_after.png")
fig_after.write_image("plots/svg/bmi_after.svg")
fig_after.write_html("plots/html/bmi_after.html")