In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import mdsine2 as md2
from mdsine2.names import STRNAMES
import h5py

[DEBUG] Using default logger (stdout, stderr).


In [2]:
# Define useful lookup variables
species = ['B_caccae', 
           'B_cellulosilyticus_WH2', 
           'B_ovatus', 
           'B_thetaiotaomicron', 
           'B_uniformis', 
           'B_vulgatus',  
           'C_aerofaciens', 
           'C_scindens', 
           'C_spiroforme', 
           # 'D_longicatena'
           'P_distasonis', 
           'R_obeum'
        ]

cols = ['#016bff',
        '#b91f1c',
        '#308937',
        '#ff8137',
        '#6f4fc7',
        '#d4b300',
        '#893a2b',
        '#ff6c79',
        '#16c4ff',
        '#766f41',
        '#00c800',
        # '#af3261'
        ]

pretty_names_dir = {'B_caccae': '$B. caccae$',
                'B_cellulosilyticus_WH2': '$B. cellulosilyticus$',
                'B_ovatus': '$B. ovatus$',
                'B_thetaiotaomicron': '$B. thetaiotaomicron$',
                'B_uniformis': '$B. uniformis$',
                'B_vulgatus': '$B. vulgatus$',
                'C_aerofaciens': '$C. aerofaciens$',
                'C_scindens': '$C. scindens$',
                'C_spiroforme': '$C. spiroforme$',
                # 'D_longicatena': '$D. longicatena$',
                'P_distasonis': '$P. distasonis$',
                'R_obeum': '$R. obeum$'
                }

order = ['B_cellulosilyticus_WH2', 
        'B_caccae', 
        'B_vulgatus', 
        'B_thetaiotaomicron', 
        'B_ovatus', 
        'R_obeum', 
        'B_uniformis', 
        'P_distasonis', 
        'C_scindens', 
        'C_aerofaciens', 
        'C_spiroforme', 
        # 'D_longicatena'
        ]

lf0 = ['1', '2', '3', '4', '5', '6', '7']
hf0 = ['8', '9', '10', '11', '12', '13', '14', '15']

taxa_color = {order: cols[i] for i, order in enumerate(order)}
taxa_color

{'B_cellulosilyticus_WH2': '#016bff',
 'B_caccae': '#b91f1c',
 'B_vulgatus': '#308937',
 'B_thetaiotaomicron': '#ff8137',
 'B_ovatus': '#6f4fc7',
 'R_obeum': '#d4b300',
 'B_uniformis': '#893a2b',
 'P_distasonis': '#ff6c79',
 'C_scindens': '#16c4ff',
 'C_aerofaciens': '#766f41',
 'C_spiroforme': '#00c800'}

In [3]:
def get_parameter_traces(dataset):
    if dataset == 'LF0':
        seed = 53
        mc_chain = pd.read_pickle(f'mcnulty-{dataset}-seed{seed}/inference/mcmc.pkl')
        burnin = mc_chain.burnin
        with h5py.File(f'mcnulty-{dataset}-seed{seed}/inference/traces.hdf5', "r") as f:
            # print("Keys: %s" % f.keys())
            # Get growth rate traces
            growth_key = list(f.keys())[2]
            print(growth_key)
            growth = f[growth_key][()][burnin:]
            # Get self-interaction traces
            self_interaction_key = list(f.keys())[17]
            print(self_interaction_key)
            self_interaction = f[self_interaction_key][()][burnin:]
            # Get interactions traces
            interactions_key = list(f.keys())[4]
            print(interactions_key)
            interactions = f[interactions_key][()][burnin:]
            # Get perturbation traces
            perturbation_key = list(f.keys())[3]
            if perturbation_key in ['HF', 'LF']:
                if perturbation_key == 'HF':
                    pert = 'HF/HS'
                else:
                    pert = 'LF/HPP'
            else:
                pert = perturbation_key
            print(f'Perturbation object: {pert}')
            perturbation = f[pert][()][burnin:]
            # Get process variance traces
            process_variance_key = list(f.keys())[16]
            print(process_variance_key)
            process_variance = f[process_variance_key][()][burnin:]
    if dataset == 'HF0':
        seed = 3
        mc_chain = pd.read_pickle(f'mcnulty-{dataset}-seed{seed}/inference/mcmc.pkl')
        burnin = mc_chain.burnin
        with h5py.File(f'mcnulty-{dataset}-seed{seed}/inference/traces.hdf5', "r") as f:
            # print("Keys: %s" % f.keys())
            # Get growth rate traces
            growth_key = list(f.keys())[2]
            print(growth_key)
            growth = f[growth_key][()][burnin:]
            # Get self-interaction traces
            self_interaction_key = list(f.keys())[18]
            print(self_interaction_key)
            self_interaction = f[self_interaction_key][()][burnin:]
            # Get interactions traces
            interactions_key = list(f.keys())[3]
            print(interactions_key)
            interactions = f[interactions_key][()][burnin:]
            # Get perturbation traces
            perturbation_key = list(f.keys())[4]
            if perturbation_key in ['HF', 'LF']:
                if perturbation_key == 'HF':
                    pert = 'HF/HS'
                else:
                    pert = 'LF/HPP'
            else:
                pert = perturbation_key
            print(f'Perturbation object: {pert}')
            perturbation = f[pert][()][burnin:]
            # Get process variance traces
            process_variance_key = list(f.keys())[17]
            print(process_variance_key)
            process_variance = f[process_variance_key][()][burnin:]

    return {'growth': growth, 
            'self_interaction': self_interaction, 
            'interactions': interactions, 
            'perturbation': perturbation, 
            'process_variance': process_variance}

