# Gut Microbiome Disease Prediction - Streamlined Pipeline


In [None]:
import os, sys, subprocess, math
import numpy as np, pandas as pd
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
from scipy.stats import mannwhitneyu
from scipy import stats
from skbio import DistanceMatrix
from skbio.stats.distance import permanova
from sklearn.decomposition import PCA
from sklearn.model_selection import (StratifiedKFold, cross_validate,
                                     train_test_split)
from sklearn.linear_model import LogisticRegressionCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (roc_auc_score, roc_curve, auc,
                             make_scorer, balanced_accuracy_score)
from statsmodels.stats.multitest import multipletests
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

sns.set(style="whitegrid", font_scale=1.05)
np.random.seed(2025)


# 1. Setup & Data Loading

In [None]:
# Create directories
data_dir = "data"
RESULTS = os.path.join(data_dir, "results_EDA")
FIGS = os.path.join(RESULTS, "figures")
TABLES = os.path.join(RESULTS, "tables")
os.makedirs(FIGS, exist_ok=True)
os.makedirs(TABLES, exist_ok=True)

# Load data
ab_raw = pd.read_csv(os.path.join(data_dir, "MetaCardis2020_relative_abundance.csv"), index_col=0)
meta = pd.read_csv(os.path.join(data_dir, "MetaCardis2020_sample_metadata.csv"), index_col=0)
taxa_meta = pd.read_csv(os.path.join(data_dir, "MetaCardis2020_taxa_metadata.csv"), index_col=0)

print("Data shapes:", ab_raw.shape, meta.shape, taxa_meta.shape)


# 2. Data Preprocessing

In [None]:
# Transpose and align
ab = ab_raw.T.copy()
common = ab.index.intersection(meta.index)
ab = ab.loc[common]
meta = meta.loc[common]

# Create health status labels
def map_health(row):
    cond = str(row['study_condition']).lower()
    dis = str(row['disease']).lower()
    return 'Healthy' if ('control' in cond) or ('healthy' in dis) else 'Disease'

meta['health_status'] = meta.apply(map_health, axis=1)
meta = meta.dropna(subset=['health_status'])
ab = ab.loc[meta.index]

# Clean antibiotic data
if 'antibiotics_current_use' in meta.columns:
    def clean_antibiotic(x):
        x = str(x).lower().strip()
        if x == 'nan': return 'unknown'
        elif x.startswith('y'): return 'yes'
        elif x.startswith('n'): return 'no'
        else: return 'unknown'
    meta['antibiotic_use'] = meta['antibiotics_current_use'].apply(clean_antibiotic)
else:
    meta['antibiotic_use'] = 'unknown'

# Filter low-prevalence taxa (≥1% samples)
n_samples = ab.shape[0]
prevalence = (ab > 0).sum(axis=0)
min_samples = math.ceil(0.01 * n_samples)
keep_taxa = prevalence[prevalence >= min_samples].index
ab_filt = ab[keep_taxa].copy()

# Normalize to relative abundance
ab_rel = ab_filt.div(ab_filt.sum(axis=1), axis=0).fillna(0)
zero_sum = ab_rel.sum(axis=1) == 0
if zero_sum.sum() > 0:
    ab_rel = ab_rel.drop(index=ab_rel.index[zero_sum])
    meta = meta.loc[ab_rel.index]

print(f"Final dataset: {ab_rel.shape[0]} samples × {ab_rel.shape[1]} taxa")


# 3. Alpha Diversity Analysis

In [None]:
def shannon(p):
    p = p[p > 0]
    return -(p * np.log(p)).sum()

alpha = pd.DataFrame(index=ab_rel.index)
alpha['richness'] = (ab_rel > 0).sum(axis=1)
alpha['shannon'] = ab_rel.apply(shannon, axis=1)
alpha = alpha.join(meta[['health_status', 'antibiotic_use']])

# Visualize
plt.figure(figsize=(6,5))
sns.violinplot(x='health_status', y='shannon', data=alpha, inner=None, palette='Set2')
sns.boxplot(x='health_status', y='shannon', data=alpha, width=0.15,
            showcaps=True, boxprops={'zorder':2}, fliersize=2)
