```
This file is part of Estimation of Causal Effects in the Alzheimer's Continuum (Causal-AD).

Causal-AD is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

Causal-AD is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Causal-AD. If not, see <https://www.gnu.org/licenses/>.
```

# Generate Synthetic Outcome with Confounding

In [None]:
# define the parameters

data_path = "ukb_data_t.h5"
sparsity: float = 0.8
prob_event: float = 0.5
var_x: float = 0.4
var_z: float = 0.4
random_state: int = 1802080521
output_file: str = "data_generated.h5"

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from causalad.ukb.generate import ConfoundingGenerator
from causalad.ukb import io

sns.set(style="ticks")

%matplotlib inline

## Generate Outcome

In [None]:
data = io.load_patient_data(data_path)

In [None]:
gen = ConfoundingGenerator(
    data,
    sparsity=sparsity,
    prob_event=prob_event,
    var_x=var_x,
    var_z=var_z,
    random_state=random_state,
)

In [None]:
generated = gen.generate_outcome_with_site()

## Descriptive Statistics

Non-zero coefficients.

In [None]:
(generated.coef != 0).sum() / generated.coef.shape[0]

In [None]:
df = pd.concat((generated.outcome, generated.confounders, data.demographics.drop("AGE", axis=1)), axis=1)
df.loc[:, "AGE_cut"] = pd.qcut(df.AGE, [0, 0.25, 0.5, 0.75, 1])

Unobserved confounder clusters.

In [None]:
df.unobserved_confounder.value_counts(normalize=True)

Distribution of outcome per observed and unobserved confounder cluster.

In [None]:
summary = pd.crosstab(df.unobserved_confounder, df.AGE_cut)
summary /= generated.confounders.shape[0]
ax = summary.plot.bar(legend=False)

ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), title=str(summary.columns.names))
ax.yaxis.grid(True)

summary.round(3)

Distribution of outcome per Sex and unobserved confounder cluster.

In [None]:
def compute_summary_statistics(x):
    xx = x.drop("unobserved_confounder", axis=1)
    perc = xx.value_counts().rename("percentage") / x.shape[0]
    return perc


def compute_summary_statistics_per_group(data):
    df_noage = data.drop(["AGE", "AGE_cut"], axis=1)

    counts = []
    for grp_name, grp_df in df_noage.groupby("unobserved_confounder"):
        grp_counts = compute_summary_statistics(grp_df)
        idx = grp_counts.index.to_frame(index=False)
        idx.insert(0, "unobserved_confounder", grp_name)
        grp_counts.index = pd.MultiIndex.from_frame(idx)
        counts.append(grp_counts)

    counts = pd.concat(counts, axis=0)
    stats = counts.reset_index().pivot_table(
        index="unobserved_confounder", columns=["outcome", "SEX"], values="percentage"
    )
    return stats

In [None]:
summary = compute_summary_statistics_per_group(df)

_, ax = plt.subplots(figsize=(6, 4))
ax = summary.plot.bar(ax=ax, legend=False)
ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), title=str(summary.columns.names))
ax.yaxis.grid(True)

summary.round(3)

Distribution of outcome per Age quartile and unobserved confounder cluster.

In [None]:
summary = df.loc[:, ["AGE_cut", "outcome", "unobserved_confounder"]].groupby(
    "unobserved_confounder"
).apply(
    lambda x: x.drop("unobserved_confounder", axis=1).value_counts().rename("percentage") / x.shape[0]
).reset_index().pivot_table(
    index="unobserved_confounder", columns=["outcome", "AGE_cut"], values="percentage"
)

_, ax = plt.subplots(figsize=(6, 4))
ax = summary.plot.bar(ax=ax, legend=False, width=0.85)
ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), title=str(summary.columns.names))
ax.yaxis.grid(True)

summary.round(3)

## Write Data

In [None]:
io.write_synthetic_data(generated, output_file)