# Housekeeping

## Library imports

In [None]:
if True:
    pass
    !{sys.executable} -m pip install -r requirements.txt

In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
# Core Libraries
import pandas as pd
import seaborn as sns
# Machine Learning - Core
import sklearn
from colorama import Fore, Style
# Statistical & Data Processing
from scipy import stats
# Machine Learning - Models
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import LabelEncoder
from supertree import SuperTree

# Model Interpretability
warnings.filterwarnings('ignore')

# Custom Functions

## Optional Imports (for specific analyses)

These imports are used for specific visualizations and analyses that may not be needed in every run.

In [None]:
# Optional: Advanced Visualizations
# from sklearn.decomposition import PCA
# from click.format import wrap_text  # May not be needed
# import collections  # Built-in, no install needed

# Note: Most visualization functions are now in functions.py
# and imported above in the main imports cell

## Settings

In [None]:
warnings.filterwarnings("ignore")
sklearn.set_config(transform_output="pandas")
print(Style.RESET_ALL)

## Data imports
Data was manually edited, to convert the mpa411.txt TSV format to a CSV format. Otherwise, Pandas was loading it as a single column, somehow. The first row, containing only "#mpa_vJun23_CHOCOPhlAnSGB_202403" was removed.

In [None]:
data = pd.read_csv('../data/raw/MAI3004_lucki_mpa411.csv')
metadata = pd.read_csv('../data/raw/MAI3004_lucki_metadata_safe.csv')
print(
    f"Data successfully imported. \n shape of data: {data.shape} \n "
    f"Shape of metadata: {metadata.shape}"
)

assert data.shape == (6903, 932), "Data has the wrong shape. Check the CSV formatting."
assert metadata.shape == (930, 6), "Metadata has the wrong shape. Check the CSV formatting."


## Function definitions
| Function Name | Description | Parameters |
|---------------|-------------|------------|


# Data preprocessing

## Merge data and metadata

In [None]:
sample_cols = [col for col in data.columns if col.startswith("mpa411_")]

sample_abundances = (
    data[['clade_name'] + sample_cols]
    .set_index('clade_name')
    .transpose()
    .rename_axis('original_sample_id')
    .reset_index()
    .rename(columns={'original_sample_id': 'sample_id'})
)

sample_abundances["sample_id"] = (
    sample_abundances["sample_id"].str.removeprefix(
        "mpa411_",
    )
)

metadata_common = metadata[
    metadata["sample_id"].isin(sample_abundances["sample_id"])
].copy()
merged_samples = metadata_common.merge(
    sample_abundances,
    on="sample_id",
    how="inner",
)

merged_samples.drop(columns=['year_of_birth', 'body_product'], inplace=True)
# YOB and body_product are omitted without sample dates.
# All samples are fecal.
# TODO: should we be accounting for sex? Do statistical analysis

print(f"Metadata rows (original): {metadata.shape[0]}")
print(f"Metadata rows with matching samples: {metadata_common.shape[0]}")
print(
    f"Metadata rows without matching samples: "
    f"{metadata_common.shape[0]-metadata_common.shape[0]}"
)
print(f"Merged dataframe shape: {merged_samples.shape}")

In [None]:
merged_samples.head()

## Encoding

In [None]:
# Sex and family_ID
encoded_samples = merged_samples.copy().dropna(subset="age_group_at_sample")

encoded_samples["sex"] = (
    encoded_samples["sex"]
    .fillna("unknown")
    .replace({"female": 1, "male": 0, "unknown": 2})
)
encoded_samples["family_id"] = LabelEncoder().fit_transform(
    encoded_samples["family_id"]
)


In [None]:
#Using days to better interpret the distance between age groups
encoding_guide = {
    '1-2 weeks': 10,
    '4 weeks': 28,
    '8 weeks': 56,
    '4 months': 120,
    '5 months': 150,
    '6 months': 180,
    '9 months': 270,
    '11 months': 330,
    '14 months': 420,
}
encoded_samples["age_group_at_sample"].replace(encoding_guide, inplace=True)