plt.title("Shannon Diversity: Healthy vs Disease")
plt.ylabel("Shannon Index")
plt.tight_layout()
plt.savefig(os.path.join(FIGS, "shannon_diversity.png"), dpi=300, bbox_inches='tight')
plt.show()

healthy = alpha.loc[alpha['health_status'] == 'Healthy', 'shannon']
disease = alpha.loc[alpha['health_status'] == 'Disease', 'shannon']
u_stat, p_val = mannwhitneyu(healthy, disease, alternative='two-sided')
print(f"Mann–Whitney test: p = {p_val:.3e}, U = {u_stat}")


# 4. CLR Transformation & Beta Diversity

In [None]:
# CLR transform
def clr_transform(df, pseudo=1e-6):
    X = df.values + pseudo
    logX = np.log(X)
    gm = logX.mean(axis=1, keepdims=True)
    clr = logX - gm
    return pd.DataFrame(clr, index=df.index, columns=df.columns)

ab_clr = clr_transform(ab_rel)
ab_clr.to_csv(os.path.join(TABLES, "abundance_clr.csv"))

# PCA
pca = PCA(n_components=2, random_state=2025)
pcs = pca.fit_transform(ab_clr.values)
pcs_df = pd.DataFrame(pcs, index=ab_clr.index, columns=['PC1', 'PC2'])
pcs_df = pcs_df.join(meta[['health_status', 'antibiotic_use']])

plt.figure(figsize=(7,5))
sns.scatterplot(data=pcs_df, x='PC1', y='PC2', hue='health_status',
                style='antibiotic_use', palette='Set2', s=60,
                edgecolor='black', alpha=0.8)
plt.title("PCA (CLR-transformed)")
plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% var)")
plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% var)")
plt.tight_layout()
plt.savefig(os.path.join(FIGS, "pca_clr.png"), dpi=300, bbox_inches='tight')
plt.show()

# PERMANOVA
clr_dist = pdist(ab_clr.values, metric='euclidean')
clr_DM = squareform(clr_dist)
dm_clr = DistanceMatrix(clr_DM, ids=list(ab_clr.index))
res_health = permanova(dm_clr, pcs_df['health_status'], permutations=999)
print("\nPERMANOVA (Aitchison) ~ health_status")
print(res_health)


# 5. Disease Group Classification

In [None]:
def create_disease_groups(meta):
    """Create biologically meaningful disease groupings"""
    manual_map = {
        'IGT': 'metabolic',
        'T2D': 'metabolic',
        'control': 'control',
        'CAD': 'cardiovascular',
        'HF': 'cardiovascular'
    }
    meta['disease_group'] = meta['study_condition'].map(manual_map)
    print("\nDisease grouping:")
    print(meta['disease_group'].value_counts())
    return meta

meta = create_disease_groups(meta)


# 6. Disease-Specific Modeling

