# Construct tables

In [None]:
import os
import csv
import functools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
FILTER_READ_STEALERS = True
CLOSE_PLOTS = True

In [None]:
OUTDIRBASE = "out/res_optimal"

OUTDIR = f"{OUTDIRBASE}/tables"
IMGDIR = f"{OUTDIRBASE}/images"

os.makedirs(OUTDIR, exist_ok=True)
os.makedirs(IMGDIR, exist_ok=True)

In [None]:
SOILS = [
    'Soil3', 'Soil5', 'Soil6', 'Soil9', 'Soil11', 
    'Soil12', 'Soil14', 'Soil15', 'Soil16', 'Soil17'
]

In [None]:
READ_STEALERS_FPATH = "read_stealers.tsv"
df_read_stealers = pd.read_csv(
    READ_STEALERS_FPATH,
    sep="\t"
)
df_read_stealers

In [None]:
DMD_TOPHITS_FPATH = f"{OUTDIRBASE}/dmnd_combined_top_hits.tsv"
df_dmd_tophits = pd.read_csv(
    DMD_TOPHITS_FPATH,
    sep="\t",
)
df_dmd_tophits

In [None]:
KO_INFO_FPATH = "data/ko_information.tsv"
DF_KO_INFO = pd.read_csv(
    KO_INFO_FPATH, sep="\t", index_col=0
)

def _mapfunc(s):
    s = s.lower()
    if "nitrate reductase" in s:
        return "nitrate reductase"
    elif "nitrite reductase" in s:
        return "nitrite reductase"
    elif "nitric oxide reductase" in s:
        return "nitric oxide reductase"
    elif "nitrous oxide reductase" in s or "nitrous-oxide reductase" in s:
        return "nitrous oxide reductase"
    elif "hydroxylamine reductase" in s:
        return "hydroxylamine reductase"
    else:
        return "other"
    
DF_KO_INFO["category"] = DF_KO_INFO["NAME"].apply(_mapfunc)

KO_LIST = list(DF_KO_INFO.index)

DF_KO_INFO

In [None]:
KO_CATEGORIES = sorted(DF_KO_INFO["category"].unique())
KO_CATEGORY_SETS = {
    cat: DF_KO_INFO[DF_KO_INFO["category"] == cat].index.values
    for cat in KO_CATEGORIES
}
for k, v in KO_CATEGORY_SETS.items():
    print(f"{k}:\n{v}\n")


In [None]:
TAXA_INFO_FPATH = "data/taxid_to_scaffold.csv"

TAXA_DF = pd.read_csv(TAXA_INFO_FPATH)
TAXA_LIST = list(TAXA_DF[~pd.isna(TAXA_DF["taxid"])]["taxid"].values)

TAXA_DF


In [None]:
COVERAGE_DIR = "data/coverage_arrays"

coverage_filelist = os.listdir(COVERAGE_DIR)
coverage_filelist = [f for f in coverage_filelist if f.endswith(".npz")]

print(f"coverage files ({(len(coverage_filelist))}):", coverage_filelist)

In [None]:
sample_ids = [
    covfile.removeprefix("coverage_arrays_").removesuffix(".npz")
    for covfile in coverage_filelist
]
print(f"sample ids ({(len(sample_ids))}):", sample_ids)

### Accepted Reads

In [None]:
KO_TABLE_DIR = "data/ko_tables"

data_rows = {}
for f in [f for f in os.listdir(KO_TABLE_DIR) if f.endswith(".tsv")]:
    with open(f"{KO_TABLE_DIR}/{f}", "r") as f:
        csvreader = csv.reader(f, delimiter="\t")
        header = next(csvreader)
        for row in csvreader:  # process each row
            sample_id, ko, avg_depth = row[0:3]
            if sample_id not in data_rows:
                data_rows[sample_id] = {}
            if ko in KO_LIST:
                data_rows[sample_id][ko] = avg_depth

# Convert nested dict â†’ DataFrame
df_full = pd.DataFrame.from_dict(data_rows, orient="index", dtype=float)

