# Description of feature table

## Setup

In [None]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from src.utils import extract_color
from src.display_meta import display_diet_information
from src.viz_alpha_div import read_and_prep_abx_exposure_data

%load_ext autoreload
%autoreload 2
%matplotlib inline

plt.rcParams.update({"font.family": "DejaVu Sans", "font.size": 14})
plt.style.use("tableau-colorblind10")

In [None]:
# USER DEFINED variables
tag = "20240806"
tag_output = "entero_family"
path_to_data = "../data/final/"

# threshold for coloring microbiome samples after abx exposure
th_sample_after_abx_months = 1
# END USER DEFINED variables

In [None]:
# read processed feature table
path_to_ft = os.path.join(path_to_data, f"ft_vat19_anomaly_v{tag}_{tag_output}.tsv")
# read time-series of exact abx exposure data
path_to_abx = os.path.join(path_to_data, f"ts_vat19_abx_v{tag}.tsv")

# location to save all outputs of this notebook
path_to_output = os.path.join("../results", f"desc_{tag}_{tag_output}")
if not os.path.isdir(path_to_output):
    os.makedirs(path_to_output)

## Read datasets


In [None]:
# read and prep abx exposure data
abx_df = read_and_prep_abx_exposure_data(path_to_abx)

In [None]:
# read and prep metadata
md_df = pd.read_csv(path_to_ft, sep="\t", index_col=0)

# columns for below plots
md_df = md_df.assign(
    sample_lt_xm_after_abx=lambda df: df["abx_any_last_t_dmonths"]
    <= th_sample_after_abx_months,
    max_abx_w_microbiome=lambda df: df.groupby("host_id")["abx_any_cumcount"].transform(
        "max"
    ),
)
md_df.sort_values(
    [
        "abx_max_count_ever",
        "max_abx_w_microbiome",
        "host_id",
        "age_months_rounded05",
    ],
    ascending=[True, True, True, True],
    inplace=True,
)

In [None]:
# sort abx_df in same order and remove samples that don't exist in md_df
abx_events = pd.DataFrame()
abx_events["host_id"] = md_df["host_id"].unique()
abx_events = pd.merge(abx_events, abx_df, on="host_id", how="left")
del abx_df
assert abx_events.host_id.unique().tolist() == md_df.host_id.unique().tolist()

In [None]:
def filter_both_dfs(md_df, abx_events, condition):
    md_df_filtered = md_df.groupby("host_id").filter(condition).copy()
    abx_events_filtered = abx_events[
        abx_events["host_id"].isin(md_df_filtered.host_id.unique())
    ].copy()
    return md_df_filtered, abx_events_filtered


# separate abx and no abx
md_df_noabx = md_df[md_df["max_abx_w_microbiome"] == 0].copy()
abx_events_noabx = abx_events[
    abx_events["host_id"].isin(md_df_noabx.host_id.unique())
].copy()

abx_events_abx = abx_events[
    abx_events["host_id"].isin(
        md_df[md_df["max_abx_w_microbiome"] > 0].host_id.unique()
    )
].copy()

md_df_abx = md_df[md_df["max_abx_w_microbiome"] > 0].copy()
assert md_df_noabx.shape[0] + md_df_abx.shape[0] == md_df.shape[0]


# distinguish between invisible and visible abx with this threshold
invisible_condition = lambda x: all(~x["sample_lt_xm_after_abx"])
md_df_abx_invisible, abx_events_abx_invisible = filter_both_dfs(
    md_df_abx, abx_events, invisible_condition
)

visible_condition = lambda x: any(x["sample_lt_xm_after_abx"])
md_df_abx_visible, abx_events_abx_visible = filter_both_dfs(
    md_df_abx, abx_events, visible_condition
)

assert md_df_abx_invisible.shape[0] + md_df_abx_visible.shape[0] == md_df_abx.shape[0]

## Visualize overview of samples and abx events

In [None]:
# fraction of "visible" abx samples from all available microbiome samples
all_samples = md_df.shape[0]

frac_abx = 100 * (md_df["sample_lt_xm_after_abx"] == True).sum() / all_samples

print(f"Fraction of potentially abx influenced samples: {frac_abx:.2f} %")

In [None]:
hide_ylabel_thickmarks = True  # hiding tick labels of y-axis for slides
fonts = 14

# Make all text use the same fontsize
plt.rcParams.update(
    {
        "font.size": fonts,
        "axes.titlesize": fonts,
        "axes.labelsize": fonts,
        "xtick.labelsize": fonts,
        "ytick.labelsize": fonts,
        "legend.fontsize": fonts,
        "legend.title_fontsize": fonts,
    }
)

