# Validation of Probiotic Growth Predictions
This notebook simulates the growth of MCMMs and validates the predicted probiotic growth rates for participants in Validation Study A

In [None]:
import pandas as pd
import numpy as np
import micom
import seaborn as sns
import matplotlib.pyplot as plt
from plotnine import *
import scipy.stats
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import confusion_matrix

%matplotlib inline

## Collect Metadata

In [None]:
# Collect metadata 
metadata = pd.read_table('../data/studyA_metadata.tsv')

# Drop header and control rows
metadata = metadata.drop(metadata.index[[0,83,84]])

# Create dictionary with subject IDs
subject_dict = metadata.set_index('Name')['subject_id'].to_dict()
treatment_dict = metadata.set_index('Name')['treatment_group'].to_dict()

## Import qPCR Table
qPCR data serves as ground truth for validation of predicted growth rates

In [None]:
qPCR = pd.read_table('../data/qPCR-reaction-table.tsv')

# Remove cycle threshold of -inf
qPCR_Ct = qPCR[qPCR['Ct']!= -np.inf].sort_values(by = 'Ct') 

# Remove reads that failed QC, filter
qPCR_filter = qPCR_Ct[(qPCR_Ct['PassQC'] == True)&(qPCR_Ct['PassFilter'] == True)]
# qPCR_filter = qPCR_Ct.copy()


# Remove less efficient primers for EHAL, CBEI
qPCR_filter = qPCR_filter.loc[qPCR['Primers'] != 'EHAL_AN']
qPCR_filter = qPCR_filter.loc[qPCR['Primers'] != 'CBEI_AB']

# Format column types
qPCR_filter['Week'] = qPCR_filter['Week'].astype('str')

# Sort table
qPCR_filter = qPCR_filter.sort_values(by = ['Strain','Week'])

# Rename strain abbreviations 
qPCR_filter['Strain'] = qPCR_filter['Strain'].str.replace(
    'AMUC','Akkermansia_muciniphila').str.replace(
    'BINF','Bifidobacterium_longum').str.replace(
    'CBEI','Clostridium_beijerinckii').str.replace(
    'CBUT','Clostridium_butyricum').str.replace(
    'EHAL','Anaerobutyricum_hallii')

# Calculate negative cycle threshold
qPCR_filter['-Ct'] = -1*qPCR_filter['Ct']

qPCR_filter.set_index(['Subject','Strain'], inplace = True)

## Isolate the Treatment Group Data
Binarizing 1/Ct values around thresholds allows for comparison with predicted growth rates.

In [None]:
# Isolate treatment group
qPCR_wbf11 = qPCR_filter.loc[qPCR_filter['Treatment'] == 'WBF-011'].reset_index()

# Average multiple hits per sample 
qPCR_wbf11 = qPCR_wbf11.groupby(['Subject','Strain','Week']).mean(numeric_only = True).reset_index()

# Calculate strain-specific threshold for 1/Ct
strain_thresholds = qPCR_wbf11.loc[qPCR_wbf11['Week'] == '0'].groupby('Strain')['-Ct'].mean().to_dict()

# Map to qPCR data
qPCR_wbf11['threshold'] = qPCR_wbf11['Strain'].map(strain_thresholds)

# Binarize the qPCR data by the threshold 
qPCR_wbf11.loc[qPCR_wbf11['-Ct']> qPCR_wbf11['threshold'], 'qPCR_engraftment'] = 1
qPCR_wbf11['qPCR_engraftment'].fillna(0, inplace = True)
qPCR_wbf11

## Prepare MCMM Medium
Load standard european medium as assumed background diet, and add inulin

In [None]:
# Load Medium 
eu_medium = pd.read_csv('../european_medium.csv')

# Add inulin supplement, as in trial
eu_medium_wbf11 = pd.concat([eu_medium, pd.DataFrame({'metabolite':['inulin'],
                  'flux':[0.061448], 
                  'name':'inulin',
                  'reaction':'EX_inulin_m',
                  'global_id':'EX_inulin(e)'}, index = ['EX_inulin_m'])])

## Grow MCMMs
Grow MCMMs with high tradeoff value 

In [None]:
manifest = pd.read_csv('../WBF011_treated/manifest.csv')

# Grow models
growth = micom.workflows.grow(manifest, 
                                  model_folder = '../WBF011_treated/',
                                  medium = eu_medium_wbf11,
                                  tradeoff = 0.99, 
                                  strategy = 'none',                                    
                                  threads = 10)