# consider in interpretation that the distances between the real age bins are not the same as our age groups

In [None]:
if False in pd.DataFrame(encoded_samples["age_group_at_sample"]).applymap(np.isreal): #fallback encoder
    age_encoder = LabelEncoder().fit(encoded_samples["age_group_at_sample"])
    encoded_samples["age_group_at_sample"] = age_encoder.transform(
        encoded_samples["age_group_at_sample"]
    )

    age_groups = dict(
        zip(age_encoder.classes_, age_encoder.transform(age_encoder.classes_))
    )
    print("Age group encoding:", age_groups)

else:
    print("Fallback encoding not needed")

## Missing check

In [None]:
missing_table = (
    encoded_samples.isna()
    .sum()
    .to_frame(name="missing_count")
    .assign(
        missing_percent=lambda df: (
            (df["missing_count"] / encoded_samples.shape[0] * 100).round(2)
        ),
    )
    .reset_index()
    .rename(columns={"index": "column"})
    .sort_values("missing_count", ascending=False)
    .query("missing_count != 0")
)

if len(missing_table) > 0:
    missing_table
else:
    print("No missing values detected.")

## Outlier check

In [None]:
numeric_cols = encoded_samples.select_dtypes(include=[np.number]).columns

q1 = encoded_samples[numeric_cols].quantile(0.25)
q3 = encoded_samples[numeric_cols].quantile(0.75)
iqr = q3 - q1

lower_bounds = q1 - 1.5 * iqr
upper_bounds = q3 + 1.5 * iqr

outlier_mask = (
    (encoded_samples[numeric_cols] < lower_bounds)
    | (encoded_samples[numeric_cols] > upper_bounds)
)
outlier_counts = outlier_mask.sum()
outlier_percent = (outlier_counts / encoded_samples.shape[0] * 100).round(2)

outlier_table = (
    pd.DataFrame({
        "column": numeric_cols,
        "lower_bound": lower_bounds,
        "upper_bound": upper_bounds,
        "outlier_count": outlier_counts,
        "outlier_percent": outlier_percent,
    })
    .query("outlier_count > 0")
    .sort_values("outlier_percent", ascending=False)
    .reset_index(drop=True)
)

outlier_table

## Normalisation check

In [None]:
normalized_samples = encoded_samples.copy()
print("Shapiro-Wilk Normality Test")

for column in numeric_cols:
    data_nona = normalized_samples[column].dropna()
    stat, p_value = stats.shapiro(data_nona)

    if p_value > 0.05:
        print(Fore.GREEN + f"{column}: Normally Distributed (p={p_value:.4f})")

    else:
        print(
            Fore.RED
            + f"{column}: Not Normally Distributed (p={p_value:.4f})"
        )

print(Style.RESET_ALL)

## Train-test split before pre-processing

In [None]:
from sklearn.model_selection import GroupShuffleSplit
feature_cols = normalized_samples.columns.difference(["sample_id", "age_group_at_sample"]) # These variables will get removed from X

X = normalized_samples[feature_cols]
Y = normalized_samples["age_group_at_sample"]

gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=3004)
train_indicies, test_indicies = next(gss.split(X, Y, groups=X['family_id']))
X_train_raw = X.iloc[train_indicies]
X_test_raw = X.iloc[test_indicies]
Y_train = Y.iloc[train_indicies]
Y_test = Y.iloc[test_indicies]

assert X_train_raw.shape[1] == X_test_raw.shape[1], "Feature columns do not match between train and test sets."
assert X_train_raw.shape[0] == Y_train.shape[0] and X_test_raw.shape[0] == Y_test.shape[0], "X and Y do not have the same length."

print("Train shape:", X_train_raw.shape, "| Test shape:", X_test_raw.shape)