# Optional: ensure columns follow KO_LIST order
df_full = df_full.reindex(columns=KO_LIST)

# Ensure all sample_ids appear as rows
all_sample_ids = set(df_full.index.values)
missing_samples = [s for s in all_sample_ids if s not in sample_ids]
df_full = df_full.reindex(all_sample_ids)
# Fill nan values
df_full = df_full.fillna(0.)

print(len(df_full))
df_full = df_full.drop(index=missing_samples)  # DROP PROBLEM SAMPLE

print(len(df_full))
for sample_id in missing_samples:
    all_sample_ids.remove(sample_id)
print(f"Dropped {missing_samples}")

# Add boolean screens for CHL+/-, T0/T9, Nitrate/No_Nitrate
screen_df = pd.DataFrame({
    "no_nitrate": df_full.index.str.contains("No_Nitrate"),
    "nitrate": ~df_full.index.str.contains("No_Nitrate"),
    "t0": df_full.index.str.contains("T0"),
    "t9": df_full.index.str.contains("T9"),
    "chl_pos": df_full.index.str.contains("CHL"),
    "chl_neg": df_full.index.str.contains("None"),
}, index=df_full.index)

# Prepend these to the KO columns
nscreens = screen_df.shape[1]
print(f"{nscreens} screen columns")
df_full = pd.concat([screen_df, df_full], axis=1)
df_full

In [None]:
##############################################################################
##  Construct sample subsets satisfying conditions of interest

sample_subsets = {}
print(f"{len(all_sample_ids)} total samples")

sample_subsets["no_nitrate"] = sorted(list(filter(
    lambda s: "No_Nitrate" in s, 
    all_sample_ids
)))
print(f"{len(sample_subsets["no_nitrate"])} no-nitrate samples")

sample_subsets["t0_samples"] = sorted(list(filter(
    lambda s: ("T0" in s) and ("No_Nitrate" not in s),
    all_sample_ids
)))
print(f"{len(sample_subsets["t0_samples"])} T0 samples")

sample_subsets["chl_pos_samples"] = sorted(list(filter(
    lambda s: ("T9" in s) and ("No_Nitrate" not in s) and ("CHL" in s),
    all_sample_ids
)))
print(f"{len(sample_subsets["chl_pos_samples"])} T9 CHL+ samples")

sample_subsets["chl_neg_samples"] = sorted(list(filter(
    lambda s: ("T9" in s) and ("No_Nitrate" not in s) and ("None" in s),
    all_sample_ids
)))
print(f"{len(sample_subsets["chl_neg_samples"])} T9 CHL- samples")

assert len(all_sample_ids) == len(functools.reduce(
    lambda x, y: x | y, (set(v) for v in sample_subsets.values())))

In [None]:
##############################################################################
##  Dataframe subsets

DF_SAMP_SUBSETS_ACCEPTED = {}
for k in ["t0_samples", "chl_pos_samples", "chl_neg_samples", "no_nitrate"]:
    DF_SAMP_SUBSETS_ACCEPTED[k] = df_full.loc[sample_subsets[k],:]    

In [None]:
DF_SAMP_SUBSETS_ACCEPTED["t0_samples"]

In [None]:
DF_SAMP_SUBSETS_ACCEPTED["chl_pos_samples"]

In [None]:
DF_SAMP_SUBSETS_ACCEPTED["chl_neg_samples"]

In [None]:
DF_SAMP_SUBSETS_ACCEPTED["no_nitrate"]

## Tables for rejected reads

### Aggregate over taxa

In [None]:
KO_REJECTED_SUBSETS_DIR = f"{OUTDIRBASE}/ko_expression"

contam_dfs = {}
for sample_id, fname in zip(sample_ids, coverage_filelist):
    fpath = os.path.join(KO_REJECTED_SUBSETS_DIR, fname)
    fpath = fpath.replace("coverage_arrays_", "coverage_")
    fpath = fpath.replace(".npz", ".csv")
    df = pd.read_csv(os.path.join(fpath))
    if FILTER_READ_STEALERS:
        df = df[df["name"].isin(df_read_stealers["sseqid"])]
    contam_dfs[sample_id] = df