In [4]:
def get_mean(x):
    return np.nanmean(x, axis=0)

def get_median(x):
    return np.nanmedian(x, axis=0)

def get_mode(x):
    if len(x[0].shape) == 0:
        x_ = x[~np.isnan(x)]
        hist, bin_edges = np.histogram(x_, bins=100)
        mode_bin_index = np.argmax(hist)
        mode_value = bin_edges[mode_bin_index]
        mode = mode_value
    if len(x[0].shape) == 1:
        mode = np.zeros(x[0].shape)
        for i in range(x[0].shape[0]):
            x_i = x[:,i][~np.isnan(x[:,i])]
            hist, bin_edges = np.histogram(x_i, bins=100)
            mode_bin_index = np.argmax(hist)
            mode_value = bin_edges[mode_bin_index]
            mode[i] = mode_value
    if len(x[0].shape) == 2:
        mode = np.zeros(x[0].shape)
        for i in range(x[0].shape[0]):
            for j in range(x[0].shape[1]):
                if i == j:
                    mode[i,j] = np.nan
                else:
                    x_ij = x[:,i,j][~np.isnan(x[:,i,j])]
                    hist, bin_edges = np.histogram(x_ij, bins=100)
                    mode_bin_index = np.argmax(hist)
                    mode_value = bin_edges[mode_bin_index]
                    mode[i,j] = mode_value
    return mode
    
def get_std(x):
    return np.nanstd(x, axis=0)
    
def get_ci(x, cl=95, method='median_unbiased'):
    cl=(100-cl)/2
    return np.nanpercentile(x, [0+cl, 100-cl], axis=0, method=method)

In [5]:
def get_bayes_factors(traces, param):
    if param == 'interactions':
        # https://htmlpreview.github.io/?https://raw.githubusercontent.com/gerberlab/MDSINE2/master/docs/mdsine2/posterior.html#mdsine2.posterior.ClusterInteractionIndicatorProbability.initialize
        ip_a = 0.5
        ip_b = 11*(11-1)

        trace = ~ np.isnan(traces['interactions'])
        cnts_1 = np.sum(trace, axis=0)
        cnts_0 = np.sum(1-trace, axis=0)

        bayes_factors = (cnts_1 * ip_b) / (cnts_0 * ip_a)
    elif param == 'perturbation':
        # https://htmlpreview.github.io/?https://raw.githubusercontent.com/gerberlab/MDSINE2/master/docs/mdsine2/posterior.html#mdsine2.posterior.PerturbationProbabilities.initialize
        pp_a = 0.5
        pp_b = 0.5

        trace = ~ np.isnan(traces['perturbation'])
        cnts_1 = np.sum(trace, axis=0)
        cnts_0 = np.sum(1-trace, axis=0)

        bayes_factors = (cnts_1 * pp_b) / (cnts_0 * pp_a)
    else:
        raise ValueError('Parameter not supported for calculating Bayes Factors')
    return bayes_factors

In [6]:
def matrix_to_table(matrix, stat):
    table = pd.DataFrame(matrix, index=species, columns=species)
    table.reset_index(inplace=True)
    table.rename(columns={'index': 'Target'}, inplace=True)
    table = table.melt(id_vars=['Target'], var_name='Source', value_name=stat)
    table.set_index(['Source', 'Target'], inplace=True)
    return table

