# Housekeeping

## Library imports

In [None]:
if True:
    import sys
    !{sys.executable} -m pip install -r requirements.txt

In [None]:
# TODO: remove unnecessary packages
import pandas as pd
import numpy as np

import sklearn
from sklearn import tree
from sklearn.preprocessing import LabelEncoder
from scipy import stats
from click.formatting import iter_rows


from sklearn.model_selection import (
    GridSearchCV,
    RepeatedStratifiedKFold,
    cross_val_score,
    train_test_split,
)
from sklearn.neighbors import KNeighborsRegressor, KNeighborsClassifier
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    RocCurveDisplay,
    accuracy_score,
    classification_report,
    confusion_matrix,
    f1_score,
    precision_recall_curve,
    roc_auc_score,
)
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance

from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RandomizedSearchCV

from colorama import Fore, Back, Style

import warnings

import matplotlib.pyplot as plt

import seaborn as sns

from supertree import SuperTree

## Settings

In [None]:
warnings.filterwarnings("ignore")
sklearn.set_config(transform_output="pandas")
print(Style.RESET_ALL)

## Data imports
Data was manually edited, to convert the mpa411.txt TSV format to a CSV format. Otherwise, Pandas was loading it as a single column, somehow. The first row, containing only "#mpa_vJun23_CHOCOPhlAnSGB_202403" was removed.

In [None]:
data = pd.read_csv('../data/raw/MAI3004_lucki_mpa411.csv')
metadata = pd.read_csv('../data/raw/MAI3004_lucki_metadata_safe.csv')
print(
    f"Data successfully imported. \n shape of data: {data.shape} \n "
    f"Shape of metadata: {metadata.shape}"
)

assert data.shape == (6903, 932), "Data has the wrong shape. Check the CSV formatting."
assert metadata.shape == (930, 6), "Metadata has the wrong shape. Check the CSV formatting."


## Function definitions
| Function Name | Description | Parameters |
|---------------|-------------|------------|


# Data preprocessing

## Merge data and metadata

In [None]:
sample_cols = [col for col in data.columns if col.startswith("mpa411_")]

sample_abundances = (
    data[['clade_name'] + sample_cols]
    .set_index('clade_name')
    .transpose()
    .rename_axis('original_sample_id')
    .reset_index()
    .rename(columns={'original_sample_id': 'sample_id'})
)

sample_abundances["sample_id"] = (
    sample_abundances["sample_id"].str.removeprefix(
        "mpa411_",
    )
)

metadata_common = metadata[
    metadata["sample_id"].isin(sample_abundances["sample_id"])
].copy()
merged_samples = metadata_common.merge(
    sample_abundances,
    on="sample_id",
    how="inner",
)

merged_samples.drop(columns=['year_of_birth', 'body_product'], inplace=True)
# YOB and body_product are omitted without sample dates.
# All samples are fecal.
# TODO: should we be accounting for sex? Do statistical analysis

print(f"Metadata rows (original): {metadata.shape[0]}")
print(f"Metadata rows with matching samples: {metadata_common.shape[0]}")
print(
    f"Metadata rows without matching samples: "
    f"{metadata_common.shape[0]-metadata_common.shape[0]}"
)
print(f"Merged dataframe shape: {merged_samples.shape}")

In [None]:
merged_samples.head()

## Encoding

In [None]:
# Sex and family_ID
encoded_samples = merged_samples.copy().dropna(subset="age_group_at_sample")

encoded_samples["sex"] = (
    encoded_samples["sex"]
    .fillna("unknown")
    .replace({"female": 1, "male": 0, "unknown": 2})
)
encoded_samples["family_id"] = LabelEncoder().fit_transform(
    encoded_samples["family_id"]
)