## Collect Predicted Growth Rates
Isolate predictions for probiotic species a

In [None]:
# Gather growth rates
rates = growth.growth_rates

# Isolate Probiotics
rates_probiotic = rates[rates.index.isin(['Akkermansia_muciniphila',
                                                              'Clostridium_beijerinckii',
                                                              'Clostridium_butyricum',
                                                              'Bifidobacterium_longum',
                                                              'Anaerobutyricum_hallii'])].reset_index()

# Log scale growth rates
rates_probiotic['growth_rate_scaled'] = np.log10(rates_probiotic[['growth_rate']])

# Threshold log-scaled growth rates
rates_probiotic.loc[rates_probiotic['growth_rate_scaled']>=-2, 'MCMM_engraftment'] = 1
rates_probiotic.loc[rates_probiotic['growth_rate_scaled']<-2, 'MCMM_engraftment'] = 0

## Create Agreement Heatmap
Find agreement between predictions and measures

In [None]:
# Create qPCR Engraftment Pivot Table
df1 = pd.pivot_table(qPCR_wbf11[qPCR_wbf11['Week'] == '12'], 
                    index='Strain',
                    columns='Subject',
                    values='qPCR_engraftment').fillna(0.0)

# Create MCMM Engraftment Pivot Table
df2 = pd.pivot_table(rates_probiotic, 
                    index='compartments',
                    columns='sample_id',
                    values='MCMM_engraftment').fillna(0.0)

# Ensure both dataframes are numeric (convert from strings if necessary)
df1 = df1.astype(float)
df2 = df2.astype(float)

# Calculate agreement matrix
agreement_matrix = (df1 == df2).astype(int)
agreement_matrix.index.name = 'Probiotic Species'
agreement_matrix.columns.name = 'Samples'
agreement_matrix.index = agreement_matrix.index.str.replace('_', ' ')
agreement_matrix.index =agreement_matrix.index.str.replace('longum','infantis')
# Plot heatmap of agreement
plt.figure(figsize=(18, 12))
sns.set(font_scale=4) 
ax = sns.heatmap(agreement_matrix, cmap='coolwarm_r',
                 annot=False, cbar=False, square=False, 
                 linewidths=1, linecolor='white')
ax.set_xticklabels([])
plt.title('Agreement Between MCMM and qPCR')
plt.savefig('../figures/agreement_heatmap_studyA.svg', dpi=300, bbox_inches='tight')  # Save as PNG with 300 DPI

## Calculate Cohen's Kappa
Measure the agreement 

In [None]:
# Flatten predictions and measures
ratings1 = df1.values.flatten()
ratings2 = df2.values.flatten()

# Compute Cohen's kappa
kappa = cohen_kappa_score(ratings1, ratings2)

print(f"Cohen's kappa: {kappa:.4f}")

## Calculate % Agreement
Calculate agreement as a percentage

In [None]:
agree_ones = agreement_matrix.sum().sum()
growth_ones = df2.sum().sum()
total = agreement_matrix.size
print(str((agree_ones/total).round(4)*100)+'%')

## Plot Confusion Matricies
Second visualization of the results

In [None]:
strains = df1.index

# Unified font sizes
FONTS = {
    "suptitle": 32,
    "title": 28,
    "axis_labels": 28,
    "tick_labels": 25,
    "heatmap_annotations": 28
}
# Set up Seaborn style
sns.set(style="whitegrid", font_scale=FONTS["tick_labels"]/10)  # scale factor for seaborn

# Set up subplot grid: 2 rows, 3 columns (5 strains + 1 aggregate)
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()  # Flatten for easy iteration

# Define consistent color limits
vmin, vmax = 0, 21

# Aggregate all results
all_y_true = df1.values.flatten()
all_y_pred = df2.values.flatten()
cm_agg = confusion_matrix(all_y_pred, all_y_true)

# Plot aggregated confusion matrix
sns.heatmap(cm_agg, annot=True, fmt="d", cmap="Reds",
            xticklabels=["No Growth", "Growth"],
            yticklabels=["No Growth", "Growth"],
            ax=axes[0],
            vmin=vmin*5, vmax=vmax*5,
            annot_kws={"fontsize": FONTS["heatmap_annotations"]})