def save_params_tables(dataset, traces):
    if dataset == 'LF0':
        seed = 53
    if dataset == 'HF0':
        seed = 3
    output_path = f'mcnulty-{dataset}-seed{seed}/inference/posteriors'
    for param in traces.keys():
        if param in ['growth', 'self_interaction']:
            mean = get_mean(traces[param])
            median = get_median(traces[param])
            mode = get_mode(traces[param])
            std = get_std(traces[param])
            ci = get_ci(traces[param])
            table = pd.DataFrame({'mean': mean, 
                                'median': median, 
                                'mode': mode, 
                                'std': std, 
                                'ci_l': ci[0], 
                                'ci_u': ci[1]}, 
                                index=species)
            table.to_csv(f'{output_path}/{param}.tsv', sep='\t')
        elif param == 'perturbation':
            mean = get_mean(traces[param])
            median = get_median(traces[param])
            mode = get_mode(traces[param])
            std = get_std(traces[param])
            ci = get_ci(traces[param])
            bayes_factors = get_bayes_factors(traces, param)
            table = pd.DataFrame({'mean': mean, 
                                'median': median, 
                                'mode': mode, 
                                'std': std, 
                                'ci_l': ci[0], 
                                'ci_u': ci[1], 
                                'bayes_factors': bayes_factors}, 
                                index=species)
            table.to_csv(f'{output_path}/{param}.tsv', sep='\t')
        elif param == 'interactions':
            mean = matrix_to_table(get_mean(traces[param]), stat='mean')
            median = matrix_to_table(get_median(traces[param]), stat='median')
            mode = matrix_to_table(get_mode(traces[param]), stat='mode')
            std = matrix_to_table(get_std(traces[param]), stat='std')
            ci_l = matrix_to_table(get_ci(traces[param])[0], stat='ci_l')
            ci_u = matrix_to_table(get_ci(traces[param])[1], stat='ci_u')
            bayes_factors = matrix_to_table(get_bayes_factors(traces, param), stat='bayes_factors')
            table = pd.concat([mean, median, mode, std, ci_l, ci_u, bayes_factors], axis=1).dropna()
            table.to_csv(f'{output_path}/{param}.tsv', sep='\t')
        elif param == 'process_variance':
            mean = get_mean(traces[param])
            median = get_median(traces[param])
            mode = get_mode(traces[param])
            std = get_std(traces[param])
            ci = get_ci(traces[param])
            table = pd.DataFrame({'mean': mean, 
                                'median': median, 
                                'mode': mode, 
                                'std': std, 
                                'ci_l': ci[0], 
                                'ci_u': ci[1]},
                                index=list(range(6)))
            table.to_csv(f'{output_path}/{param}.tsv', sep='\t')

def save_interaction_matrices(dataset, traces):
    if dataset == 'LF0':
        seed = 53
    if dataset == 'HF0':
        seed = 3
    output_path = Path(f'mcnulty-{dataset}-seed{seed}/inference/posteriors/interactions_matrices')
    output_path.mkdir(exist_ok=True, parents=True)
    pd.DataFrame(get_mean(traces['interactions']), index=species, columns=species). \
        to_csv(f'{output_path}/mean_matrix.tsv', sep='\t')
    pd.DataFrame(get_median(traces['interactions']), index=species, columns=species). \
        to_csv(f'{output_path}/median_matrix.tsv', sep='\t')
    pd.DataFrame(get_mode(traces['interactions']), index=species, columns=species). \
        to_csv(f'{output_path}/mode_matrix.tsv', sep='\t')
    pd.DataFrame(get_std(traces['interactions']), index=species, columns=species). \
        to_csv(f'{output_path}/std_matrix.tsv', sep='\t')
    pd.DataFrame(get_ci(traces['interactions'])[0], index=species, columns=species). \
        to_csv(f'{output_path}/ci_l_matrix.tsv', sep='\t')
    pd.DataFrame(get_ci(traces['interactions'])[1], index=species, columns=species). \
        to_csv(f'{output_path}/ci_u_matrix.tsv', sep='\t')
    pd.DataFrame(get_bayes_factors(traces, 'interactions'), index=species, columns=species). \
        to_csv(f'{output_path}/bayes_factors_matrix.tsv', sep='\t')