In [None]:
#Using days to better interpret the distance between age groups
encoding_guide = {
    '1-2 weeks': 10,
    '4 weeks': 28,
    '8 weeks': 56,
    '4 months': 120,
    '5 months': 150,
    '6 months': 180,
    '9 months': 270,
    '11 months': 330,
    '14 months': 420,
}
encoded_samples["age_group_at_sample"].replace(encoding_guide, inplace=True)

# consider in interpretation that the distances between the real age bins are not the same as our age groups

In [None]:
if False in pd.DataFrame(encoded_samples["age_group_at_sample"]).applymap(np.isreal): #fallback encoder
    age_encoder = LabelEncoder().fit(encoded_samples["age_group_at_sample"])
    encoded_samples["age_group_at_sample"] = age_encoder.transform(
        encoded_samples["age_group_at_sample"]
    )

    age_groups = dict(
        zip(age_encoder.classes_, age_encoder.transform(age_encoder.classes_))
    )
    print("Age group encoding:", age_groups)

else:
    print("Fallback encoding not needed")

## Missing check

In [None]:
missing_table = (
    encoded_samples.isna()
    .sum()
    .to_frame(name="missing_count")
    .assign(
        missing_percent=lambda df: (
            (df["missing_count"] / encoded_samples.shape[0] * 100).round(2)
        ),
    )
    .reset_index()
    .rename(columns={"index": "column"})
    .sort_values("missing_count", ascending=False)
    .query("missing_count != 0")
)

if len(missing_table) > 0:
    missing_table
else:
    print("No missing values detected.")

## Outlier check

In [None]:
numeric_cols = encoded_samples.select_dtypes(include=[np.number]).columns

q1 = encoded_samples[numeric_cols].quantile(0.25)
q3 = encoded_samples[numeric_cols].quantile(0.75)
iqr = q3 - q1

lower_bounds = q1 - 1.5 * iqr
upper_bounds = q3 + 1.5 * iqr

outlier_mask = (
    (encoded_samples[numeric_cols] < lower_bounds)
    | (encoded_samples[numeric_cols] > upper_bounds)
)
outlier_counts = outlier_mask.sum()
outlier_percent = (outlier_counts / encoded_samples.shape[0] * 100).round(2)

outlier_table = (
    pd.DataFrame({
        "column": numeric_cols,
        "lower_bound": lower_bounds,
        "upper_bound": upper_bounds,
        "outlier_count": outlier_counts,
        "outlier_percent": outlier_percent,
    })
    .query("outlier_count > 0")
    .sort_values("outlier_percent", ascending=False)
    .reset_index(drop=True)
)

outlier_table

## Normalisation check

In [None]:
normalized_samples = encoded_samples.copy()
print("Shapiro-Wilk Normality Test")

for column in numeric_cols:
    data_nona = normalized_samples[column].dropna()
    stat, p_value = stats.shapiro(data_nona)

    if p_value > 0.05:
        print(Fore.GREEN + f"{column}: Normally Distributed (p={p_value:.4f})")

    else:
        print(
            Fore.RED
            + f"{column}: Not Normally Distributed (p={p_value:.4f})"
        )

print(Style.RESET_ALL)

## Train-test split before pre-processing

In [None]:
from sklearn.model_selection import GroupShuffleSplit
feature_cols = normalized_samples.columns.difference(["sample_id", "age_group_at_sample"]) # These variables will get removed from X

X = normalized_samples[feature_cols]
Y = normalized_samples["age_group_at_sample"]

gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=3004)
train_indicies, test_indicies = next(gss.split(X, Y, groups=X['family_id']))
X_train_raw = X.iloc[train_indicies]
X_test_raw = X.iloc[test_indicies]
Y_train = Y.iloc[train_indicies]
Y_test = Y.iloc[test_indicies]

assert X_train_raw.shape[1] == X_test_raw.shape[1], "Feature columns do not match between train and test sets."
assert X_train_raw.shape[0] == Y_train.shape[0] and X_test_raw.shape[0] == Y_test.shape[0], "X and Y do not have the same length."