In [None]:
def train_disease_specific_models(ab_clr, meta, feature_names):
    """Train separate models for each disease type vs. control"""
    results_by_disease = {}

    control_idx = meta['disease_group'] == 'control'
    disease_groups = ['metabolic', 'cardiovascular']

    for disease_group in disease_groups:
        print(f"\n{'='*60}")
        print(f"MODEL: {disease_group.upper()} vs CONTROL")
        print('='*60)

        # Select samples
        disease_idx = meta['disease_group'] == disease_group
        selected_idx = control_idx | disease_idx

        X_subset = ab_clr.loc[selected_idx].values
        y_subset = disease_idx[selected_idx].astype(int).values

        n_disease = y_subset.sum()
        n_control = (y_subset == 0).sum()
        print(f"Samples: {n_control} controls vs {n_disease} {disease_group}")

        # Train model with cross-validation
        model = LogisticRegressionCV(
            cv=5, penalty='l2', max_iter=1000,
            class_weight='balanced', random_state=2025, n_jobs=-1
        )

        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=2025)
        scoring = {
            'roc_auc': 'roc_auc',
            'balanced_acc': make_scorer(balanced_accuracy_score),
            'f1': 'f1'
        }

        cv_results = cross_validate(
            model, X_subset, y_subset,
            cv=cv, scoring=scoring, return_train_score=True
        )

        # Store results
        results_by_disease[disease_group] = {
            'n_control': n_control,
            'n_disease': n_disease,
            'roc_auc': cv_results['test_roc_auc'].mean(),
            'roc_auc_std': cv_results['test_roc_auc'].std(),
            'balanced_acc': cv_results['test_balanced_acc'].mean(),
            'f1': cv_results['test_f1'].mean()
        }

        print(f"\nPerformance Metrics:")
        print(f"  ROC-AUC: {results_by_disease[disease_group]['roc_auc']:.3f} ± "
              f"{results_by_disease[disease_group]['roc_auc_std']:.3f}")
        print(f"  Balanced Accuracy: {results_by_disease[disease_group]['balanced_acc']:.3f}")
        print(f"  F1 Score: {results_by_disease[disease_group]['f1']:.3f}")

        # Fit on full subset for feature importance
        model.fit(X_subset, y_subset)
        coefs = model.coef_[0]
        top_idx = np.argsort(np.abs(coefs))[-20:][::-1]

        print(f"\nTop 10 Discriminative Taxa:")
        for i, idx in enumerate(top_idx[:10], 1):
            direction = "↑" if coefs[idx] > 0 else "↓"
            print(f"  {i:2d}. {feature_names[idx][:50]:50s} {direction} ({coefs[idx]:+.3f})")

        results_by_disease[disease_group]['model'] = model
        results_by_disease[disease_group]['top_features'] = [feature_names[i] for i in top_idx]
        results_by_disease[disease_group]['top_coefs'] = coefs[top_idx]

    return results_by_disease

# Train models
disease_results = train_disease_specific_models(ab_clr, meta, list(ab_clr.columns))


# 7. Results Visualization