## Normalising data using clr transformation

In [None]:
"""The CLR function based on: https://medium.com/@nextgendatascientist/a-guide-for-data-scientists-log-ratio-transformations-in-machine-learning-a2db44e2a455"""

def clr_transform(X, epsilon=1e-9):
    """
    Compute CLR with a tunable zero-replacement value (epsilon).
    """

    #To capture metadata from the original dataframe
    if isinstance(X, pd.DataFrame):
        index = X.index
        columns = X.columns
        X_arr = X.values
    else:
        X_arr = X

    # 1. Replace zeros with epsilon (tunable parameter)
    X_replaced = np.where(X == 0, epsilon, X)

    # 2. Compute Geometric Mean
    # exp(mean(log)) is safer and standard for this
    gm = np.exp(np.log(X_replaced).mean(axis=1, keepdims=True))

    # 3. CLR transformation
    X_clr = np.log(X_replaced / gm)


    #Rebulding back a NumPy array to a dataframe
    if isinstance(X, pd.DataFrame):
        return pd.DataFrame(X_clr, index=index, columns=columns)

    return X_clr

In [None]:
X_train = clr_transform(X_train_raw)
X_test = clr_transform(X_test_raw)

In [None]:
#To activate back the raw data without normalisation

#X_train = X_train_raw
#X_test = X_test_raw

In [None]:
print("Shapiro-Wilk Normality Test (after Log Normlaisation)\n")

for col in X_train.columns:

    # 1. Get the data for this column from both sets
    train_data = X_train[col].dropna()
    test_data = X_test[col].dropna()

    # 2. Run Shapiro test on both
    stat_train, p_train = stats.shapiro(train_data)
    stat_test, p_test = stats.shapiro(test_data)

    # 3. Determine status (Both must be > 0.05 to be truly "Normal")
    is_train_normal = p_train > 0.05
    is_test_normal = p_test > 0.05

    # 4. Print Logic
    # If both are Green
    if is_train_normal and is_test_normal:
        print(Fore.GREEN + f"✔ {col}: Normal (Train p={p_train:.3f}, Test p={p_test:.3f})")

    # If one is Red (Mixed results)
    elif is_train_normal or is_test_normal:
        print(Fore.YELLOW + f"⚠ {col}: Inconsistent (Train p={p_train:.3f}, Test p={p_test:.3f})")

    # If both are Red
    else:
        print(Fore.RED + f"✘ {col}: Not Normal (Train p={p_train:.3f}, Test p={p_test:.3f})")

print(Style.RESET_ALL)

# Exploratory data analysis

In [None]:
print(merged_samples.shape)
merged_samples.head()

In [None]:
# Dataset overview
print("Number of samples:", len(merged_samples))
print(
    "Number of unique families (family_id):",
    merged_samples["family_id"].nunique(),
)
print("Number of columns (metadata + features):", merged_samples.shape[1])

In [None]:
# Samples per family
samples_per_family = merged_samples["family_id"].value_counts()
samples_per_family.describe()

In [None]:
samples_per_family.hist(bins=20)
plt.xlabel("Number of samples per family")
plt.ylabel("Number of families")
plt.title("Distribution of samples per family")
plt.show()

In [None]:
#distribution of age groups
merged_samples["age_group_at_sample"].value_counts(dropna=False)

In [None]:
merged_samples["age_group_at_sample"].value_counts().plot(kind="bar")
plt.title("Distribution of age groups")
plt.ylabel("Number of samples")
plt.xticks(rotation=45)
plt.show()

In [None]:
#dimensionality and sparsity of the microbiome feature matrix
metadata_cols = [
    "sample_id",
    "family_id",
    "sex",
    "body_product",
    "age_group_at_sample",
    "year_of_birth",
]
feature_cols = [c for c in merged_samples.columns if c not in metadata_cols]

X = merged_samples[feature_cols]