print("Train shape:", X_train_raw.shape, "| Test shape:", X_test_raw.shape)

## Normalising data using clr transformation

In [None]:
"""The CLR function based on: https://medium.com/@nextgendatascientist/a-guide-for-data-scientists-log-ratio-transformations-in-machine-learning-a2db44e2a455"""

def clr_transform(X, epsilon=1e-9):
    """
    Compute CLR with a tunable zero-replacement value (epsilon).
    """

    #To capture metadata from the original dataframe
    if isinstance(X, pd.DataFrame):
        index = X.index
        columns = X.columns
        X_arr = X.values
    else:
        X_arr = X

    # 1. Replace zeros with epsilon (tunable parameter)
    X_replaced = np.where(X == 0, epsilon, X)

    # 2. Compute Geometric Mean
    # exp(mean(log)) is safer and standard for this
    gm = np.exp(np.log(X_replaced).mean(axis=1, keepdims=True))

    # 3. CLR transformation
    X_clr = np.log(X_replaced / gm)


    #Rebulding back a NumPy array to a dataframe
    if isinstance(X, pd.DataFrame):
        return pd.DataFrame(X_clr, index=index, columns=columns)

    return X_clr

In [None]:
X_train = clr_transform(X_train_raw)
X_test = clr_transform(X_test_raw)

In [None]:
#To activate back the raw data without normalisation

#X_train = X_train_raw
#X_test = X_test_raw

In [None]:
print("Shapiro-Wilk Normality Test (after Log Normlaisation)\n")

for col in X_train.columns:

    # 1. Get the data for this column from both sets
    train_data = X_train[col].dropna()
    test_data = X_test[col].dropna()

    # 2. Run Shapiro test on both
    stat_train, p_train = stats.shapiro(train_data)
    stat_test, p_test = stats.shapiro(test_data)

    # 3. Determine status (Both must be > 0.05 to be truly "Normal")
    is_train_normal = p_train > 0.05
    is_test_normal = p_test > 0.05

    # 4. Print Logic
    # If both are Green
    if is_train_normal and is_test_normal:
        print(Fore.GREEN + f"‚úî {col}: Normal (Train p={p_train:.3f}, Test p={p_test:.3f})")

    # If one is Red (Mixed results)
    elif is_train_normal or is_test_normal:
        print(Fore.YELLOW + f"‚ö† {col}: Inconsistent (Train p={p_train:.3f}, Test p={p_test:.3f})")

    # If both are Red
    else:
        print(Fore.RED + f"‚úò {col}: Not Normal (Train p={p_train:.3f}, Test p={p_test:.3f})")

print(Style.RESET_ALL)

# Exploratory data analysis

In [None]:
print(merged_samples.shape)
merged_samples.head()

In [None]:
# Dataset overview
print("Number of samples:", len(merged_samples))
print(
    "Number of unique families (family_id):",
    merged_samples["family_id"].nunique(),
)
print("Number of columns (metadata + features):", merged_samples.shape[1])

In [None]:
# Samples per family
samples_per_family = merged_samples["family_id"].value_counts()
samples_per_family.describe()

In [None]:
samples_per_family.hist(bins=20)
plt.xlabel("Number of samples per family")
plt.ylabel("Number of families")
plt.title("Distribution of samples per family")
plt.show()

In [None]:
#distribution of age groups
merged_samples["age_group_at_sample"].value_counts(dropna=False)

In [None]:
merged_samples["age_group_at_sample"].value_counts().plot(kind="bar")
plt.title("Distribution of age groups")
plt.ylabel("Number of samples")
plt.xticks(rotation=45)
plt.show()

In [None]:
#dimensionality and sparsity of the microbiome feature matrix
metadata_cols = [
    "sample_id",
    "family_id",
    "sex",
    "body_product",
    "age_group_at_sample",
    "year_of_birth",
]
feature_cols = [c for c in merged_samples.columns if c not in metadata_cols]

