In [None]:
import pandas as pd
import numpy as np
import micom
import seaborn as sns
import matplotlib.pyplot as plt
from plotnine import *
import sklearn.metrics
import scipy.stats

%matplotlib inline

# Validation of Probiotic Growth Predictions
This notebook simulates the growth of MCMMs and validates the predicted probiotic growth rates

## Collect Metadata

In [None]:
# Collect metadata 
metadata = pd.read_table('../data/Hiseq_metagenomic_202_190916 metadata_conditions.txt')

# Drop header and control rows
metadata = metadata.drop(metadata.index[[0,83,84]])

# Create dictionary with subject IDs
subject_dict = metadata.set_index('Name')['subject_id'].to_dict()
treatment_dict = metadata.set_index('Name')['treatment_group'].to_dict()

## Import qPCR Table
qPCR data serves as ground truth for validation of predicted growth rates

In [None]:
qPCR = pd.read_table('../data/qPCR-reaction-table.tsv')

# Remove cycle threshold of -inf
qPCR_Ct = qPCR[qPCR['Ct']!= -np.inf].sort_values(by = 'Ct') 

# Remove reads that failed QC, filter
qPCR_filter = qPCR_Ct[(qPCR_Ct['PassQC'] == True)&(qPCR_Ct['PassFilter'] == True)]
# qPCR_filter = qPCR_Ct.copy()


# Remove less efficient primers for EHAL, CBEI
qPCR_filter = qPCR_filter.loc[qPCR['Primers'] != 'EHAL_AN']
qPCR_filter = qPCR_filter.loc[qPCR['Primers'] != 'CBEI_AB']

# Format column types
qPCR_filter['Week'] = qPCR_filter['Week'].astype('str')

# Sort table
qPCR_filter = qPCR_filter.sort_values(by = ['Strain','Week'])

# Rename strain abbreviations 
qPCR_filter['Strain'] = qPCR_filter['Strain'].str.replace(
    'AMUC','Akkermansia_muciniphila').str.replace(
    'BINF','Bifidobacterium_longum').str.replace(
    'CBEI','Clostridium_beijerinckii').str.replace(
    'CBUT','Clostridium_butyricum').str.replace(
    'EHAL','Anaerobutyricum_hallii')

# Calculate inverse cycle threshold
qPCR_filter['Ct_inv'] = 1/qPCR_filter['Ct']

qPCR_filter.set_index(['Subject','Strain'], inplace = True)

## Isolate the Treatment Group Data
Binarizing 1/Ct values around thresholds allows for comparison with predicted growth rates.

In [None]:
# Isolate treatment group
qPCR_wbf11 = qPCR_filter.loc[qPCR_filter['Treatment'] == 'WBF-011'].reset_index()

# Average multiple hits per sample 
qPCR_wbf11 = qPCR_wbf11.groupby(['Subject','Strain','Week']).mean(numeric_only = True).reset_index()

# Calculate strain-specific threshold for 1/Ct
strain_thresholds = qPCR_wbf11.loc[qPCR_wbf11['Week'] == '0'].groupby('Strain')['Ct_inv'].mean().to_dict()

# Map to qPCR data
qPCR_wbf11['threshold'] = qPCR_wbf11['Strain'].map(strain_thresholds)

# Binarize the qPCR data by the threshold 
qPCR_wbf11.loc[qPCR_wbf11['Ct_inv']> qPCR_wbf11['threshold'], 'qPCR_engraftment'] = 1
qPCR_wbf11['qPCR_engraftment'].fillna(0, inplace = True)
qPCR_wbf11

## Prepare MCMM Medium
Load standard european medium as assumed background diet

In [None]:
# Load Medium 
eu_medium_wbf11 = micom.qiime_formats.load_qiime_medium('../european_medium.qza')

# Add inulin supplement, as in trial
eu_medium_wbf11.loc['EX_inulin_m', 'flux'] = 4

## Grow MCMMs
Grow MCMMs with high tradeoff value 

In [None]:
manifest = pd.read_csv('../WBF011_models/manifest.csv')

# Grow models
growth_agora1 = micom.workflows.grow(manifest, 
                                  model_folder = '../WBF011_treated/',
                                  medium = eu_medium_wbf11,
                                  tradeoff = 0.99, 
                                  strategy = 'none',                                    
                                  threads = 10)

## Collect Predicted Growth Rates
Isolate predictions for probiotic species a

In [None]:
# Gather growth rates
rates_probiotic = growth_agora1.growth_rates

# Isolate Probiotics
rates_probiotic = rates_probiotic[rates_probiotic.index.isin(['Akkermansia_muciniphila',
                                                              'Clostridium_beijerinckii',
                                                              'Clostridium_butyricum',
                                                              'Bifidobacterium_longum',
                                                              'Anaerobutyricum_hallii'])].reset_index()

# Map subject IDs
rates_probiotic['subject_id'] = rates_probiotic['sample_id'].map(subject_dict)

# Set index to subject and compartment 
rates_probiotic.set_index(['subject_id', 'compartments'], inplace = True)
rates_probiotic

## Merge Data and Compare
Binarize predicted growth rates around 10^-3

In [None]:
# Log scale growth rates
rates_probiotic['growth_rate_scaled'] = np.log10(rates_probiotic[['growth_rate']])