print("Feature matrix shape:", X.shape)
print("Overall fraction of zeros:", (X == 0).mean().mean())

In [None]:
#number of observed taxa per sample
nonzero_per_sample = (X > 0).sum(axis=1)
nonzero_per_sample.describe()

In [None]:
nonzero_per_sample.hist(bins=50)
plt.xlabel("Number of non-zero taxa per sample")
plt.ylabel("Number of samples")
plt.title("Non-zero taxa per sample")
plt.show()

In [None]:
# Total abundance per sample (sanity check)
total_abundance = X.sum(axis=1)
total_abundance.describe()

In [None]:
total_abundance.hist(bins=50)
plt.xlabel("Total abundance per sample")
plt.ylabel("Number of samples")
plt.title("Total microbial abundance per sample")
plt.show()

In [None]:
# Distribution of feature prevalence
feature_prevalence = (X > 0).sum(axis=0)
feature_prevalence.describe()

In [None]:
feature_prevalence.hist(bins=50)
plt.xlabel("Number of samples in which taxon is present")
plt.ylabel("Number of taxa")
plt.title("Feature prevalence distribution")
plt.show()

In [None]:
# Distribution of non-zero abundances (log scale)

nonzero_values = X.values[X.values > 0]
plt.hist(np.log10(nonzero_values), bins=50)
plt.xlabel("log10(abundance)")
plt.ylabel("Frequency")
plt.title("Distribution of non-zero abundances (log10 scale)")
plt.show()

In [None]:
# PCA visualization
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# Use a subset of features for speed
prevalence = (X > 0).sum(axis=0)
top_features = prevalence.sort_values(ascending=False).head(500).index

X_sub = X[top_features]

# Scale features
X_scaled = StandardScaler().fit_transform(X_sub)

# PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

# Plot
age = merged_samples["age_group_at_sample"]

plt.figure(figsize=(8,6))
plt.scatter(X_pca.iloc[:, 0], X_pca.iloc[:, 1],
            c=pd.factorize(age)[0], cmap="viridis", alpha=0.6)
plt.colorbar(label="Age group (encoded)")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.title("PCA of samples colored by age group")
plt.show()

print("Explained variance ratio:", pca.explained_variance_ratio_)

### Summary of EDA
The dataset consists of 930 stool samples derived from multiple individuals across different families and contains approximately 6,900 microbiome features, making it a high-dimensional and highly sparse dataset. Each sample contains on average around 300 detected taxa, while the total microbial abundance per sample is relatively stable, indicating that sequencing depth is consistent across samples.
Most taxa are rare and occur in only a small fraction of samples, whereas a small subset of taxa is highly prevalent across the cohort. The distribution of non-zero abundances follows an approximately log-normal shape, which is typical for microbiome sequencing data (e.g., Lutz et al., 2022).
A PCA projection based on the most prevalent taxa does not reveal sharply separated clusters but shows a gradual age-related gradient, suggesting that age-related variation in microbiome composition is present but represents only a limited fraction of the total variance in the data.

# Model Training

## Filtering for features at the genus level

In [None]:
def filter_genus(df_uf): #Defining a function that filters a dataframe to only include columns with features at genus level
    df_uf = df_uf.drop(list(df_uf.filter(regex="s__")),axis=1,inplace=False) #Drops columns that include features at species level
    df_uf = df_uf.filter(regex="g__") #Drops columns that include features broader than genus level
    return df_uf

In [None]:
X_train_genus = filter_genus(X_train)
X_test_genus = filter_genus(X_test)

In [None]:
"target_names" in X_train_genus.columns

In [None]:
print("Features used by model:")
print(X_train_genus.columns.tolist())

## Random Forest Regressor with Train/Test split (Genus)

### Base model

In [None]:
# Base model
rf_base = RandomForestRegressor(
    random_state=42,
    n_jobs=-1,
    oob_score=True
)