X = merged_samples[feature_cols]

print("Feature matrix shape:", X.shape)
print("Overall fraction of zeros:", (X == 0).mean().mean())

In [None]:
#number of observed taxa per sample
nonzero_per_sample = (X > 0).sum(axis=1)
nonzero_per_sample.describe()

In [None]:
nonzero_per_sample.hist(bins=50)
plt.xlabel("Number of non-zero taxa per sample")
plt.ylabel("Number of samples")
plt.title("Non-zero taxa per sample")
plt.show()

In [None]:
# Total abundance per sample (sanity check)
total_abundance = X.sum(axis=1)
total_abundance.describe()

In [None]:
total_abundance.hist(bins=50)
plt.xlabel("Total abundance per sample")
plt.ylabel("Number of samples")
plt.title("Total microbial abundance per sample")
plt.show()

In [None]:
# Distribution of feature prevalence
feature_prevalence = (X > 0).sum(axis=0)
feature_prevalence.describe()

In [None]:
feature_prevalence.hist(bins=50)
plt.xlabel("Number of samples in which taxon is present")
plt.ylabel("Number of taxa")
plt.title("Feature prevalence distribution")
plt.show()

In [None]:
# Distribution of non-zero abundances (log scale)

nonzero_values = X.values[X.values > 0]
plt.hist(np.log10(nonzero_values), bins=50)
plt.xlabel("log10(abundance)")
plt.ylabel("Frequency")
plt.title("Distribution of non-zero abundances (log10 scale)")
plt.show()

In [None]:
# PCA visualization
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# Use a subset of features for speed
prevalence = (X > 0).sum(axis=0)
top_features = prevalence.sort_values(ascending=False).head(500).index

X_sub = X[top_features]

# Scale features
X_scaled = StandardScaler().fit_transform(X_sub)

# PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

# Plot
age = merged_samples["age_group_at_sample"]

plt.figure(figsize=(8,6))
plt.scatter(X_pca.iloc[:, 0], X_pca.iloc[:, 1],
            c=pd.factorize(age)[0], cmap="viridis", alpha=0.6)
plt.colorbar(label="Age group (encoded)")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.title("PCA of samples colored by age group")
plt.show()

print("Explained variance ratio:", pca.explained_variance_ratio_)

### Summary of EDA
The dataset consists of 930 stool samples derived from multiple individuals across different families and contains approximately 6,900 microbiome features, making it a high-dimensional and highly sparse dataset. Each sample contains on average around 300 detected taxa, while the total microbial abundance per sample is relatively stable, indicating that sequencing depth is consistent across samples.
Most taxa are rare and occur in only a small fraction of samples, whereas a small subset of taxa is highly prevalent across the cohort. The distribution of non-zero abundances follows an approximately log-normal shape, which is typical for microbiome sequencing data (e.g., Lutz et al., 2022).
A PCA projection based on the most prevalent taxa does not reveal sharply separated clusters but shows a gradual age-related gradient, suggesting that age-related variation in microbiome composition is present but represents only a limited fraction of the total variance in the data.

# Model Training

### Filtering for features at the genus level

In [None]:
def filter_genus(df_uf): #Defining a function that filters a dataframe to only include columns with features at genus level
    df_uf = df_uf.drop(list(df_uf.filter(regex="s__")),axis=1,inplace=False) #Drops columns that include features at species level
    df_uf = df_uf.filter(regex="g__") #Drops columns that include features broader than genus level
    return df_uf

In [None]:
X_train_genus = filter_genus(X_train)
X_test_genus = filter_genus(X_test)

In [None]:
"target_names" in X_train_genus.columns

In [None]:
print("Features used by model:")
print(X_train_genus.columns.tolist())

### Random Forest Regressor with Train/Test split (Genus)

### Base model

In [None]:
# create model
# Base model
rf = RandomForestRegressor(
    random_state=42,
    n_jobs=-1,
    oob_score=True
)