# Map treatment group from metadata 
rates_probiotic['treatment'] = rates_probiotic['sample_id'].map(treatment_dict)

rates_probiotic.reset_index(inplace = True)

# Threshold log-scaled growth rates
rates_probiotic.loc[rates_probiotic['growth_rate_scaled']>=-3, 'MCMM_engraftment'] = 1
rates_probiotic['MCMM_engraftment'].fillna(0.0, inplace = True)

## Create Agreement Heatmap
Find agreement between predictions and measures

In [None]:
# Create qPCR Engraftment Pivot Table
df1 = pd.pivot_table(qPCR_wbf11[qPCR_wbf11['Week'] == '12'], 
                    index='Strain',
                    columns='Subject',
                    values='qPCR_engraftment').fillna(0.0)

# Create MCMM Engraftment Pivot Table
df2 = pd.pivot_table(rates_probiotic, 
                    index='compartments',
                    columns='subject_id',
                    values='MCMM_engraftment').fillna(0.0)

# Ensure both dataframes are numeric (convert from strings if necessary)
df1 = df1.astype(float)
df2 = df2.astype(float)

# Calculate agreement matrix
agreement_matrix = (df1 == df2).astype(int)

agreement_matrix.index = agreement_matrix.index.str.replace('_', ' ')
# Plot heatmap of agreement
plt.figure(figsize=(18, 6))
sns.set(font_scale=1.5) 
sns.heatmap(agreement_matrix, cmap='coolwarm_r', annot=True, cbar=False, square=True)
plt.title('Agreement Between MCMM and qPCR')
plt.savefig('../figures/agreement_heatmap_qPCR.svg', dpi=300, bbox_inches='tight')  # Save as PNG with 300 DPI

## Calculate Cohen's Kappa
Measure the agreement 

In [None]:
# Flatten predictions and measures
ratings1 = df1.values.flatten()
ratings2 = df2.values.flatten()

# Compute Cohen's kappa
kappa = sklearn.metrics.cohen_kappa_score(ratings1, ratings2)

print(f"Cohen's kappa: {kappa:.4f}")

## Calculate % Agreement
Calculate agreement as a percentage

In [None]:
agree_ones = agreement_matrix.sum().sum()
growth_ones = df2.sum().sum()
total = agreement_matrix.size
print(str((agree_ones/total).round(4)*100)+'%')

## Plot Confusion Matricies
Second visualization of the results

In [None]:
strains = df1.index

# Set up subplot grid
fig, axes = plt.subplots(2, 3, figsize=(18, 8))
axes = axes.flatten()  # Flatten for easy iteration

# Define consistent color limits
vmin, vmax = 0, 21

# Aggregate all results
all_y_true = df1.values.flatten()  # Flatten across all strains
all_y_pred = df2.values.flatten()

# Compute overall confusion matrix
cm_agg = sklearn.metrics.confusion_matrix(all_y_true, all_y_pred)

# Ensure correct shape for formatting
if cm_agg.shape == (1, 1):
    cm_agg = np.array([[cm_agg[0, 0], 0], [0, 0]])

# Plot aggregated confusion matrix
sns.heatmap(cm_agg, annot=True, fmt="d", cmap="Reds",
            xticklabels=["No Growth", "Growth"],
            yticklabels=["No Growth", "Growth"],
            ax=axes[0], vmin=vmin, vmax=vmax)

# Loop through each strain
for i, strain in enumerate(strains):
    y_true = df1.loc[strain].values.flatten()
    y_pred = df2.loc[strain].values.flatten()

    cm = sklearn.metrics.confusion_matrix(y_true, y_pred)

    if cm.shape == (1, 1):
        cm = np.array([[cm[0, 0], 0], [0, 0]])

    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=["No Growth", "Growth"],
                yticklabels=["No Growth", "Growth"],
                ax=axes[i+1], vmin=vmin, vmax=vmax)


axes[0].set_title("Aggregate Across Strains", fontsize=16)
axes[1].set_title('Akkermansia muciniphila', fontsize=16)
axes[2].set_title('Anaerobutyricum hallii', fontsize=16)
axes[3].set_title('Bifidobacterium longum \n sp. infantis', fontsize=16)
axes[4].set_title('Clostridium beijerinckii', fontsize=16)
axes[5].set_title('Clostridium butyricum', fontsize=16)

fig.suptitle("Confusion Matrices of MICOM Predictions vs. qPCR", fontsize=24)
fig.text(-0.02, 0.5, "qPCR Ground Truth", va='center', rotation='vertical', fontsize=24)
fig.text(0.5, -0.02, "Model Prediction", ha='center', fontsize=24)


plt.tight_layout()
plt.savefig('../figures/confusion_matricies.svg', dpi=300)  # Save as PNG with 300 DPI
plt.show()


## Calculate Fisher's Exact Test for Aggregate Result


In [None]:
y_true = df1.loc['Anaerobutyricum_hallii'].values.flatten()
y_pred = df2.loc['Anaerobutyricum_hallii'].values.flatten()

cm = confusion_matrix(y_true, y_pred)

if cm.shape == (1, 1):
    cm = np.array([[cm[0, 0], 0], [0, 0]])
scipy.stats.fisher_exact(cm_agg)