rf_base.fit(X_train_genus, Y_train)
yhat_rf = rf_base.predict(X_test_genus)

print(f"Mean Squared Error: {mean_squared_error(Y_test, yhat_rf):.3f}")
print(f"Best CV RMSE: {np.sqrt(mean_squared_error(Y_test, yhat_rf)):.3f}")
print(f"R2 Score: {r2_score(Y_test, yhat_rf):.3f}")


### Search for the best model

In [None]:
from functions import random_forest_benchmark

rf_results = random_forest_benchmark(
    X_train_genus,
    X_test_genus,
    Y_train,
    Y_test,
    label="Genus Level"
)

print(f"\nBest hyperparameters: {rf_results.best_params}")

In [None]:
print(f"Mean Squared Error: {rf_results.rmse**2:.3f}") # Squared because rmse is sqrt(mse)
print(f"Best CV RMSE: {rf_results.rmse:.3f}")
print(f"R2 Score: {rf_results.r2:.3f}")

best_rf_model = rf_results.model
yhat = rf_results.model.predict(X_test_genus)

In [None]:
plt.figure(figsize=(6,6))
sns.scatterplot(x=Y_test, y=yhat, alpha=0.6)
plt.plot([Y_test.min(), Y_test.max()],
         [Y_test.min(), Y_test.max()],
         color="red", linestyle="--")

plt.xlabel("Actual Values")
plt.ylabel("Predicted Values")
plt.title("Predicted vs Actual (Genus-level RF)")
plt.tight_layout()
plt.show()

In [None]:
residuals = Y_test - yhat

plt.figure(figsize=(6,4))
sns.scatterplot(x=yhat, y=residuals, alpha=0.6)
plt.axhline(0, linestyle="--", color="red")

plt.xlabel("Predicted values")
plt.ylabel("Residuals")
plt.title("Residuals vs Predictions")
plt.tight_layout()
plt.show()

In [None]:
importances = pd.Series(
    best_rf_model.feature_importances_,
    index=X_train_genus.columns
).sort_values(ascending=False)

top_n = 20

plt.figure(figsize=(8,6))
sns.barplot(
    x=importances.head(top_n),
    y=importances.head(top_n).index
)

plt.xlabel("Feature importance")
plt.ylabel("Genus")
plt.title(f"Top {top_n} most important genera")
plt.tight_layout()
plt.show()

In [None]:
st = SuperTree(
    best_rf_model,
    X_train_genus,
    Y_train
)

st.show_tree(which_tree=0)

In [None]:
# get predictions from each tree on the test set
all_tree_preds = np.array([tree.predict(X_test_genus) for tree in best_rf_model.estimators_])

# compute the mean prediction (Random Forest final prediction)
rf_pred = all_tree_preds.mean(axis=0)

# compute standard deviation per sample (uncertainty)
rf_std = all_tree_preds.std(axis=0)

plt.figure(figsize=(10,6))

# plot all tree predictions (semi-transparent lines)
for i in range(all_tree_preds.shape[0]):
    plt.plot(Y_test.values, all_tree_preds[i], 'o', color='lightgray', alpha=0.3)

# plot Random Forest mean prediction
plt.scatter(Y_test, rf_pred, color='blue', label='RF mean prediction', s=40)

plt.errorbar(Y_test, rf_pred, yerr=rf_std, fmt='o', color='red', alpha=0.5, label='±1 std across trees')

plt.plot([Y_test.min(), Y_test.max()],
         [Y_test.min(), Y_test.max()],
         color='black', linestyle='--', label='Perfect prediction')

plt.xlabel("Actual Age Group")
plt.ylabel("Predicted Age Group")
plt.title("Random Forest – Forest plot of tree predictions")
plt.legend()
plt.tight_layout()
plt.show()


# Alternative Models

## XGBoost Alternative

In [None]:
from xgboost import XGBRegressor
import re
from sklearn.metrics import mean_squared_error, r2_score