### Search for the best model

In [None]:
# Hyperparameter space
param_dist = {
    "n_estimators": [300, 500, 800, 1000],
    "max_depth": [None, 10, 20, 40],
    "min_samples_split": [2, 5, 10],
    "min_samples_leaf": [1, 2, 5, 10],
    "max_features": ["sqrt", "log2"]
}

# Randomized search with 5-fold CV on training set
search = RandomizedSearchCV(
    estimator=rf,
    param_distributions=param_dist,
    n_iter=25,
    cv=5,
    scoring="neg_mean_squared_error",
    random_state=42,
    n_jobs=-1,
    verbose=2
)

# Fit search
search.fit(X_train_genus, Y_train)

# Best model
best_model = search.best_estimator_
print("\nBest hyperparameters:", search.best_params_)

In [None]:
# evaluate model
yhat = best_model.predict(X_test_genus)
mse = mean_squared_error(Y_test, yhat)
print('Mean Squared Error: %.3f' % mse)

# CV RMSE of best model
best_cv_rmse = (-search.best_score_) ** 0.5
print("Best CV RMSE:", best_cv_rmse)

#R-squared
from sklearn.metrics import r2_score
r2 = r2_score(Y_test, yhat)
print('R2 Score: %.3f' % r2)

In [None]:
plt.figure(figsize=(6,6))
sns.scatterplot(x=Y_test, y=yhat, alpha=0.6)
plt.plot([Y_test.min(), Y_test.max()],
         [Y_test.min(), Y_test.max()],
         color="red", linestyle="--")

plt.xlabel("Actual Values")
plt.ylabel("Predicted Values")
plt.title("Predicted vs Actual (Genus-level RF)")
plt.tight_layout()
plt.show()

In [None]:
residuals = Y_test - yhat

plt.figure(figsize=(6,4))
sns.scatterplot(x=yhat, y=residuals, alpha=0.6)
plt.axhline(0, linestyle="--", color="red")

plt.xlabel("Predicted values")
plt.ylabel("Residuals")
plt.title("Residuals vs Predictions")
plt.tight_layout()
plt.show()

In [None]:
importances = pd.Series(
    best_model.feature_importances_,
    index=X_train_genus.columns
).sort_values(ascending=False)

top_n = 20

plt.figure(figsize=(8,6))
sns.barplot(
    x=importances.head(top_n),
    y=importances.head(top_n).index
)

plt.xlabel("Feature importance")
plt.ylabel("Genus")
plt.title(f"Top {top_n} most important genera")
plt.tight_layout()
plt.show()

In [None]:
st = SuperTree(
    best_model,
    X_train_genus,
    Y_train
)

st.show_tree(which_tree=0)

In [None]:
# get predictions from each tree on the test set
all_tree_preds = np.array([tree.predict(X_test_genus) for tree in best_model.estimators_])

# compute the mean prediction (Random Forest final prediction)
rf_pred = all_tree_preds.mean(axis=0)

# compute standard deviation per sample (uncertainty)
rf_std = all_tree_preds.std(axis=0)

plt.figure(figsize=(10,6))

# plot all tree predictions (semi-transparent lines)
for i in range(all_tree_preds.shape[0]):
    plt.plot(Y_test.values, all_tree_preds[i], 'o', color='lightgray', alpha=0.3)

# plot Random Forest mean prediction
plt.scatter(Y_test, rf_pred, color='blue', label='RF mean prediction', s=40)

plt.errorbar(Y_test, rf_pred, yerr=rf_std, fmt='o', color='red', alpha=0.5, label='¬±1 std across trees')

plt.plot([Y_test.min(), Y_test.max()],
         [Y_test.min(), Y_test.max()],
         color='black', linestyle='--', label='Perfect prediction')

plt.xlabel("Actual Age Group")
plt.ylabel("Predicted Age Group")
plt.title("Random Forest ‚Äì Forest plot of tree predictions")
plt.legend()
plt.tight_layout()
plt.show()