In [None]:
def compare_disease_models(results_by_disease):
    """Visualize comparison between disease-specific models"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    diseases = list(results_by_disease.keys())
    metrics = ['roc_auc', 'balanced_acc', 'f1']

    x = np.arange(len(diseases))
    width = 0.25

    for i, metric in enumerate(metrics):
        values = [results_by_disease[d][metric] for d in diseases]
        axes[0].bar(x + i*width, values, width,
                   label=metric.upper().replace('_', ' '), alpha=0.8)

    axes[0].set_ylabel('Score', fontweight='bold')
    axes[0].set_title('Disease-Specific Model Performance', fontweight='bold')
    axes[0].set_xticks(x + width)
    axes[0].set_xticklabels([d.capitalize() for d in diseases])
    axes[0].legend()
    axes[0].grid(axis='y', alpha=0.3)
    axes[0].set_ylim([0, 1])
    axes[0].axhline(0.5, color='red', linestyle='--', alpha=0.5)

    # Sample sizes
    n_controls = [results_by_disease[d]['n_control'] for d in diseases]
    n_diseases = [results_by_disease[d]['n_disease'] for d in diseases]

    axes[1].bar(x, n_controls, width*2, label='Control', alpha=0.8, color='green')
    axes[1].bar(x, n_diseases, width*2, bottom=n_controls,
               label='Disease', alpha=0.8, color='red')

    axes[1].set_ylabel('Number of Samples', fontweight='bold')
    axes[1].set_title('Sample Distribution', fontweight='bold')
    axes[1].set_xticks(x)
    axes[1].set_xticklabels([d.capitalize() for d in diseases])
    axes[1].legend()
    axes[1].grid(axis='y', alpha=0.3)

    plt.tight_layout()
    plt.savefig(os.path.join(FIGS, 'disease_specific_comparison.png'),
                dpi=300, bbox_inches='tight')
    plt.show()

compare_disease_models(disease_results)


# 8. Taxonomic Analysis

In [None]:
def aggregate_by_taxonomy(ab_rel, taxa_meta, level='phylum'):
    """Aggregate species-level abundances to higher taxonomic levels"""
    common_taxa = ab_rel.columns.intersection(taxa_meta.index)
    ab_subset = ab_rel[common_taxa]
    taxa_subset = taxa_meta.loc[common_taxa]

    level_groups = taxa_subset.groupby(level).groups
    ab_aggregated = pd.DataFrame(index=ab_subset.index)

    for taxon, species_list in level_groups.items():
        if pd.notna(taxon):
            ab_aggregated[taxon] = ab_subset[species_list].sum(axis=1)

    print(f"{level.capitalize()}-level: {ab_aggregated.shape[1]} taxa")
    return ab_aggregated

def analyze_taxonomic_signatures(ab_rel, taxa_meta, meta):
    """Identify taxonomic groups enriched in different diseases"""
    results = {}

    for level in ['phylum', 'family']:
        print(f"\n--- {level.upper()} Level ---")
        ab_agg = aggregate_by_taxonomy(ab_rel, taxa_meta, level=level)

        common_idx = ab_agg.index.intersection(meta.index)
        ab_agg_aligned = ab_agg.loc[common_idx]
        meta_aligned = meta.loc[common_idx]

        enriched_taxa = {}

        for disease_group in ['metabolic', 'cardiovascular']:
            disease_samples = meta_aligned['disease_group'] == disease_group
            control_samples = meta_aligned['disease_group'] == 'control'

            pvalues, fold_changes, taxa_names = [], [], []

            for taxon in ab_agg_aligned.columns:
                disease_abund = ab_agg_aligned.loc[disease_samples, taxon]
                control_abund = ab_agg_aligned.loc[control_samples, taxon]

                if (disease_abund > 0).sum() < 5 or (control_abund > 0).sum() < 5:
                    continue

                try:
                    stat, pval = stats.mannwhitneyu(disease_abund, control_abund)
                    mean_disease = disease_abund.mean() + 1e-10
                    mean_control = control_abund.mean() + 1e-10
                    fc = np.log2(mean_disease / mean_control)

                    pvalues.append(pval)
                    fold_changes.append(fc)
                    taxa_names.append(taxon)
                except:
                    continue

            if len(pvalues) > 0:
                reject, pvals_corrected, _, _ = multipletests(pvalues, method='fdr_bh')

                results_df = pd.DataFrame({
                    'taxon': taxa_names,
                    'log2_FC': fold_changes,
                    'pvalue': pvalues,
                    'padj': pvals_corrected,
                    'significant': reject
                })

                sig_results = results_df[results_df['significant']].copy()
                sig_results = sig_results.sort_values('log2_FC', key=abs, ascending=False)
                enriched_taxa[disease_group] = sig_results

                print(f"\n{disease_group.upper()} vs Control: {sig_results.shape[0]} significant taxa")

        results[level] = enriched_taxa

    return results

taxonomic_results = analyze_taxonomic_signatures(ab_rel, taxa_meta, meta)


# 9. Save Results

In [None]:
# Save disease-specific results
disease_summary = pd.DataFrame([
    {
        'disease_group': disease,
        'n_samples': results['n_disease'],
        'roc_auc': results['roc_auc'],
        'roc_auc_std': results['roc_auc_std'],
        'balanced_acc': results['balanced_acc'],
        'f1': results['f1']
    }
    for disease, results in disease_results.items()
])
disease_summary.to_csv(os.path.join(TABLES, 'disease_specific_results.csv'), index=False)

# Save taxonomic signatures
for level, disease_dict in taxonomic_results.items():
    for disease, results_df in disease_dict.items():
        filename = f'taxonomic_signatures_{level}_{disease}.csv'
        results_df.to_csv(os.path.join(TABLES, filename), index=False)

print("\n" + "="*60)
print("ANALYSIS COMPLETE!")
print("="*60)
print(f"Figures saved to: {FIGS}/")
print(f"Tables saved to: {TABLES}/")