# Base model
xgb_base = XGBRegressor(
    objective='reg:squarederror',
    random_state=42,
    n_jobs=-1
)

X_train_clean = X_train_genus.copy()
X_train_clean.columns = [re.sub('[^A-Za-z0-9_]+', '', str(col)) for col in X_train_clean.columns]
X_test_clean = X_test_genus.copy()
X_test_clean.columns = [re.sub('[^A-Za-z0-9_]+', '', str(col)) for col in X_test_clean.columns]


xgb_base.fit(X_train_clean, Y_train)
yhat_xgb = xgb_base.predict(X_test_clean)

print(f"Mean Squared Error: {mean_squared_error(Y_test, yhat_xgb):.3f}")
print(f"Best CV RMSE: {np.sqrt(mean_squared_error(Y_test, yhat_xgb)):.3f}")
print(f"R2 Score: {r2_score(Y_test, yhat_xgb):.3f}")

#Best hyperparameters: {'subsample': 0.7, 'reg_lambda': 1.0, 'reg_alpha': 0.1, 'n_estimators': 1000, 'max_depth': 3, 'learning_rate': 0.01, 'colsample_bytree': 0.2}

### Best XGBoost Parameters Search

## Additional Ensemble Methods

### AdaBoost

In [None]:
from functions import adaboost_benchmark

adaboost_results = adaboost_benchmark(
    X_train_genus,
    X_test_genus,
    Y_train,
    Y_test,
    label="Genus Level"
)

print(f"\nBest hyperparameters: {adaboost_results.best_params}")

In [None]:
print(f"Mean Squared Error: {adaboost_results.rmse**2:.3f}")
print(f"Best CV RMSE: {adaboost_results.rmse:.3f}")
print(f"R2 Score: {adaboost_results.r2:.3f}")

best_adaboost_model = adaboost_results.model
yhat_ada = adaboost_results.model.predict(X_test_genus)

In [None]:
plt.figure(figsize=(6,6))
sns.scatterplot(x=Y_test, y=yhat_ada, alpha=0.6)
plt.plot([Y_test.min(), Y_test.max()],
         [Y_test.min(), Y_test.max()],
         color="red", linestyle="--")

plt.xlabel("Actual Values")
plt.ylabel("Predicted Values")
plt.title("Predicted vs Actual (Genus-level AdaBoost)")
plt.tight_layout()
plt.show()

In [None]:
importances_ada = pd.Series(
    best_adaboost_model.feature_importances_,
    index=X_train_genus.columns
).sort_values(ascending=False)

top_n = 20

plt.figure(figsize=(8,6))
sns.barplot(
    x=importances_ada.head(top_n),
    y=importances_ada.head(top_n).index
)

plt.xlabel("Feature importance")
plt.ylabel("Genus")
plt.title(f"Top {top_n} most important genera (AdaBoost)")
plt.tight_layout()
plt.show()

### Gradient Boosting

In [None]:
from functions import gradient_boosting_benchmark

gb_results = gradient_boosting_benchmark(
    X_train_genus,
    X_test_genus,
    Y_train,
    Y_test,
    label="Genus Level"
)

print(f"\nBest hyperparameters: {gb_results.best_params}")

In [None]:
print(f"Mean Squared Error: {gb_results.rmse**2:.3f}")
print(f"Best CV RMSE: {gb_results.rmse:.3f}")
print(f"R2 Score: {gb_results.r2:.3f}")

best_gb_model = gb_results.model
yhat_gb = gb_results.model.predict(X_test_genus)

In [None]:
plt.figure(figsize=(6,6))
sns.scatterplot(x=Y_test, y=yhat_gb, alpha=0.6)
plt.plot([Y_test.min(), Y_test.max()],
         [Y_test.min(), Y_test.max()],
         color="red", linestyle="--")

plt.xlabel("Actual Values")
plt.ylabel("Predicted Values")
plt.title("Predicted vs Actual (Genus-level Gradient Boosting)")
plt.tight_layout()
plt.show()

