# Association between the gut microbiota and host behavior
Assess to what extent SGB abundances at each timepoint (D0, D15, D22, and D30) predict host behavior.

## Set up

### Import packages

In [1]:
import argparse
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import linregress
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import KFold, cross_val_predict, permutation_test_score
import numpy as np
import matplotlib
import seaborn as sns
import os

matplotlib.use('Agg')

### Set up argument parsing

In [9]:
# Set up argument parsing
def parse_args():
    parser = argparse.ArgumentParser(description="Run Random Forest to predict phenotype from microbiome abundance.")
    parser.add_argument(
        "--output_dir", 
        required=True, 
        help="Directory to save output files"
    )
    parser.add_argument(
        "--phenotype_vars", 
        type=str, 
        nargs="+", 
        help="List of phenotype variables to analyze, e.g., --phenotype_vars Social_preference Groom_duration Center_time"
    )
    parser.add_argument(
        "--phenotype_file", 
        type=str, 
        required=True, 
        help="Path to the phenotype metadata file"
    )
    parser.add_argument(
        "--abundance_file", 
        type=str, 
        required=True, 
        help="Path to the abundance data file"
    )
    parser.add_argument(
        "--timepoint", 
        required=True, 
        help="Timepoint of the data")
    return parser.parse_args()

# Parse arguments
args = parse_args()

usage: ipykernel_launcher.py [-h] --output_dir OUTPUT_DIR
                             [--phenotype_vars PHENOTYPE_VARS [PHENOTYPE_VARS ...]]
                             --phenotype_file PHENOTYPE_FILE --tissue
                             {Adr,VAT} --transcriptome_file TRANSCRIPTOME_FILE
ipykernel_launcher.py: error: the following arguments are required: --output_dir, --phenotype_file, --tissue, --transcriptome_file


SystemExit: 2

In [4]:
# Simulate arguments for troubleshooting
class Args:
    output_dir = "output/Behavior_T1"
    phenotype_vars = ["Center_occupancy", "Grooming_duration", "Social_preference"]
    phenotype_file = "Behavior_T1_data.txt"
    timepoint = "T1"
    abundance_file = "rep_SGB_raref_relative_abundance_MAG_ID.txt"

args = Args()

### Set up output directory

In [5]:
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)

### Define timepoint

In [7]:
timepoint = args.timepoint

### Define palette

In [42]:
treatment_color_map = {
    "Pair_H2O": "#63B8FF",     # steelblue2
    "Pair_TMT": "#EEC900",     # gold2
    "Single_H2O": "#CD1076",   # deeppink3
    "Single_TMT": "#EE9A49"    # sienna2
}

### Load data

In [26]:
# Load phenotype metadata
phenotype = pd.read_csv(args.phenotype_file, sep="\t")

# Load abundance data
abundance = pd.read_csv(args.abundance_file, sep="\t", index_col=0)

In [27]:
# Data overview

# Display first 5 rows
display(phenotype.head())
display(abundance.head())

# Check data types and missing values
print(phenotype.info())
print(abundance.info())

Unnamed: 0,Mouse_ID,Sample_ID,Sex,DOB,Dame,Sire,Parents,Litter,Weaned,Experiement_Start,...,TMT,Treatment,Pair_mouse,Cage_ID,Treatment_Timepoint,Stress,Center_occupancy,Grooming_duration,Social_preference,Weight
0,2028,2028_T1,Female,01/01/2022,,,NA_NA,NA_NA_44562,31/01/2022,15/02/2022,...,H2O,Pair_H2O,2029.0,2028_2029,Pair_H2O_T1,Control,344.533333,1,0.093117,15.04
1,2029,2029_T1,Female,01/01/2022,,,NA_NA,NA_NA_44562,31/01/2022,15/02/2022,...,H2O,Pair_H2O,2028.0,2028_2029,Pair_H2O_T1,Control,121.866667,65,0.446575,16.09
2,594,594_T1,Male,01/01/2022,550.0,560.0,550_560,550_560_44562,24/01/2022,15/02/2022,...,H2O,Pair_H2O,595.0,594_595,Pair_H2O_T1,Control,195.866667,14,0.895172,20.35
3,595,595_T1,Male,01/01/2022,550.0,560.0,550_560,550_560_44562,24/01/2022,15/02/2022,...,H2O,Pair_H2O,594.0,594_595,Pair_H2O_T1,Control,201.933333,88,0.474453,20.31
4,596,596_T1,Female,01/01/2022,550.0,560.0,550_560,550_560_44562,24/01/2022,15/02/2022,...,TMT,Pair_TMT,597.0,596_597,Pair_TMT_T1,Stressor,216.566667,17,0.87976,16.18