fig, axs = plt.subplots(1, 2, figsize=(14, 5), sharex=True, dpi=400)
markersize = 10
dic_to_plot = {
    "w/o": [md_df_noabx, abx_events_noabx],
    "with": [md_df_abx, abx_events_abx],
}
i = 0

for title, df in dic_to_plot.items():
    # samples
    sns.scatterplot(
        x="age_months_rounded05",
        y="host_id",
        hue="sample_lt_xm_after_abx",
        palette={True: "#f4a461", False: "#014487"},
        data=df[0],
        ax=axs[i],
        s=markersize,
    )
    # abx events
    sns.scatterplot(
        x="abx_start_age_months",
        y="host_id",
        data=df[1],
        ax=axs[i],
        s=markersize * 1.5,
        marker="x",
        color="darkred",
        label="abx event",
    )

    axs[i].set_title(f"Hosts {title} abx exposure ({df[0].host_id.nunique()})")
    axs[i].set_xlabel("Age [months]")
    axs[i].set_ylabel("Host ID$_{\\mathrm{no\\ abx}}$")
    axs[i].margins(y=0.005)
    axs[i].tick_params(axis="both", labelsize=fonts)

    if i != 0:
        axs[i].set_ylabel("Host ID$_{\\mathrm{abx}}$")
    if i != 1:
        axs[i].get_legend().remove()
    if hide_ylabel_thickmarks:
        axs[i].tick_params(left=False, labelleft=False)
    i += 1

axs[1].legend(
    loc="upper right",
    bbox_to_anchor=(1.6, 1),
    title=f" sample <={th_sample_after_abx_months}m after abx",
    prop={"size": fonts},
    title_fontsize=fonts,
)
plt.suptitle("Distribution of microbial samples over time", fontsize=fonts, y=1.0)
plt.tight_layout()
filename = os.path.join(
    path_to_output,
    f"overall_distribution_samples_t{hide_ylabel_thickmarks}.pdf",
)
plt.savefig(filename, dpi=400, bbox_inches="tight")
plt.show()