# Alternative Models

### XGBoost Alternative

In [None]:
import xgboost as xgb
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import mean_squared_error, r2_score

# Base model
xgb_model = xgb.XGBRegressor(
    objective='reg:squarederror',
    random_state=42,
    n_jobs=-1
)


#Best hyperparameters: {'subsample': 0.7, 'reg_lambda': 1.0, 'reg_alpha': 0.1, 'n_estimators': 1000, 'max_depth': 3, 'learning_rate': 0.01, 'colsample_bytree': 0.2}

# Feature Selection via neural networks

## The finding

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers, constraints, callbacks, initializers
import pandas as pd
import numpy as np
import random
import os

# 1. LOCK DOWN RANDOMNESS (Reproducibility)
# ---------------------------------------------------------
def set_seeds(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    print(f"‚úÖ Random Seeds fixed to {seed}")

set_seeds(42)

# 2. SETUP DATA
# ---------------------------------------------------------
X_train_tf = X_train.astype('float32')
y_train_tf = Y_train.values.astype('float32')

# 3. THE GATEKEEPER LAYER
# ---------------------------------------------------------
class GatekeeperLayer(layers.Layer):
    def __init__(self, num_features, l1_penalty=0.01, **kwargs):
        super(GatekeeperLayer, self).__init__(**kwargs)
        self.num_features = num_features
        self.l1_penalty = l1_penalty

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(self.num_features,),
            # Start tiny (but not dead zero) so gradients can flow if needed
            initializer=initializers.RandomUniform(minval=0.0, maxval=0.002),
            trainable=True,
            constraint=constraints.NonNeg(),
            regularizer=regularizers.l1(self.l1_penalty)
        )

    def call(self, inputs):
        return inputs * self.w

# 4. BUILD & TRAIN FUNCTION (The "Sweet Spot" 32-Neuron Brain)
# ---------------------------------------------------------
def train_candidate(l1_strength, run_seed):
    # Set seed again for this specific run to ensure diversity if we change seeds
    tf.random.set_seed(run_seed)

    inputs = layers.Input(shape=(X_train_tf.shape[1],))

    # Layer 1: Selection
    x = GatekeeperLayer(X_train_tf.shape[1], l1_penalty=l1_strength)(inputs)

    # Layer 2: The "Brain" - Set to 32 as requested
    x = layers.Dense(32, activation='relu')(x)

    outputs = layers.Dense(1, activation='linear')(x)

    model = models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.005),
                  loss='mse', metrics=['mae'])

    # Train
    history = model.fit(
        X_train_tf, y_train_tf,
        epochs=120,
        batch_size=32,
        validation_split=0.2,
        verbose=0,
        callbacks=[callbacks.EarlyStopping(patience=8, restore_best_weights=True)]
    )

    best_mse = min(history.history['val_loss'])
    val_rmse = best_mse ** 0.5

    weights = model.layers[1].get_weights()[0]
    n_selected = np.sum(weights > 1e-5)

    return n_selected, val_rmse, weights

# 5. THE MONTE CARLO SEARCH LOOP
# ---------------------------------------------------------
# Smoother gradient of penalties
penalties = [0.5, 1.0, 2.0, 3.0, 5.0, 7.5, 10.0]
repeats_per_penalty = 5

champion_model = {
    'rmse': float('inf'),
    'n_features': 0,
    'penalty': 0,
    'weights': None
}

print(f"{'Penalty':<8} | {'Run':<4} | {'Features':<10} | {'Val RMSE':<10} | {'Status'}")
print("-" * 65)