In [None]:
df

In [None]:
DF_SAMP_SUBSETS_REJECTED = {}
for key in DF_SAMP_SUBSETS_ACCEPTED:
    print(key)
    df_rej = DF_SAMP_SUBSETS_ACCEPTED[key].copy()
    df_rej.iloc[:,nscreens:] = np.nan
    ko_values_rej = {
        sample_id: df.groupby("ko")["avg_depth"].sum()
        for sample_id, df in contam_dfs.items()
    }

    for sample_id, series in ko_values_rej.items():
        if sample_id in df_rej.index:
            for ko, value in series.items():
                if ko in df_rej.columns:
                    df_rej.at[sample_id, ko] = value
    
    DF_SAMP_SUBSETS_REJECTED[key] = df_rej


In [None]:
DF_SAMP_SUBSETS_REJECTED["t0_samples"]

In [None]:
DF_SAMP_SUBSETS_REJECTED["chl_pos_samples"]

In [None]:
DF_SAMP_SUBSETS_REJECTED["chl_neg_samples"]

In [None]:
DF_SAMP_SUBSETS_REJECTED["no_nitrate"]

## Plots and tables

In [None]:
def make_violin_plot(
        df_subsets, 
        ko_set,
        keys,
        ko_labels=None,
        width=None,
        spacing=None,
        gap=None,
        margin=0,
        legend=True,
        legend_labels=None, 
        colors=None,
        alpha=None,
        hatch=None,
        ax=None,
        verbosity=1,
        **kwargs
):
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=kwargs.get("figsize", (8, 10)))

    n = len(ko_set)
    k = len(keys)

    width = width if width else 0.9 * (1 / k) # width of violin
    d = spacing if spacing else 0.01 * (1 / k)  # spacing between violins
    gap = gap if gap else 0.5 * (1 / k)  # gap between KO groups

    legend_handles = []
    legend_labels = keys if legend_labels is None else legend_labels
    ko_group_width = k * width + (k - 1) * d
    for i, key in enumerate(keys):
        df = df_subsets[key]
        data = [np.log10(1+df[col]) for col in ko_set]
        pos = margin + np.arange(n) * (ko_group_width + gap) + i * (width + d)
        li = ax.violinplot(
            data, pos, 
            orientation="horizontal",
            # showmeans=True, 
            showmedians=True, 
            showextrema=True,
            widths=width,
        )
        if isinstance(colors, list):
            c = colors[i]
            for body in li["bodies"]:
                body.set_facecolor(c)
                body.set_edgecolor(c)
            for partname in ['cbars','cmins','cmaxes','cmeans','cmedians']:
                if partname in li:
                    vp = li[partname]
                    vp.set_edgecolor(c)
        
        for body in li["bodies"]:
            if hatch:
                body.set_hatch(hatch)
            if alpha:
                body.set_alpha(alpha)
        
        legend_handles.append(li["bodies"][0])
    
    if legend:
        if ax.get_legend() is None:
            old_handles, old_labels = [], []
        else:
            old_handles = ax.get_legend().legend_handles
            old_labels = [t.get_text() for t in ax.get_legend().texts]
        legend_handles.reverse()
        legend_labels.reverse()
        all_handles = old_handles + legend_handles
        all_labels = old_labels + legend_labels
        ax.legend(
            all_handles, all_labels,
            bbox_to_anchor=(1.05, 1),
            loc="upper left"
        )

    # Set xticks
    xticklabels = np.array(
        [10**i for i in range(0, int(np.ceil(ax.get_xlim()[1])))]
    )
    ax.set_xticks(np.log10(1 + xticklabels), labels=xticklabels)

    # Set yticks
    yticks = margin + np.arange(n) * (ko_group_width + gap) + (ko_group_width - width) / 2 
    ax.set_yticks(
        yticks, 
        labels=ko_labels,
    )

    # Add labels and title
    ax.set_xlabel("total avg depth")
    ax.set_ylabel("KO")
    ax.set_title("")
    ax.set_ylim(-gap, pos.max() + gap)
    
    return ax