A cross-sectional only publication found that abx effect was only detectable until 30 days after exposure in infants [source](https://www.nature.com/articles/s41467-023-44289-6). So threshold <= 1 month seems most promising (but we need to keep in mind that our study is longitudinal - and might reveal more detailed dynamics).

Unique host counts for different thresholds:

| <= `x` m after abx | # no abx | # invisible abx | # visible abx | # detectable samples [1] |
|--------------------|----------|-----------------|---------------|--------------------------|
| 1                  | 140      | 56              | 85            | 169                      |
| 2                  | "        | 34              | 107           |                          |
| 3                  | "        | 26              | 115           |                          |
| 4                  | "        | 21              | 120           |                          |
| 5                  | "        | 19              | 122           |                          |
| 6                  | "        | 17              | 124           |                          |
| 9                  | "        | 6               | 135           |                          |
| 12                 | "        | 2               | 139           |                          |

[1] Detectable samples are samples where samples were collected <= bellow threshold but not in same month as abx was given:
````
df[np.logical_and(
    df["abx_any_last_t_dmonths"]<= th_sample_after_abx_months,
    df["abx_any_last_t_dmonths"]>0
)].shape[0]
````

In [None]:
# # number of unique samples that were collected <= 1m after abx exposure but not in same month
# for th in [1, 2, 3, 4, 5, 6, 9, 12]:
#     print(th)

#     print(df[np.logical_and(
#         df["abx_any_last_t_dmonths"]<= th,
#         df["abx_any_last_t_dmonths"]>0
#     )].shape[0])
#     print()

## Visualize distribution of samples available 

In [None]:
def plot_box_violin(y, color, ax, horizontal=False):
    violin = sns.violinplot(
        x=y if horizontal else None,
        y=y if not horizontal else None,
        inner=None,
        ax=ax,
        color=color,
        orient="h" if horizontal else "v",
    )
    box = sns.boxplot(
        x=y if horizontal else None,
        y=y if not horizontal else None,
        width=0.1,
        boxprops={"facecolor": "white", "edgecolor": "black", "zorder": 2},
        flierprops={
            "marker": "o",
            "markerfacecolor": "none",
            "markeredgecolor": "black",
        },
        ax=ax,
        orient="h" if horizontal else "v",
    )

In [None]:
plt.rcParams.update({"font.family": "DejaVu Sans", "font.size": 8})
fig, axs = plt.subplots(1, 2, figsize=(6, 3), sharey=True, dpi=300)

dic_to_plot = {
    "w/o abx": md_df_noabx,
    "w abx": md_df_abx,
    # 'with "invisible" abx': md_df_abx_invisible,
    # "with visible abx": md_df_abx_visible,
}
i = 0

for title, df in dic_to_plot.items():
    samples_per_host = df.groupby("host_id")["age_days"].agg("count")
    c = extract_color("tableau-colorblind10", 0)
    plot_box_violin(samples_per_host, c, axs[i])

    axs[i].set_title(title)
    axs[i].set_ylabel("")
    axs[i].set_ylim(bottom=0)
    i += 1
axs[0].set_ylabel("# samples per host", fontsize=10)
plt.suptitle("Number of samples per host", fontsize=12, y=1.0)
plt.tight_layout()
filename = os.path.join(path_to_output, "nb_samples_per_host.png")
plt.savefig(filename, dpi=400, bbox_inches="tight")
plt.show()

In [None]:
fonts = 14  # ensure consistent fontsize

plt.rcParams.update({"font.family": "DejaVu Sans", "font.size": fonts})
fig, axs = plt.subplots(3, 1, figsize=(5, 5), dpi=400)

# compute samples per host across all samples in md_df
samples_per_host_all = md_df.groupby("host_id")["age_days"].count()

# top subplot: samples per host (horizontal)
plot_box_violin(samples_per_host_all, "#ababab", axs[0], horizontal=True)
axs[0].set_title("# samples per host", fontsize=fonts)
axs[0].set_xlabel("", fontsize=fonts)
axs[0].set_ylabel("", fontsize=fonts)
axs[0].set_xlim(left=0)
axs[0].tick_params(axis="both", labelsize=fonts)

# middle subplot: days between samples per host (horizontal)
cols = ["host_id", "age_days"]
df_md_diff = md_df[cols].sort_values(cols)
df_md_diff["diff_age"] = df_md_diff[cols].sort_values(cols).groupby(["host_id"]).diff()

plot_box_violin(df_md_diff["diff_age"], "#844f7f", axs[1], horizontal=True)
axs[1].set_title("Duration between samples [days]", fontsize=fonts)
axs[1].set_xlabel("", fontsize=fonts)
axs[1].set_ylabel("", fontsize=fonts)
axs[1].set_xlim(left=0)
axs[1].tick_params(axis="both", labelsize=fonts)

# bottom subplot: age at first sample per host (horizontal)
first_sample_all = md_df[["age_days", "host_id"]].groupby("host_id").min()
plot_box_violin(first_sample_all["age_days"], "#6c9cc2", axs[2], horizontal=True)
axs[2].set_title("Age at first sample [days]", fontsize=fonts)
axs[2].set_xlabel("", fontsize=fonts)
axs[2].set_ylabel("", fontsize=fonts)
axs[2].set_xlim(left=0)
axs[2].tick_params(axis="both", labelsize=fonts)

plt.tight_layout()

# save
filename = os.path.join(path_to_output, "nb_samples_per_host_all_combined.pdf")
plt.savefig(filename, dpi=400, bbox_inches="tight")
plt.show()

# summary statistics of intervals
df_md_diff["diff_age"].describe()

## Visualize 1st, 2nd and 3rd abx exposure

In [None]:
# how many samples are there in visible abx hosts with sample after 1st infant's abx exposure?
ls_cols = [
    "host_id",
    "age_months_rounded05",
    "abx_any_cumcount",
    "sample_lt_xm_after_abx",
    "abx_any_last_t_dmonths",
]
all_abx = md_df_abx.host_id.nunique()

for i in [1.0, 2.0, 3.0]:
    # count the number of hosts with a microbial sample after their first abx exposure
    first_abx_exp_sample = (md_df_abx["abx_any_cumcount"] == i) & (
        md_df_abx_visible["sample_lt_xm_after_abx"] == True
    )
    num_hosts = md_df_abx.loc[
        first_abx_exp_sample,
        "host_id",
    ].nunique()

    print(
        f"Of the \033[1m{all_abx}\033[0m hosts with abx exposure,"
        f" \033[1m{round(100*num_hosts/all_abx,1)} % ({num_hosts}\033[0m)"
        f" have a sample {th_sample_after_abx_months} month after {i}-th abx exposure."
    )

    # What's the average age of hosts at nth abx exposure?
    fig, ax = plt.subplots(figsize=(6, 1), dpi=400)

    first_abx = (
        md_df_abx.loc[first_abx_exp_sample, ["host_id", "age_months_rounded05"]]
        .groupby("host_id")
        .first()
    )
    print(first_abx.shape)
    print(f"Mean age: {first_abx['age_months_rounded05'].mean():.2f} months")
    plot_box_violin(first_abx["age_months_rounded05"], c, ax, horizontal=True)
    ax.set_xlabel("Age [months]", fontsize=6)
    ax.tick_params(axis="x", labelsize=6)
    ax.set_xlim(-0.5, 38.5)
    if i == 1:
        suff = "1st"
    elif i == 2:
        suff = "2nd"
    elif i == 3:
        suff = "3rd"
    ax.set_title(f"Age at {suff} abx exposure", fontsize=7)
    plt.show()

## Display diet covariate

In [None]:
ax_weaning = display_diet_information(
    md_df, "diet_weaning", "age_months_rounded1", "samples"
)
plt.savefig(
    os.path.join(path_to_output, "diet_weaning_over_time.pdf"),
    dpi=400,
    bbox_inches="tight",
)

In [None]:
ax_milk = display_diet_information(md_df, "diet_milk", "age_months_rounded1", "samples")
plt.savefig(
    os.path.join(path_to_output, "diet_milk_over_time.pdf"),
    dpi=400,
    bbox_inches="tight",
)

## Check delivery mode and geolocation counts


In [None]:
md_df[["host_id", "delivery_mode"]].drop_duplicates().groupby("delivery_mode").count()

In [None]:
md_df[["host_id", "geo_location_name"]].drop_duplicates().groupby(
    "geo_location_name"
).count()

## Check duration between samples (all hosts)

In [None]:
# get number of days between samples of an infant:
cols = ["host_id", "age_days"]
df_md_diff = md_df[cols].sort_values(cols)
df_md_diff["diff_age"] = df_md_diff[cols].sort_values(cols).groupby(["host_id"]).diff()

fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(3, 5))
c = extract_color("tableau-colorblind10", 0)
plot_box_violin(df_md_diff["diff_age"], c, ax)