In [None]:
importances_gb = pd.Series(
    best_gb_model.feature_importances_,
    index=X_train_genus.columns
).sort_values(ascending=False)

top_n = 20

plt.figure(figsize=(8,6))
sns.barplot(
    x=importances_gb.head(top_n),
    y=importances_gb.head(top_n).index
)

plt.xlabel("Feature importance")
plt.ylabel("Genus")
plt.title(f"Top {top_n} most important genera (Gradient Boosting)")
plt.tight_layout()
plt.show()

### LightGBM

In [None]:
from functions import lightgbm_benchmark

lgb_results = lightgbm_benchmark(
    X_train_genus,
    X_test_genus,
    Y_train,
    Y_test,
    label="Genus Level"
)

print(f"\nBest hyperparameters: {lgb_results.best_params}")

In [None]:
print(f"Mean Squared Error: {lgb_results.rmse**2:.3f}")
print(f"Best CV RMSE: {lgb_results.rmse:.3f}")
print(f"R2 Score: {lgb_results.r2:.3f}")

best_lgb_model = lgb_results.model
yhat_lgb = lgb_results.model.predict(X_test_genus)

In [None]:
plt.figure(figsize=(6,6))
sns.scatterplot(x=Y_test, y=yhat_lgb, alpha=0.6)
plt.plot([Y_test.min(), Y_test.max()],
         [Y_test.min(), Y_test.max()],
         color="red", linestyle="--")

plt.xlabel("Actual Values")
plt.ylabel("Predicted Values")
plt.title("Predicted vs Actual (Genus-level LightGBM)")
plt.tight_layout()
plt.show()

In [None]:
importances_lgb = pd.Series(
    best_lgb_model.feature_importances_,
    index=X_train_genus.columns
).sort_values(ascending=False)

top_n = 20

plt.figure(figsize=(8,6))
sns.barplot(
    x=importances_lgb.head(top_n),
    y=importances_lgb.head(top_n).index
)

plt.xlabel("Feature importance")
plt.ylabel("Genus")
plt.title(f"Top {top_n} most important genera (LightGBM)")
plt.tight_layout()
plt.show()

## Model Interpretability with LIME and SHAP

### LIME (Local Interpretable Model-agnostic Explanations)

In [None]:
from functions import explain_with_lime

# Use the Random Forest model for explanation
lime_explanations = explain_with_lime(
    model=best_rf_model,
    X_train=X_train_genus,
    X_test=X_test_genus,
    feature_names=X_train_genus.columns.tolist(),
    num_samples=3,
    num_features=10
)

print("LIME Explanations for Random Forest Model:")
print(f"Number of samples explained: {len(lime_explanations)}")

In [None]:
# Display explanation for first sample
for sample_idx, explanation in lime_explanations.items():
    print(f"\n{'='*80}")
    print(f"Sample {sample_idx}:")
    print(f"Predicted value: {explanation['predicted_value']:.3f}")
    print(f"Actual value: {explanation['actual_value']:.3f}")
    print(f"\nTop contributing features:")
    for feature, weight in explanation['explanation'][:10]:
        print(f"  {feature}: {weight:+.4f}")

### SHAP (SHapley Additive exPlanations)

In [None]:
from functions import explain_with_shap

# Use the Random Forest model for SHAP explanation
shap_result = explain_with_shap(
    model=best_rf_model,
    X_train=X_train_genus,
    X_test=X_test_genus,
    feature_names=X_train_genus.columns.tolist(),
    background_samples=100
)

print("SHAP Analysis Complete")
print(f"SHAP values shape: {shap_result['shap_values'].shape}")

In [None]:
import shap

# Summary plot showing feature importance
plt.figure(figsize=(10, 8))
shap.summary_plot(
    shap_result['shap_values'],
    X_test_genus,
    feature_names=X_train_genus.columns.tolist(),
    max_display=20,
    show=False
)
plt.tight_layout()
plt.show()