In [7]:
def save_traces_plots(dataset, traces):
    if dataset == 'LF0':
        seed = 53
    if dataset == 'HF0':
        seed = 3
    for param in traces.keys():
        if param in ['growth', 'self_interaction', 'perturbation']:
            mean = get_mean(traces[param])
            median = get_median(traces[param])
            mode = get_mode(traces[param])
            std = get_std(traces[param])
            ci = get_ci(traces[param])
            for sp in range(len(species)):
                fig, axs = plt.subplots(1, 2, figsize=(12, 5))
                axs[0].hist(traces[param][:,sp], bins=100, alpha=0.9, color='gray')
                axs[0].axvline(x=mean[sp], color='mediumblue', linestyle='--', linewidth=2, label='Mean')
                axs[0].axvline(x=median[sp], color='green', linestyle='--', linewidth=2, label='Median')
                axs[0].axvline(x=mode[sp], color='firebrick', linestyle='--', linewidth=2, label='Mode')
                axs[0].axvline(x=ci[0][sp], color='gold', linestyle=':', linewidth=1.5, label='95% CI')
                axs[0].axvline(x=ci[1][sp], color='gold', linestyle=':', linewidth=1.5)
                axs[0].axvline(x=mean[sp]+std[sp], color='tomato', linestyle=':', linewidth=1.5, label='Std Dev')
                axs[0].axvline(x=mean[sp]-std[sp], color='tomato', linestyle=':', linewidth=1.5)
                axs[1].plot(traces[param][:,sp], alpha=0.9, marker='x', markersize=1, linestyle='None', color='gray')
                axs[1].axhline(y=mean[sp], color='mediumblue', linestyle='--', linewidth=2, label='Mean')
                axs[1].axhline(y=median[sp], color='green', linestyle='--', linewidth=2, label='Median')
                axs[1].axhline(y=mode[sp], color='firebrick', linestyle='--', linewidth=2, label='Mode')
                axs[1].axhline(y=ci[0][sp], color='gold', linestyle=':', linewidth=1.5, label='95% CI')
                axs[1].axhline(y=ci[1][sp], color='gold', linestyle=':', linewidth=1.5)
                axs[1].axhline(y=mean[sp]+std[sp], color='tomato', linestyle=':', linewidth=1.5, label='Std Dev')
                axs[1].axhline(y=mean[sp]-std[sp], color='tomato', linestyle=':', linewidth=1.5)
                axs[0].set_xlabel(f'{param} value')
                axs[0].set_ylabel(f'Frequency')
                axs[1].set_xlabel(f'Step')
                axs[1].set_ylabel(f'{param} value')
                fig.suptitle(f'{pretty_names_dir[species[sp]]} {param} posterior', y=0.94)
                handles, labels = axs[0].get_legend_handles_labels()
                fig.legend(handles, labels, loc='lower center', ncol=9, bbox_to_anchor=(0.5, -0.05))
                output_path = Path(f'mcnulty-{dataset}-seed{seed}/inference/posteriors/plots/{param}')
                output_path.mkdir(exist_ok=True, parents=True)
                fig.savefig(f'{output_path}/{species[sp]}_{param}.pdf', bbox_inches='tight')
                plt.close()
        elif param == 'interactions':
            mean = get_mean(traces[param])
            median = get_median(traces[param])
            mode = get_mode(traces[param])
            std = get_std(traces[param])
            ci = get_ci(traces[param])
            for spi in range(len(species)):
                for spj in range(len(species)):
                    if spi == spj:
                        continue
                    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
                    axs[0].hist(traces[param][:,spi,spj], bins=100, alpha=0.9, color='gray')
                    axs[0].axvline(x=mean[spi,spj], color='mediumblue', linestyle='--', linewidth=2, label='Mean')
                    axs[0].axvline(x=median[spi,spj], color='green', linestyle='--', linewidth=2, label='Median')
                    axs[0].axvline(x=mode[spi,spj], color='firebrick', linestyle='--', linewidth=2, label='Mode')
                    axs[0].axvline(x=ci[0][spi,spj], color='gold', linestyle=':', linewidth=1.5, label='95% CI')
                    axs[0].axvline(x=ci[1][spi,spj], color='gold', linestyle=':', linewidth=1.5)
                    axs[0].axvline(x=mean[spi,spj]+std[spi,spj], color='tomato', linestyle=':', linewidth=1.5, label='Std Dev')
                    axs[0].axvline(x=mean[spi,spj]-std[spi,spj], color='tomato', linestyle=':', linewidth=1.5)
                    axs[1].plot(traces[param][:,spi,spj], alpha=0.9, marker='x', markersize=1, linestyle='None', color='gray')
                    axs[1].axhline(y=mean[spi,spj], color='mediumblue', linestyle='--', linewidth=2, label='Mean')
                    axs[1].axhline(y=median[spi,spj], color='green', linestyle='--', linewidth=2, label='Median')
                    axs[1].axhline(y=mode[spi,spj], color='firebrick', linestyle='--', linewidth=2, label='Mode')
                    axs[1].axhline(y=ci[0][spi,spj], color='gold', linestyle=':', linewidth=1.5, label='95% CI')
                    axs[1].axhline(y=ci[1][spi,spj], color='gold', linestyle=':', linewidth=1.5)
                    axs[1].axhline(y=mean[spi,spj]+std[spi,spj], color='tomato', linestyle=':', linewidth=1.5, label='Std Dev')
                    axs[1].axhline(y=mean[spi,spj]-std[spi,spj], color='tomato', linestyle=':', linewidth=1.5)
                    axs[0].set_xlabel(f'{param} value')
                    axs[0].set_ylabel(f'Frequency')
                    axs[1].set_xlabel(f'Step')
                    axs[1].set_ylabel(f'{param} value')
                    fig.suptitle(f'{pretty_names_dir[species[spi]]}-{pretty_names_dir[species[spj]]} {param} posterior', y=0.94)
                    handles, labels = axs[0].get_legend_handles_labels()
                    fig.legend(handles, labels, loc='lower center', ncol=9, bbox_to_anchor=(0.5, -0.05))
                    output_path = Path(f'mcnulty-{dataset}-seed{seed}/inference/posteriors/plots/{param}')
                    output_path.mkdir(exist_ok=True, parents=True)
                    fig.savefig(f'{output_path}/{species[spi]}_{species[spj]}_{param}.pdf', bbox_inches='tight')
                    plt.close()
        elif param == 'process_variance':
            mean = get_mean(traces[param])
            median = get_median(traces[param])
            mode = get_mode(traces[param])
            std = get_std(traces[param])
            ci = get_ci(traces[param])
            fig, axs = plt.subplots(1, 2, figsize=(12, 5))
            axs[0].hist(traces[param], bins=100, alpha=0.9, color='gray')
            axs[0].axvline(x=mean, color='mediumblue', linestyle='--', linewidth=2, label='Mean')
            axs[0].axvline(x=median, color='green', linestyle='--', linewidth=2, label='Median')
            axs[0].axvline(x=mode, color='firebrick', linestyle='--', linewidth=2, label='Mode')
            axs[0].axvline(x=ci[0], color='gold', linestyle=':', linewidth=1.5, label='95% CI')
            axs[0].axvline(x=ci[1], color='gold', linestyle=':', linewidth=1.5)
            axs[0].axvline(x=mean+std, color='tomato', linestyle=':', linewidth=1.5, label='Std Dev')
            axs[0].axvline(x=mean-std, color='tomato', linestyle=':', linewidth=1.5)
            axs[1].plot(traces[param], alpha=0.9, marker='x', markersize=1, linestyle='None', color='gray')
            axs[1].axhline(y=mean, color='mediumblue', linestyle='--', linewidth=2, label='Mean')
            axs[1].axhline(y=median, color='green', linestyle='--', linewidth=2, label='Median')
            axs[1].axhline(y=mode, color='firebrick', linestyle='--', linewidth=2, label='Mode')
            axs[1].axhline(y=ci[0], color='gold', linestyle=':', linewidth=1.5, label='95% CI')
            axs[1].axhline(y=ci[1], color='gold', linestyle=':', linewidth=1.5)
            axs[1].axhline(y=mean+std, color='tomato', linestyle=':', linewidth=1.5, label='Std Dev')
            axs[1].axhline(y=mean-std, color='tomato', linestyle=':', linewidth=1.5)
            axs[0].set_xlabel(f'{param} value')
            axs[0].set_ylabel(f'Frequency')
            axs[1].set_xlabel(f'Step')
            axs[1].set_ylabel(f'{param} value')
            fig.suptitle(f'{param} posterior', y=0.94)
            handles, labels = axs[0].get_legend_handles_labels()
            fig.legend(handles, labels, loc='lower center', ncol=9, bbox_to_anchor=(0.5, -0.05))
            output_path = Path(f'mcnulty-{dataset}-seed{seed}/inference/posteriors/plots/{param}')
            output_path.mkdir(exist_ok=True, parents=True)
            fig.savefig(f'{output_path}/{param}.pdf', bbox_inches='tight')
            plt.close()

In [8]:
for dataset in ['LF0', 'HF0']:
    traces = get_parameter_traces(dataset)
    save_params_tables(dataset, traces)
    save_interaction_matrices(dataset, traces)
    save_traces_plots(dataset, traces)

Growth parameter
Self interaction parameter
Interactions object
Perturbation object: HF/HS
Process Variance parameter
Growth parameter
Self interaction parameter
Interactions object
Perturbation object: LF/HPP
Process Variance parameter