for p in penalties:
    for i in range(repeats_per_penalty):
        # We vary the seed slightly per run so "Repeats" aren't identical
        run_seed = 42 + i

        n_feats, val_rmse, weights = train_candidate(l1_strength=p, run_seed=run_seed)

        # Criteria: Between 50 and 1500 features
        is_valid = n_feats < 1250 and n_feats > 50
        status = "‚ùå Too Many"

        if is_valid:
            status = "‚úÖ Candidate"
            if val_rmse < champion_model['rmse']:
                champion_model['rmse'] = val_rmse
                champion_model['n_features'] = n_feats
                champion_model['penalty'] = p
                champion_model['weights'] = weights
                status = "üèÜ NEW BEST"
        elif n_feats <= 50:
            status = "‚ö†Ô∏è Too Sparse"

        print(f"{p:<8} | {i+1:<4} | {n_feats:<10} | {val_rmse:<10.2f} | {status}")

# 6. SAVE & EXPORT
# ---------------------------------------------------------
if champion_model['weights'] is not None:
    print("\n------------------------------------------------")
    print(f"üéâ FINAL CHAMPION FOUND!")
    print(f"Penalty: {champion_model['penalty']}")
    print(f"Features: {champion_model['n_features']}")
    print(f"Validation RMSE: {champion_model['rmse']:.4f}")
    print("------------------------------------------------")

    gate_weights = champion_model['weights']
    df_imp = pd.DataFrame({'Bacteria': X_train.columns, 'Score': gate_weights})
    selected_bacteria = df_imp[df_imp['Score'] > 1e-5].sort_values('Score', ascending=False)

    print("Top 10 Selected Bacteria:")
    print(selected_bacteria.head(10))

    X_train_elite = X_train[selected_bacteria['Bacteria']]
    X_test_elite = X_test[selected_bacteria['Bacteria']]

    print(f"\nCreated X_train_elite with shape: {X_train_elite.shape}")
else:
    print("\n‚ùå No models found in the sweet spot. Penalties might still be too low/high.")

## Comparing the genus vs neural network features

In [None]:
from sklearn.model_selection import RepeatedKFold, cross_validate
import pandas as pd
import numpy as np
import re

# 1. SETUP REPEATED VALIDATION
# 7-fold, repeated 3 times = 21 total trainings per model
# This is the "Gold Standard" for stability.
rkf = RepeatedKFold(n_splits=7, n_repeats=3, random_state=42)

datasets = {
    "Biology (Genus 1200)": X_train_genus,
    "Math (NN Elite)": X_train_elite
}

# 2. DEFINING THE CHAMPIONS
# Using your optimized hyperparameters
models = {
    "XGBoost": xgb.XGBRegressor(
        n_estimators=1000, learning_rate=0.01, max_depth=3,
        subsample=0.7, colsample_bytree=0.2, reg_alpha=0.1,
        reg_lambda=1.0, n_jobs=-1, random_state=42
    ),
    "Random Forest": RandomForestRegressor(
        n_estimators=1000, max_depth=20, max_features='sqrt',
        n_jobs=-1, random_state=42
    )
}

# 3. THE ULTIMATE BATTLE
results = []

print(f"{'Dataset':<20} | {'Model':<15} | {'Avg R2':<10} | {'Std Dev':<10}")
print("-" * 65)

for data_name, X_data in datasets.items():

    # Clean column names
    X_clean = X_data.copy()
    X_clean.columns = [re.sub('[^A-Za-z0-9_]+', '', str(col)) for col in X_clean.columns]

    for model_name, model in models.items():
        # Run 21-fold CV
        cv_results = cross_validate(
            model, X_clean, Y_train,
            cv=rkf,
            scoring='r2',
            n_jobs=-1
        )

        mean_r2 = np.mean(cv_results['test_score'])
        std_r2 = np.std(cv_results['test_score'])

        print(f"{data_name:<20} | {model_name:<15} | {mean_r2:.4f}     | ¬±{std_r2:.3f}")

        results.append({
            'Dataset': data_name,
            'Model': model_name,
            'R2': mean_r2,
            'Std': std_r2
        })

print("-" * 65)