In [None]:
# Bar plot of mean absolute SHAP values
plt.figure(figsize=(10, 8))
shap.summary_plot(
    shap_result['shap_values'],
    X_test_genus,
    feature_names=X_train_genus.columns.tolist(),
    plot_type="bar",
    max_display=20,
    show=False
)
plt.tight_layout()
plt.show()

# Feature Selection via neural networks

## The finding

In [None]:
from functions import nn_feature_search
if False:
    nn_res = nn_feature_search(X_train, X_test, Y_train)

    if nn_res:
        elite_bacteria_list = nn_res.feature_names
        print(f"Selected {nn_res.n_features} elite bacterial drivers.")
        print(f"Expected Validation RMSE: {nn_res.rmse:.2f} days")
    else:
        print("No feature set found within the target range.")

## Comparing the genus vs neural network features on all models

# Feature Cutoff Cross-Validation

Test model performance at different taxonomic levels

In [None]:
from functions import cross_validate_feature_cutoffs, plot_feature_cutoff_comparison

# Test Random Forest at different taxonomic levels
rf_cutoff_results = cross_validate_feature_cutoffs(
    X_train_raw,
    Y_train,
    feature_levels=['Phylum', 'Class', 'Order', 'Family', 'Genus', 'Species'],
    model_type='RandomForest',
    cv_folds=5
)

In [None]:
# Visualize the results
plot_feature_cutoff_comparison(rf_cutoff_results, title="Random Forest Performance")

## Compare Different Models Across Taxonomic Levels

In [None]:
# Test other models
xgb_cutoff_results = cross_validate_feature_cutoffs(
    X_train_raw,
    Y_train,
    feature_levels=['Phylum', 'Class', 'Order', 'Family', 'Genus'],
    model_type='XGBoost',
    cv_folds=5
)

In [None]:
lgbm_cutoff_results = cross_validate_feature_cutoffs(
    X_train_raw,
    Y_train,
    feature_levels=['Phylum', 'Class', 'Order', 'Family', 'Genus'],
    model_type='LightGBM',
    cv_folds=5
)

# Advanced Model Visualizations

In [None]:
from functions import plot_residuals_analysis, plot_prediction_intervals, plot_learning_curves

# Residuals analysis for best Random Forest model
plot_residuals_analysis(Y_test, yhat, title="Random Forest Residuals Analysis")

In [None]:
# Prediction intervals (showing uncertainty from individual trees)
plot_prediction_intervals(
    Y_test, 
    yhat,
    prediction_std=rf_std,
    sample_indices=range(50),  # Show first 50 samples
    title="Random Forest Predictions with Uncertainty"
)

In [None]:
# Learning curves to check for overfitting
from sklearn.ensemble import RandomForestRegressor

rf_model = RandomForestRegressor(n_estimators=100, random_state=42, n_jobs=-1)
plot_learning_curves(
    rf_model,
    X_train_genus,
    Y_train,
    cv_folds=5,
    title="Random Forest Learning Curves (Genus Level)"
)

## Model Comparison

In [None]:
from functions import plot_model_comparison_heatmap

# Compare all models (assuming they've been run)
try:
    model_results = {
        'Random Forest': rf_results,
        'AdaBoost': ada_results,
        'Gradient Boosting': gb_results,
        'LightGBM': lgbm_results
    }
    plot_model_comparison_heatmap(model_results, title="Model Performance Comparison")
except NameError:
    print("Some models haven't been run yet. Run all benchmark functions first.")

In [None]:
from functions import final_battle
try:
    nn_res
except NameError:
    print("Neural network did not run")
else:
        feature_sets = {
        "Genus Level": X_train_genus,
        "NN Level": nn_res.X_train_elite
    }
        battle_stats = final_battle(feature_sets, Y_train)
