In [None]:
import pandas as pd
import numpy as np
from scipy.stats import spearmanr
import os

# === 1. Load fungal and bacterial data ===
fungi = pd.read_csv("data/fungi_OTU_table_subset.csv", sep=";")
bacteria = pd.read_csv("data/bacteria_OTU_table_subset", sep=";")

for df in (fungi, bacteria):
    df.columns = df.columns.str.strip().str.lower()

sample_cols = [
    'n1b','n1s','n2b','n2s','n3s','n3b','n4b','n4s',
    'n5b','n5s','n6b','n6s','n7b','n7s'
]

# === 2. Centered log-ratio (CLR) transformation ===
def clr_transform(df):
    df_pseudo = df.copy()
    df_pseudo[sample_cols] += 1e-6  # pseudocount to avoid log(0)
    gm = np.exp(np.log(df_pseudo[sample_cols]).mean(axis=1))
    df_clr = df_pseudo.copy()
    df_clr[sample_cols] = np.log(df_pseudo[sample_cols].div(gm, axis=0))
    return df_clr

# === 3. Aggregate by genus and apply CLR ===
fungi_genus = fungi.groupby('genus')[sample_cols].sum()
bacteria_genus = bacteria.groupby('genus')[sample_cols].sum()

fungi_clr = clr_transform(fungi_genus.reset_index()).set_index('genus')
bacteria_clr = clr_transform(bacteria_genus.reset_index()).set_index('genus')

# === 4. Compute cross-kingdom correlations ===
def compute_edges(fungi_df, bacteria_df, corr_thresh=0.5, pval_thresh=0.05):
    records = []
    for fg in fungi_df.index:
        x = fungi_df.loc[fg, sample_cols]
        for bc in bacteria_df.index:
            y = bacteria_df.loc[bc, sample_cols]
            r, p = spearmanr(x, y, nan_policy='omit')
            if not pd.isna(r):
                records.append({
                    'source': fg,
                    'target': bc,
                    'correlation': r,
                    'pval': p,
                    'interaction_type': 'Positive' if r > 0 else 'Negative'
                })
    df = pd.DataFrame(records)
    df_sig = df[(df['correlation'].abs() >= corr_thresh) & (df['pval'] < pval_thresh)]
    return df, df_sig

edges_all, edges_sig = compute_edges(fungi_clr, bacteria_clr)

# === 5. Generate node table ===
nodes_fungi = pd.DataFrame({
    'id': fungi_genus.sum(axis=1).index,
    'abundance': fungi_genus.sum(axis=1).values,
    'type': 'Fungus'
})
nodes_bacteria = pd.DataFrame({
    'id': bacteria_genus.sum(axis=1).index,
    'abundance': bacteria_genus.sum(axis=1).values,
    'type': 'Bacterium'
})
nodes = pd.concat([nodes_fungi, nodes_bacteria], ignore_index=True)

# === 6. Extract top 20 most abundant nodes ===
top20_ids = nodes.sort_values(by='abundance', ascending=False).head(20)['id']
edges_sig_top20 = edges_sig[
    edges_sig['source'].isin(top20_ids) & edges_sig['target'].isin(top20_ids)
]
nodes_top20 = nodes[nodes['id'].isin(top20_ids)]

# === 7. Save outputs ===
out_dir = "output/network_crosskingdom"
os.makedirs(out_dir, exist_ok=True)

edges_all.to_csv(f"{out_dir}/edges_all.csv", index=False)
edges_sig.to_csv(f"{out_dir}/edges_significant.csv", index=False)
nodes.to_csv(f"{out_dir}/nodes.csv", index=False)

edges_sig_top20.to_csv(f"{out_dir}/edges_top20.csv", index=False)
nodes_top20.to_csv(f"{out_dir}/nodes_top20.csv", index=False)

print(f"Network exported: total edges={len(edges_all)}, significant={len(edges_sig)}, top20={len(edges_sig_top20)}")