#### Accepted reads

In [None]:
colors_acc = ["g", "b", "r"]
colors_rej = ["brown", "cyan", "orange"]

alpha_acc = None
alpha_rej = None

hatch_acc = None
hatch_rej = "///"

In [None]:
for key in DF_SAMP_SUBSETS_ACCEPTED:
    df = DF_SAMP_SUBSETS_ACCEPTED[key]
    df.iloc[:,nscreens:].to_csv(
        f"{OUTDIR}/{key}_accepted.csv", float_format="%.3f", na_rep=np.nan
    )

In [None]:
##############################################################################
##  Violin plots (ACCEPTED)

for category in KO_CATEGORIES:
    ko_set = KO_CATEGORY_SETS[category]
    ko_labels = [ko + "\n" + DF_KO_INFO.loc[ko, "SYMBOL"] for ko in ko_set]
    keys = ["t0_samples", "chl_pos_samples", "chl_neg_samples",]

    fig, ax = plt.subplots(1, 1, figsize=[8,10])

    make_violin_plot(
        DF_SAMP_SUBSETS_ACCEPTED, ko_set, keys, 
        ko_labels=ko_labels,
        ax=ax,
        colors=colors_acc,
        alpha=alpha_acc,
        hatch=hatch_acc,
    )

    ax.set_title(f"{category} (accepted)")

    saveas = f"{IMGDIR}/{category.replace(" ", "_")}_accepted.png"
    print(f"Saving {saveas}")
    plt.savefig(saveas, bbox_inches="tight")
    if CLOSE_PLOTS:
        plt.close()

plt.show()

#### Rejected reads

In [None]:
for key in DF_SAMP_SUBSETS_REJECTED:
    df = DF_SAMP_SUBSETS_REJECTED[key]
    df.iloc[:,nscreens:].to_csv(
        f"{OUTDIR}/{key}_rejected.csv", float_format="%.3f", na_rep=np.nan
    )

In [None]:
##############################################################################
##  Violin plots (REJECTED)

for category in KO_CATEGORIES:
    ko_set = KO_CATEGORY_SETS[category]
    ko_labels = [ko + "\n" + DF_KO_INFO.loc[ko, "SYMBOL"] for ko in ko_set]
    keys = ["t0_samples", "chl_pos_samples", "chl_neg_samples",]

    fig, ax = plt.subplots(1, 1, figsize=[8,10])

    make_violin_plot(
        DF_SAMP_SUBSETS_REJECTED, ko_set, keys, 
        ko_labels=ko_labels,
        ax=ax,
        colors=colors_rej,
        alpha=alpha_rej,
        hatch=hatch_rej,
    )

    ax.set_title(f"{category} (rejected)")
    
    saveas = f"{IMGDIR}/{category.replace(" ", "_")}_rejected.png"
    print(f"Saving {saveas}")
    plt.savefig(saveas, bbox_inches="tight")
    if CLOSE_PLOTS:
        plt.close()

plt.show()

In [None]:
DF_SAMP_SUBSETS_REJECTED["no_nitrate"]

### Combined

In [None]:
DF_SAMP_SUBSETS_REJECTED.keys()

In [None]:
##############################################################################
##  Violin plots (COMBINED)