ax.set_ylabel("Days between samples per host", fontsize=12)
ax.set_ylim(0)
c_thirty = extract_color("tableau-colorblind10", 1)
ax.axhline(y=30, color=c_thirty)
ax.text(x=-0.48, y=33, s="30", color=c_thirty)
plt.show()

df_md_diff["diff_age"].describe()

## Derive time of initial sample per "noabx" host


(originally used to define imputation for initial samples of noabx hosts)

### Time of initial sample per "noabx" host

In [None]:
first_sample = md_df_noabx[["age_days", "host_id"]].groupby("host_id").min()
first_sample.describe()

In [None]:
first_sample.hist(bins=100)
plt.axvline(x=42, color=c_thirty)
plt.ylabel("Number of hosts")
plt.xlabel("Age [days]")
plt.title("Age of first sample per noabx host")

In [None]:
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(3, 5))
c = extract_color("tableau-colorblind10", 0)
plot_box_violin(first_sample["age_days"], c, ax)
ax.set_ylabel("First sample per host [age in days]", fontsize=12)
ax.set_ylim(0)

### Noabx hosts that have samples prior or at 42 days 

(originally used to suggest new t0)

In [None]:
early_noabx_samples = md_df_noabx.loc[
    md_df_noabx["age_days"] <= (35 + 7), ["host_id", "age_days"]
]
early_noabx_samples

In [None]:
# how many unique noabx hosts do we have <= 42 days?
init_hosts = md_df_noabx.loc[md_df_noabx["age_days"] <= (35 + 7), "host_id"].unique()
len(init_hosts)

In [None]:
# delivery mode of noabx hosts with samples at t0=42 days
init_host_delmode = md_df_noabx.loc[
    md_df_noabx["host_id"].isin(init_hosts), ["host_id", "delivery_mode"]
].drop_duplicates()
init_host_delmode["delivery_mode"].value_counts()

### Noabx host samples missed with t0=42 days

In [None]:
# how many unique abx hosts do we have < 42 days?
# -> we would miss those by setting t0=42 days
abx_init_hosts = md_df_abx.loc[md_df_abx["age_days"] < (28 + 2 * 7), "host_id"].unique()
len(abx_init_hosts)

### Abx hosts with antibiotics exposure at birth

In [None]:
cond_birth_abx = (
    md_df_abx["age_months_rounded05"] - md_df_abx["abx_any_last_t_dmonths"]
) == 0.0

hosts_wabx_at_birth = md_df_abx.loc[cond_birth_abx, "host_id"].unique()
print(f"# hosts with ABX exposure at birth: {len(hosts_wabx_at_birth)}")

In [None]:
md_df_abx.loc[cond_birth_abx, "delivery_mode"].unique()

All of these hosts were vaginally born.

In [None]:
# do we have early microbial samples for these hosts at birth?
# -> only 1 host has sample at t0=42 days

md_df_abx.loc[
    md_df_abx["host_id"].isin(hosts_wabx_at_birth),
    ["host_id", "age_days", "abx_max_count_ever", "delivery_mode"],
].groupby("host_id").min()