ax = axes[0]
ax.set_xticklabels(ax.get_xticklabels(), rotation=0)
ax.set_yticklabels(ax.get_yticklabels(), rotation=90)
# Loop through each strain
for i, strain in enumerate(strains):
    y_true = df1.loc[strain].values.flatten()
    y_pred = df2.loc[strain].values.flatten()

    cm = confusion_matrix(y_pred, y_true, labels=[0, 1])
    if cm.shape == (1, 1):
        cm = np.array([[cm[0, 0], 0], [0, 0]])

    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=["No Growth", "Growth"],
                yticklabels=["No Growth", "Growth"],
                ax=axes[i+1],
                vmin=vmin, vmax=vmax,
                annot_kws={"fontsize": FONTS["heatmap_annotations"]})
    ax = axes[i+1]
    ax.set_xticklabels(ax.get_xticklabels(), rotation=0)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=90)

# Set titles
strain_titles = ["Aggregate \n Across Strains",
                 'Akkermansia \n muciniphila',
                 'Anaerobutyricum \n hallii',
                 'Bifidobacterium longum \n sp. infantis',
                 'Clostridium \n beijerinckii',
                 'Clostridium \n butyricum']

for ax, title in zip(axes, strain_titles):
    ax.set_title(title, fontsize=FONTS["title"])

# Set global labels
fig.suptitle("Confusion Matrices of Model Predictions vs. qPCR", fontsize=FONTS["suptitle"])
fig.text(-0.02, 0.5, "Model Prediction", va='center', rotation='vertical', fontsize=FONTS["axis_labels"])
fig.text(0.5, -0.02, "qPCR Ground Truth", ha='center', fontsize=FONTS["axis_labels"])

plt.tight_layout()
plt.savefig('../figures/confusion_matricies_studyA.svg', dpi=300, bbox_inches='tight')
plt.show()


## Calculate Fisher's Exact Test for Aggregate Result


In [None]:
scipy.stats.fisher_exact(cm_agg)

## Sensitivity Analysis
Test how growth rate thresholds impact prediction accuracy

In [None]:
res = pd.DataFrame()
for threshold in np.logspace(-5, 0.7, num=50):
        # Threshold log-scaled growth rates
        rates_probiotic.loc[rates_probiotic['growth_rate']>=threshold, 'MCMM_engraftment'] = 1
        rates_probiotic.loc[rates_probiotic['growth_rate']<threshold, 'MCMM_engraftment'] = 0
        
        
        # Create MCMM Engraftment Pivot Table
        df2 = pd.pivot_table(rates_probiotic, 
                            index='compartments',
                            columns='sample_id',
                            values='MCMM_engraftment').fillna(0.0)
        df2.index = df2.index.str.replace('_Clostridium_','Clostridium')
        # Ensure both dataframes are numeric (convert from strings if necessary)
        df2 = df2.astype(float)
        df2 = df2[df1.columns]
    
        # Calculate agreement matrix
    
        agreement_matrix = (df1 == df2).astype(int)
        
        ratings1 = df1.values.flatten()
        ratings2 = df2.values.flatten()
        
        # Compute Cohen's kappa
        kappa = cohen_kappa_score(ratings1, ratings2)
        
        agree_ones = agreement_matrix.sum().sum()
        growth_ones = df2.sum().sum()
        total = agreement_matrix.size
        
        cm = confusion_matrix(ratings1, ratings2)
        
        if cm.shape == (1, 1):
            cm = np.array([[cm[0, 0], 0], [0, 0]])
        fisher = scipy.stats.fisher_exact(cm)
        res = pd.concat([res, pd.DataFrame({'agreement':[(agree_ones/total)],
                                            'kappa':[kappa],
                                            'fisher':[fisher[1]],
                                            'threshold':[threshold]
                                           }
                                          )
                        ]
                       )
# Convert results into long format 
res_long = res.melt(id_vars='threshold', value_vars=['agreement', 'kappa'],
                    var_name='metric', value_name='value')

## Plot sensitivity

In [None]:
sensitivity_plot = (ggplot(res_long, aes(x='threshold', y='value', color='metric'))
                    + geom_line(size=1.2)
                    + scale_x_log10()
                    +geom_vline(xintercept = 0.01)
                    + labs(y='Value', x='Growth Rate Threshold', color='Metric')
                    +scale_color_discrete(limits = ['agreement','kappa'], labels = ['Percent Agreement','Cohens Kappa'])
                    + theme_minimal()
                   )

ggsave(sensitivity_plot, '../figures/sensitivity_studyA.svg',dpi = 200)
sensitivity_plot