for category in KO_CATEGORIES:
    ko_set = KO_CATEGORY_SETS[category]
    ko_labels = [ko + "\n" + DF_KO_INFO.loc[ko, "SYMBOL"] for ko in ko_set]
    keys = ["t0_samples", "chl_pos_samples", "chl_neg_samples",]

    fig, ax = plt.subplots(1, 1, figsize=(8,10))

    make_violin_plot(
        DF_SAMP_SUBSETS_ACCEPTED, ko_set, keys, 
        ko_labels=ko_labels,
        ax=ax,
        legend=True,
        legend_labels=[k + " (accepted)" for k in keys],
        colors=colors_acc,
        alpha=alpha_acc,
        hatch=hatch_acc,
    )

    make_violin_plot(
        DF_SAMP_SUBSETS_REJECTED, ko_set, keys, 
        ko_labels=ko_labels,
        ax=ax,
        legend=True,
        legend_labels=[k + " (rejected)" for k in keys],
        colors=colors_rej,
        alpha=alpha_rej,
        hatch=hatch_rej,
    )

    ax.set_title(f"{category} (accepted vs rejected)")

    saveas = f"{IMGDIR}/{category.replace(" ", "_")}_comparison.png"
    print(f"Saving {saveas}")
    plt.savefig(saveas, bbox_inches="tight")
    if CLOSE_PLOTS:
        plt.close()

plt.show()

# Disambiguate by taxa

In [None]:
DF_SAMP_SUBSETS_REJECTED_BY_TAXA = {}
for taxid in TAXA_LIST:
    DF_SAMP_SUBSETS_REJECTED_BY_TAXA[taxid] = {}
    for key in DF_SAMP_SUBSETS_ACCEPTED:    
        df_rej = DF_SAMP_SUBSETS_ACCEPTED[key].copy()
        df_rej.iloc[:,nscreens:] = np.nan
        ko_values_rej = {
            sample_id: df[df["taxid"] == taxid].groupby("ko")["avg_depth"].sum()
            for sample_id, df in contam_dfs.items()
        }

        for sample_id, series in ko_values_rej.items():
            if sample_id in df_rej.index:
                for ko, value in series.items():
                    if ko in df_rej.columns:
                        df_rej.at[sample_id, ko] = value
        DF_SAMP_SUBSETS_REJECTED_BY_TAXA[taxid][key] = df_rej


In [None]:
suboutdir = f"{OUTDIR}/by_taxa"
os.makedirs(suboutdir, exist_ok=True)

for taxid in TAXA_LIST:
    dftaxa = DF_SAMP_SUBSETS_REJECTED_BY_TAXA[taxid]
    for key in dftaxa:
        df = dftaxa[key]
        df.iloc[:,nscreens:].to_csv(
            f"{suboutdir}/{key}_{taxid}_rejected.csv", 
            float_format="%.3f", na_rep=np.nan
        )

In [None]:
##############################################################################
##  Violin plots (COMPARISON, DISAMBIGUATED)

subimgdir = f"{IMGDIR}/by_taxa"
os.makedirs(subimgdir, exist_ok=True)

for category in KO_CATEGORIES:
    ko_set = KO_CATEGORY_SETS[category]
    ko_labels = [ko + "\n" + DF_KO_INFO.loc[ko, "SYMBOL"] for ko in ko_set]
    keys = ["t0_samples", "chl_pos_samples", "chl_neg_samples",]

    for taxid in TAXA_LIST:
        fig, ax = plt.subplots(1, 1, figsize=(8,10))

        make_violin_plot(
            DF_SAMP_SUBSETS_ACCEPTED, ko_set, keys, 
            ko_labels=ko_labels,
            ax=ax,
            legend=True,
            legend_labels=[k + " (accepted)" for k in keys],
            colors=colors_acc,
            alpha=alpha_acc,
            hatch=hatch_acc,
        )

        make_violin_plot(
            DF_SAMP_SUBSETS_REJECTED_BY_TAXA[taxid], ko_set, keys, 
            ko_labels=ko_labels,
            ax=ax,
            legend=True,
            legend_labels=[k + " (rejected)" for k in keys],
            colors=colors_rej,
            alpha=alpha_rej,
            hatch=hatch_rej,
        )

        spec = TAXA_DF[TAXA_DF["taxid"] == taxid]["species"].values[0]
        ax.set_title(f"{category} (accepted vs rejected)\n {taxid} ({spec})")

        saveas = f"{subimgdir}/{category.replace(" ", "_")}_comparison_{taxid}.png"
        print(f"Saving {saveas}")
        plt.savefig(saveas, bbox_inches="tight")
        if CLOSE_PLOTS:
            plt.close()

plt.show()