Unnamed: 0,MAG_ID,2028_T1,2028_T5,2029_T1,2029_T5,594_T1,594_T5,595_T1,595_T5,596_T1,...,640_T3,640_T4,641_T3,641_T4,642_T3,642_T4,643_T3,643_T4,660_T3,660_T4
SGB001,SGB001,1.653143,0.232005,0.081545,0.153245,1.5e-05,1.5e-05,0.0,0.0,0.328732,...,0.94932,0.33497,0.649551,0.889639,0.383688,0.516491,0.64041,1.02194,0.978599,0.278992
SGB002,SGB002,2.462126,0.473809,0.406443,0.709397,0.0,1.5e-05,1.5e-05,0.0,1.6e-05,...,3.3e-05,3.4e-05,3.2e-05,3.3e-05,3.2e-05,0.0,3.3e-05,3.4e-05,0.0,0.0
SGB012,SGB012,1.551652,1.108049,0.176861,1.4217,0.845357,1.414299,1.040065,0.915453,2.027241,...,4.3e-05,4.4e-05,4.1e-05,4.3e-05,4.1e-05,4.2e-05,0.0,4.4e-05,0.671992,1.282148
SGB035,SGB035,0.09881,0.152776,0.0,0.126336,0.130409,0.09096,0.055733,0.082296,1.8e-05,...,3.7e-05,0.0,3.6e-05,0.0,3.6e-05,0.0,0.0,0.0,3.7e-05,3.8e-05
SGB103,SGB103,6.220864,6.058197,3.336119,7.028358,10.585193,7.374783,5.469805,8.262121,3.451959,...,3.344646,7.991148,6.376491,4.813945,5.59397,7.400267,3.140161,4.628417,4.166448,6.480963


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 47 entries, 0 to 46
Data columns (total 27 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Mouse_ID             47 non-null     int64  
 1   Sample_ID            47 non-null     object 
 2   Sex                  47 non-null     object 
 3   DOB                  47 non-null     object 
 4   Dame                 45 non-null     object 
 5   Sire                 45 non-null     float64
 6   Parents              47 non-null     object 
 7   Litter               47 non-null     object 
 8   Weaned               47 non-null     object 
 9   Experiement_Start    47 non-null     object 
 10  Starting_Age         47 non-null     int64  
 11  Days_P_Wean          47 non-null     int64  
 12  Cohort               47 non-null     object 
 13  Timepoint            47 non-null     object 
 14  Sampling_date        47 non-null     object 
 15  Sampling_Age         47 non-null     int64

In [28]:
# Transpose the abundance table: Rows become Sample_IDs, Columns become MAG features
abundance = abundance.T
abundance.index.name = "Sample_ID"
abundance = abundance.reset_index()

# Convert abundance values to numeric (force non-numeric values to NaN and handle appropriately)
abundance.iloc[:, 1:] = abundance.iloc[:, 1:].apply(pd.to_numeric, errors='coerce')

# Remove 1st line
abundance = abundance.iloc[1:]

display(abundance.head())
print(abundance.info())

Unnamed: 0,Sample_ID,SGB001,SGB002,SGB012,SGB035,SGB103,SGB086,SGB152,SGB003,SGB011,...,SGB127,SGB065,SGB112,SGB077,SGB078,SGB061,SGB098,SGB017,SGB158,SGB146
1,2028_T1,1.653143,2.462126,1.551652,0.09881,6.220864,2.82632,0.583152,1.403031,5.083632,...,0.150603,9e-06,0.08632,0.0,0.087762,0.248802,0.0,0.0,0.0,0.0
2,2028_T5,0.232005,0.473809,1.108049,0.152776,6.058197,1.599059,0.684653,1.301366,1.742463,...,0.263139,0.076161,0.163405,8e-06,0.196704,0.0,0.522487,1e-05,3.461532,1.4e-05
3,2029_T1,0.081545,0.406443,0.176861,0.0,3.336119,0.351705,0.083235,0.283863,0.492405,...,0.893146,0.095887,0.330952,0.173563,0.242426,0.518093,1.3e-05,0.0,0.0,0.0
4,2029_T5,0.153245,0.709397,1.4217,0.126336,7.028358,2.16906,0.66416,1.971886,1.253085,...,0.281828,0.067306,0.262906,0.0,0.093791,1.1e-05,0.543518,0.0,3.097634,0.0
5,594_T1,1.5e-05,0.0,0.845357,0.130409,10.585193,1.411901,0.381543,1.830642,5.654489,...,0.108834,0.090903,0.154122,0.052562,0.232023,0.734222,1e-05,0.0,0.0,0.0


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 181 entries, 1 to 181
Columns: 173 entries, Sample_ID to SGB146
dtypes: object(173)
memory usage: 244.8+ KB
None


In [29]:
# Select relevant columns from phenotype metadata
phenotype = phenotype[["Sample_ID", "Treatment"] + args.phenotype_vars]  # Dynamic selection of phenotype variables

display(phenotype.head())

Unnamed: 0,Sample_ID,Treatment,Center_occupancy,Grooming_duration,Social_preference
0,2028_T1,Pair_H2O,344.533333,1,0.093117
1,2029_T1,Pair_H2O,121.866667,65,0.446575
2,594_T1,Pair_H2O,195.866667,14,0.895172
3,595_T1,Pair_H2O,201.933333,88,0.474453
4,596_T1,Pair_TMT,216.566667,17,0.87976


In [32]:
# Merge datasets on Sample_ID
merged = pd.merge(phenotype, abundance, on="Sample_ID")

display(merged)

Unnamed: 0,Sample_ID,Treatment,Center_occupancy,Grooming_duration,Social_preference,SGB001,SGB002,SGB012,SGB035,SGB103,...,SGB127,SGB065,SGB112,SGB077,SGB078,SGB061,SGB098,SGB017,SGB158,SGB146
0,2028_T1,Pair_H2O,344.533333,1,0.093117,1.653143,2.462126,1.551652,0.09881,6.220864,...,0.150603,9e-06,0.08632,0.0,0.087762,0.248802,0.0,0.0,0.0,0.0
1,2029_T1,Pair_H2O,121.866667,65,0.446575,0.081545,0.406443,0.176861,0.0,3.336119,...,0.893146,0.095887,0.330952,0.173563,0.242426,0.518093,1.3e-05,0.0,0.0,0.0
2,594_T1,Pair_H2O,195.866667,14,0.895172,1.5e-05,0.0,0.845357,0.130409,10.585193,...,0.108834,0.090903,0.154122,0.052562,0.232023,0.734222,1e-05,0.0,0.0,0.0
3,595_T1,Pair_H2O,201.933333,88,0.474453,0.0,1.5e-05,1.040065,0.055733,5.469805,...,0.141697,0.073088,0.166101,0.0,0.211688,0.77027,1.1e-05,1e-05,0.0,1.4e-05
4,596_T1,Pair_TMT,216.566667,17,0.87976,0.328732,1.6e-05,2.027241,1.8e-05,3.451959,...,0.210063,0.228806,0.14249,0.186937,0.103387,0.202915,0.0,1.027051,0.0,0.0
5,597_T1,Pair_TMT,118.633333,107,0.380531,0.215855,1.5e-05,1.298721,0.107695,7.48088,...,0.253092,0.130737,0.11882,0.10127,0.080589,0.323852,0.0,1.192089,0.0,0.0
6,598_T1,Single_TMT,175.5,59,0.626506,1.7e-05,0.0,1.155102,0.058953,4.632725,...,0.348862,0.158319,0.313817,0.176335,0.314371,0.301876,1.2e-05,1.078787,0.0,0.0
7,599_T1,Single_H2O,164.833333,0,0.68,0.112511,0.0,1.775267,0.0,5.268881,...,0.198408,0.597122,0.369971,0.166216,0.148677,0.486181,1.1e-05,1.892376,0.0,0.0
8,600_T1,Pair_TMT,164.2,31,0.654839,0.12375,0.0,1.167028,0.0,4.627308,...,0.291782,0.239598,0.588204,0.184423,0.472715,0.382977,1.1e-05,1e-05,1e-05,1.4e-05
9,601_T1,Pair_TMT,118.366667,12,0.385124,0.233345,1.6e-05,1.369058,0.094255,4.930521,...,0.192376,0.223241,0.390331,0.198744,0.228945,0.256814,1.1e-05,0.0,1e-05,0.0


## Random Forest

In [39]:
# Define the phenotype variables to analyze (from command line argument)
phenotype_vars = args.phenotype_vars

# Initialize results dataframe
results_df = pd.DataFrame(columns=["Phenotype", "Perm_test_R2", "p_value"])
importance_df = pd.DataFrame(columns=["Phenotype", "MAG_ID", "Importance"])
plot_df = pd.DataFrame(columns=["Phenotype", "Treatment", "Predicted", "Actual"])

# Define cross-validation strategy
cv = KFold(n_splits=5, shuffle=True, random_state=42)

In [43]:
# Loop through each phenotype variable
phenotype_var = "Grooming_duration"

for phenotype_var in phenotype_vars:
    print(f"Processing {phenotype_var}...")

    # Define the target variable
    y = merged[phenotype_var]

    # One-hot encode Treatment as a covariate
    covariates = pd.get_dummies(merged["Treatment"], drop_first=True)

    # Residualize the phenotype variable against Treatment
    cov_model = LinearRegression()
    cov_model.fit(covariates, y)
    y_pred_cov = cov_model.predict(covariates)
    y_res = y - y_pred_cov  # Residualized target variable

    # Define the abundance predictors (excluding phenotype metadata columns)
    phenotype_cols = {"Sample_ID", "Treatment"} | set(phenotype_vars)
    abundance_features = [col for col in merged.columns if col not in phenotype_cols]
    X_abundance = merged[abundance_features]

    # Define the Random Forest model
    rf_model = RandomForestRegressor(n_estimators=20, random_state=42, n_jobs=-1)

    # Get cross-validated predictions instead of using training predictions
    y_pred_cv = cross_val_predict(rf_model, X_abundance, y_res, cv=cv)

    # Perform permutation test using cross-validation
    score, permutation_scores, pvalue = permutation_test_score(
        rf_model, X_abundance, y_res, cv=cv, n_permutations=1000, scoring="r2", random_state=42
    )

    # Store results
    new_row = pd.DataFrame({
        "Phenotype": [phenotype_var],
        "Perm_test_R2": [score],
        "p_value": [pvalue]
    })
    results_df = pd.concat([results_df, new_row], ignore_index=True)
    
    #display(results_df.head())


    ## Find most important features

    # Train the Random Forest model on the full dataset
    rf_model.fit(X_abundance, y_res)

    # Store results
    new_imp = pd.DataFrame({
        "Phenotype": [phenotype_var] * len(X_abundance.columns),
        "MAG_ID": X_abundance.columns,
        "Importance": rf_model.feature_importances_
    }).sort_values(by="Importance", ascending=False)

    importance_df = pd.concat([importance_df, new_imp], ignore_index=True)
    
    #display(importance_df.head())


    ### Predicted vs. Actual Plot (Cross-Validation) with Regression Line & p-value
    
    # Add predicted values and residuals to a DataFrame for plotting
    new_plot = pd.DataFrame({
        "Phenotype": [phenotype_var] * len(y_pred_cv),
        "Treatment": merged["Treatment"],
        "Predicted": y_pred_cv,
        "Actual": y_res
    })
    plot_df = pd.concat([plot_df, new_plot], ignore_index=True)

    # Linear regression stats
    slope, intercept, r_value, p_value, std_err = linregress(y_pred_cv, y_res)

    # Plot
    plt.figure(figsize=(4, 4))
    sns.scatterplot(data=new_plot, x="Predicted", y="Actual", hue="Treatment", palette=treatment_color_map, alpha=0.6)
    sns.lineplot(x=new_plot["Predicted"], y=intercept + slope * new_plot["Predicted"], color="red", label="Regression Line")

    plt.text(
        x=new_plot["Predicted"].min() + (new_plot["Predicted"].max() - new_plot["Predicted"].min()) * 0.05,
        y=new_plot["Actual"].max() - (new_plot["Actual"].max() - new_plot["Actual"].min()) * 0.1,
        s=f"p = {p_value:.3g}",
        fontsize=8,
        color="black"
    )

    plt.xlabel(f"Predicted Residualized {phenotype_var}", fontsize=8)
    plt.ylabel(f"Actual Residualized {phenotype_var}", fontsize=8)
    plt.xticks(fontsize=8 * 0.33)
    plt.yticks(fontsize=8 * 0.33)
    plt.title(f"Predicted vs. Actual Plot (Cross-Validation) - {phenotype_var}", fontsize=8 * 0.33)
    plt.legend(title="Treatment", fontsize=8 * 0.33, title_fontsize=8 * 0.33)
    plt.savefig(os.path.join(output_dir, f"Predicted_vs_actual_{phenotype_var}_{timepoint}.pdf"), format="pdf")
    plt.show()


    ### Residual Plot (Cross-Validation)
    residuals = y_res - y_pred_cv
    plt.figure(figsize=(4, 4))
    plt.scatter(y_pred_cv, residuals, alpha=0.6, color="green", label="Residuals (CV)")
    plt.axhline(0, color="red", linestyle="--")
    plt.xlabel(f"Predicted Residualized {phenotype_var}", fontsize=8)
    plt.ylabel("Residuals", fontsize=8)
    plt.xticks(fontsize=8 * 0.33)
    plt.yticks(fontsize=8 * 0.33)
    plt.title(f"Residual Plot (Cross-Validation) - {phenotype_var}", fontsize=8 * 0.33)
    plt.legend(fontsize=8 * 0.33)
    plt.savefig(os.path.join(output_dir, f"Residual_plot_{phenotype_var}_{timepoint}.pdf"), format="pdf")  # Save as PDF
    plt.show()


    ### Feature Importance
    plt.figure(figsize=(8, 4))
    sns.barplot(data=new_imp.head(10), x="Importance", y="MAG_ID", palette="viridis")
    plt.xlabel("Feature Importance")
    plt.ylabel("SGB")
    plt.title("Top 10 Important SGBs for Prediction")
    plt.savefig(os.path.join(output_dir, f"Feature_importance_{phenotype_var}_{timepoint}.pdf"), format="pdf")  # Save as PDF
    plt.show()

Processing Center_occupancy...



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=new_imp.head(10), x="Importance", y="MAG_ID", palette="viridis")


Processing Grooming_duration...



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=new_imp.head(10), x="Importance", y="MAG_ID", palette="viridis")


Processing Social_preference...



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(data=new_imp.head(10), x="Importance", y="MAG_ID", palette="viridis")


In [44]:
# Save final results to CSV
results_df.to_csv(os.path.join(output_dir, f"RF_permutation_test_results_{timepoint}.csv"), index=False)
importance_df.to_csv(os.path.join(output_dir, f"RF_important_features_{timepoint}.csv"), index=False)
plot_df.to_csv(os.path.join(output_dir, f"RF_predict_v_actual_plot_{timepoint}.csv"), index=False)