[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gerberlab/MDSINE2_Paper/blob/master/google_colab/tutorial3_reproduce_figures.ipynb)
# Figures for MDSINE2

The codes below reproduce the figures appearing in the paper, "***Intrinsic instability of the dysbiotic microbiome revealed through dynamical systems inference at scale ***". The output figures are saved in `MDSINE2_Paper/analysis/gibson_inference/figures/output_figures/`. 


## Dependency Initializations

### Install Necessary libraries

In [1]:
!git clone https://github.com/gerberlab/MDSINE2
!pip install MDSINE2/.


Cloning into 'MDSINE2'...
remote: Enumerating objects: 3021, done.[K
remote: Counting objects: 100% (930/930), done.[K
remote: Compressing objects: 100% (650/650), done.[K
remote: Total 3021 (delta 623), reused 531 (delta 271), pack-reused 2091[K
Receiving objects: 100% (3021/3021), 78.20 MiB | 9.60 MiB/s, done.
Resolving deltas: 100% (1913/1913), done.
Processing ./MDSINE2
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Collecting matplotlib>=3.3.1
  Downloading matplotlib-3.5.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (11.2 MB)
[K     |████████████████████████████████| 11.2 MB 5.3 MB/s 
Collecting h5py==2.9.0
  Down

In [2]:
!pip install wget
!pip install zenodo-get
!pip install scikit-bio
!pip install tables==3.6.1 --force-reinstall

Collecting wget
  Downloading wget-3.2.zip (10 kB)
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9672 sha256=0ffe17f853c4cda51bc1d32fcea632d02d6fe979955aa0276ed694f4223ba1cf
  Stored in directory: /root/.cache/pip/wheels/a1/b6/7c/0e63e34eb06634181c63adacca38b79ff8f35c37e3c13e3c02
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2
Collecting zenodo-get
  Downloading zenodo_get-1.3.4-py2.py3-none-any.whl (17 kB)
Installing collected packages: zenodo-get
Successfully installed zenodo-get-1.3.4
Collecting scikit-bio
  Downloading scikit-bio-0.5.6.tar.gz (8.4 MB)
[K     |████████████████████████████████| 8.4 MB 4.3 MB/s 
[?25hCollecting lockfile>=0.10.2
  Downloading lockfile-0.12.2-py2.py3-none-any.whl (13 kB)
Collecting hdmedians>=0.13
  Downloading hdmedians-0.14.2.tar.gz (7.6 kB)
  Installing build dependencies ... [?25l[

### Zenodo file retrieval

The relevant files needed to make the figures are saved in zenodo. 

In [None]:
!wget https://zenodo.org/record/5781848/files/forward_sims.tgz
!wget https://zenodo.org/record/5781848/files/mixed_prior_fixed.tgz
!wget https://zenodo.org/record/5781848/files/mixed_prior_unfixed.tgz
!wget https://zenodo.org/record/5781848/files/other_files.tgz

--2021-12-20 20:58:47--  https://zenodo.org/record/5781848/files/forward_sims.tgz
Resolving zenodo.org (zenodo.org)... 137.138.76.77
Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5183489132 (4.8G) [application/octet-stream]
Saving to: ‘forward_sims.tgz’


2021-12-20 21:05:49 (11.7 MB/s) - ‘forward_sims.tgz’ saved [5183489132/5183489132]

--2021-12-20 21:05:49--  https://zenodo.org/record/5781848/files/mixed_prior_fixed.tgz
Resolving zenodo.org (zenodo.org)... 137.138.76.77
Connecting to zenodo.org (zenodo.org)|137.138.76.77|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6286873326 (5.9G) [application/octet-stream]
Saving to: ‘mixed_prior_fixed.tgz’

mixed_prior_fixed.t  27%[====>               ]   1.63G  22.1MB/s    eta 3m 11s 

In [None]:
#use if wget does not work
#!zenodo_get -d 10.5281/zenodo.5781391

In [None]:
#@title unzip the .tgz files
!tar -xzvf "mixed_prior_fixed.tgz" 
!tar -xzvf "mixed_prior_unfixed.tgz" 
!tar -xzvf "forward_sims.tgz" 
!tar -xzvf "other_files.tgz" 

In [None]:
!git clone --branch sawal_final_changes https://github.com/gerberlab/MDSINE2_Paper

In [None]:
%cd /content/MDSINE2_Paper/analysis/
%pwd
%matplotlib notebook
import matplotlib.pyplot as plt 
from pathlib import Path
import os

saveloc = "output/gibson/plots"
os.makedirs(saveloc, exist_ok=True)

## Figure 2

The figure is made in two steps. First, we plot all the parts but the heatmap showing deseq results. Also, the animations in panel A of the figure are not visible. They are added manually using Adobe Illustrator. 

In [None]:
!python gibson_inference/figures/figure2.py \
       -file1 "gibson_inference/figures/preprocessed_all/gibson_healthy_agg_taxa.pkl" \
       -file2 "gibson_inference/figures/preprocessed_all/gibson_uc_agg_taxa.pkl" \
       -file3 "gibson_inference/figures/preprocessed_all/gibson_inoculum_agg_taxa.pkl" \
       -o_loc "output/gibson/plots"

##Figure 2 Heatmap 

The output files are called mat_order_high_ss and mat_order_low_ss; the former and latter files illustrate the deseq results for order/family whose relative abundances are >0.5 and <0.5% respectively. The heatmaps are compiled into the paper version manually using Adobe Illustrator. 


In [None]:
!python gibson_inference/figures/deseq_heatmap_ss.py \
    -loc "gibson_inference/figures/figure2_files" \
    -abund "high" \
    -txt "abundant_species" \
    -taxo "order" \
    -o "mat_order_high_ss" \
    -o_loc "output/gibson/plots"


!python gibson_inference/figures/deseq_heatmap_ss.py \
    -loc "gibson_inference/figures/figure2_files" \
    -abund "low" \
    -txt "abundant_species" \
    -taxo "order" \
    -o "mat_order_low_ss" \
    -o_loc "output/gibson/plots"


## Figure 3 

In [None]:
!python gibson_inference/figures/figure3.py \
    --mdsine_path "/content/forward_sims/"\
    --clv_elas_path "/content/clv_results/results_rel_elastic/"\
    --clv_ridge_path "/content/clv_results/results_rel_ridge/"\
    --glv_elas_path "/content/clv_results/results_abs_elastic/"\
    --glv_ridge_path "/content/clv_results/results_abs_ridge/forward_sims_abs_ridge/"\
    --output_path "output/gibson/plots/"

## Figure 4

All sub-plots except the network diagrams for Healthy and Dysbiotic Cohorts are made. The networks are created using cytoscape and added to the figure manually using Adobe Illustrator.  

In [None]:
!python gibson_inference/figures/figure4.py \
    --chain_healthy "/content/mixed_prior_fixed/healthy-seed0-mixed/mcmc.pkl" \
    --chain_uc "/content/mixed_prior_fixed/uc-seed0-mixed/mcmc.pkl" \
    --tree_fname 'files/phylogenetic_placement_OTUs/phylogenetic_tree_only_query.nhx' \
    --study_healthy "output/gibson/preprocessed/gibson_healthy_agg_taxa.pkl" \
    --study_uc "output/gibson/preprocessed/gibson_uc_agg_taxa.pkl" \
    --study_inoc "output/gibson/preprocessed/gibson_inoculum_agg_taxa.pkl" \
    --detected_study_healthy "gibson_inference/figures/preprocessed_all/gibson_healthy_agg_taxa_filtered3.pkl" \
    --detected_study_uc "gibson_inference/figures/preprocessed_all/gibson_uc_agg_taxa_filtered3.pkl" \
    --output_loc "output/gibson/plots"

## Figure 5

This figure shows the Phylogenetic Neighborhood Analysis result, (Figure 5B).
The cartoon (Figure 5A) is created separately.

In [None]:
!python gibson_inference/figures/figure5.py \
    -file1 "/content/coarsening/distance.csv" \
    -file2 "/content/coarsening/arithmetic_mean_data.csv" \
    -file3 "/content/coarsening/arithmetic_mean_null_all.csv" \
    -o_loc "output/gibson/plots"

## Figure 6

Custom code necessary for rendering using pre-computed data.

In [None]:
#@title
import mdsine2 as md2
from mdsine2.names import STRNAMES
import numpy as np
import scipy
import scipy.stats
import pandas as pd
from pathlib import Path
from tqdm.notebook import tqdm
import seaborn as sns


# COLORS
_default_colors = sns.color_palette()
_default_healthy_color = _default_colors[0]
_default_uc_color = _default_colors[1]


def stat_annotate(x1, x2, y, h, color, ax, lw=1.0, desc='*'):
    ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=lw, c=color)
    ax.text((x1+x2)*.5, y+h, desc, ha='center', va='bottom', color=color)


class PerturbationSimFigure():
    """Render figures for perturbation simulations. (Figures 6A,6B)"""

    def __init__(self, data_dir: Path, healthy_color=_default_healthy_color, uc_color=_default_uc_color):
        self.healthy_color = healthy_color
        self.uc_color = uc_color

        # DATA DIR
        self.data_dir = data_dir

        # ============ Preprocessing
        print("Loading dataframes from disk.")
        healthy_random_pert_metadata_df = pd.read_hdf(
            data_dir / 'healthy_metadata.h5', key="df", mode="r")
        healthy_random_pert_fwsim_df = pd.read_hdf(
            data_dir / 'healthy_fwsim.h5', key="df", mode="r"
        )

        uc_random_pert_metadata_df = pd.read_hdf(
            data_dir / 'uc_metadata.h5', key="df", mode="r"
        )
        uc_random_pert_fwsim_df = pd.read_hdf(
            data_dir / 'uc_fwsim.h5', key="df", mode="r"
        )

        print("Merging dataframes.")
        healthy_random_pert_merged_df = self.posthoc_random_pert_helper(healthy_random_pert_fwsim_df, healthy_random_pert_metadata_df)
        uc_random_pert_merged_df = self.posthoc_random_pert_helper(uc_random_pert_fwsim_df, uc_random_pert_metadata_df)

        print("Computing difference levels.")
        self.random_pert_concat_df = self.random_pert_diff_figure_df(healthy_random_pert_merged_df, uc_random_pert_merged_df)

        print("Computing diversities.")
        self.random_pert_diversity_df, self.healthy_random_pert_baseline_diversity, self.uc_random_pert_baseline_diversity = self.precompute_diversities(
            healthy_random_pert_fwsim_df, uc_random_pert_fwsim_df
        )

        print("Finished initialization.")

    @staticmethod
    def posthoc_random_pert_helper(fwsim: pd.DataFrame, metadata: pd.DataFrame):
        merged_df = fwsim.loc[(fwsim["PerturbedFrac"] != 0.0), :].merge(
            fwsim.loc[(fwsim["PerturbedFrac"] == 0.0), ["OTU", "SteadyState", "Died", "SampleIdx"]],
            left_on=["OTU", "SampleIdx"],
            right_on=["OTU", "SampleIdx"],
            how="inner",
            suffixes=["", "Base"]
        ).merge(
            metadata.loc[
                :,
                ["OTU", "PerturbedFrac", "Perturbation", "Trial", "IsPerturbed"]
            ],
            left_on=["OTU", "PerturbedFrac", "Perturbation", "Trial"],
            right_on=["OTU", "PerturbedFrac", "Perturbation", "Trial"]
        ).set_index(["OTU", "PerturbedFrac", "Perturbation", "Trial", "SampleIdx"])

        merged_df["SteadyStateDiff"] = np.abs(np.log10(merged_df["SteadyState"] + 1e5) - np.log10(merged_df["SteadyStateBase"] + 1e5))

        return merged_df

    @staticmethod
    def random_pert_diff_figure_df(healthy_merged_df: pd.DataFrame, uc_merged_df: pd.DataFrame, perturbation=-2.0):
        #   OTU 	PerturbedFrac 	Perturbation 	Trial 	SampleIdx

        healthy_agg_df = healthy_merged_df.loc[
            (slice(None), slice(None), perturbation, slice(None), slice(None)),
            ["SteadyStateDiff"]
        ].groupby(
            level=[1, 2, 3, 4]  # "PerturbedFrac", "Perturbation", "Trial", "SampleIdx"
        ).mean().groupby(
            level=[0, 1, 2]  # "PerturbedFrac", "Perturbation", "Trial"
        ).mean()

        uc_agg_df = uc_merged_df.loc[
            (slice(None), slice(None), perturbation, slice(None), slice(None)),
            ["SteadyStateDiff"]
        ].groupby(
            level=[1, 2, 3, 4]  # "PerturbedFrac", "Perturbation", "Trial", "SampleIdx"
        ).mean().groupby(
            level=[0, 1, 2]  # "PerturbedFrac", "Perturbation", "Trial"
        ).mean()

        healthy_agg_df["Dataset"] = "Healthy"
        uc_agg_df["Dataset"] = "UC"
        concat_df = pd.concat([
            healthy_agg_df.reset_index(),
            uc_agg_df.reset_index()
        ])

        concat_df["key"] = r'$\alpha$:' + concat_df["PerturbedFrac"].astype(str) + "\nPert:" + concat_df["Perturbation"].astype(str)
        return concat_df

    @staticmethod
    def precompute_diversities(healthy_fwsim_df, uc_fwsim_df, perturbation=-2.0):
        def agg(x):
            p = x["SteadyState"].to_numpy()
            u = np.ones(p.shape[0])
            return scipy.stats.entropy(p) / scipy.stats.entropy(u)

        # ======= Altered
        healthy_diversity = healthy_fwsim_df.loc[
            (healthy_fwsim_df["PerturbedFrac"] != 0.0) & (healthy_fwsim_df["Perturbation"] == perturbation),
            ["Trial", "SampleIdx", "SteadyState", "PerturbedFrac"]
        ].groupby(["PerturbedFrac", "Trial", "SampleIdx"]).apply(
            agg
        ).groupby(level=[0, 1]).mean()
        healthy_diversity = pd.DataFrame({"Diversity": healthy_diversity})
        healthy_diversity["Dataset"] = "Healthy"

        uc_diversity = uc_fwsim_df.loc[
            (uc_fwsim_df["PerturbedFrac"] != 0.0) & (uc_fwsim_df["Perturbation"] == perturbation),
            ["Trial", "SampleIdx", "SteadyState", "PerturbedFrac"]
        ].groupby(["PerturbedFrac", "Trial", "SampleIdx"]).apply(
            agg
        ).groupby(level=[0, 1]).mean()
        uc_diversity = pd.DataFrame({"Diversity": uc_diversity})
        uc_diversity["Dataset"] = "UC"

        diversities = pd.concat([healthy_diversity.reset_index(), uc_diversity.reset_index()]).reset_index()

        # ======= Baselines
        healthy_baseline_diversity = healthy_fwsim_df.loc[
            healthy_fwsim_df["PerturbedFrac"] == 0.0,
            ["SampleIdx", "SteadyState"]
        ].groupby("SampleIdx").apply(
            agg
        ).mean()

        uc_baseline_diversity = uc_fwsim_df.loc[
            uc_fwsim_df["PerturbedFrac"] == 0.0,
            ["SampleIdx", "SteadyState"]
        ].groupby("SampleIdx").apply(
            agg
        ).mean()

        return diversities, healthy_baseline_diversity, uc_baseline_diversity

    def plot_deviations(
        self,
        ax,
        ymin=0.0, ymax=0.35
    ):
        df = self.random_pert_concat_df
        sns.swarmplot(x="PerturbedFrac",
                      y="SteadyStateDiff",
                      hue="Dataset",
                      ax=ax,
                      data=df,
                      palette={"Healthy": self.healthy_color, "UC": self.uc_color},
                      size=3,
                      dodge=True)

        sns.boxplot(
            data=df,
            x="PerturbedFrac", y="SteadyStateDiff",
            hue="Dataset",
            whis=[2.5, 97.5],
            ax=ax,
            showfliers=False,
            palette={"Healthy": self.healthy_color, "UC": self.uc_color},
            boxprops=dict(alpha=.4)
        )

        ax.set_ylim([ymin, ymax])

        # Axis labels
        ax.set_xlabel("Fraction of OTUs perturbed")
        ax.set_ylabel("Difference from Baseline Steady State")

        # =========== P-values + Benjamini-Hochberg correction
        df_healthy = df.loc[df["Dataset"] == "Healthy", ["PerturbedFrac", "SteadyStateDiff"]]
        df_uc = df.loc[df["Dataset"] == "UC", ["PerturbedFrac", "SteadyStateDiff"]]
        df_merged = df_healthy.merge(df_uc, on="PerturbedFrac", how="inner", suffixes=["Healthy", "UC"])

        # Compute statistic (raw p-values)
        def fn(tbl):
            u = scipy.stats.mannwhitneyu(tbl["SteadyStateDiffHealthy"], tbl["SteadyStateDiffUC"], alternative="less")
            return u

        pvalues = df_merged.groupby("PerturbedFrac").apply(fn)
        pvalues_df = pd.DataFrame({"pvalue": pvalues.sort_values()})

        # Apply BH correction
        p_adjusted = []
        p_adj_prev = 0.0
        for i, (index, row) in enumerate(pvalues_df.iterrows()):
            p_adj = row["pvalue"].pvalue * pvalues_df.shape[0] / (i+1)
            p_adj = min(max(p_adj, p_adj_prev), 1)
            p_adjusted.append(p_adj)
            p_adj_prev = p_adj

        pvalues_df["pvalue_adj"] = p_adjusted
        pvalues_df = pvalues_df.sort_values("PerturbedFrac").reset_index()
        sig_indices = pvalues_df.index[pvalues_df["pvalue_adj"] <= 1e-3]

        print("Pvalues for deviations")
        display(pvalues_df)

    def plot_diversity(
        self,
        ax
    ):
        diversity_df = self.random_pert_diversity_df
        healthy_baseline = self.healthy_random_pert_baseline_diversity
        uc_baseline = self.uc_random_pert_baseline_diversity

        sns.swarmplot(x="PerturbedFrac",
                      y="Diversity",
                      hue="Dataset",
                      ax=ax,
                      data=diversity_df,
                      palette={"Healthy": self.healthy_color, "UC": self.uc_color},
                      size=3,
                      dodge=True)

        sns.boxplot(
            x="PerturbedFrac",
            hue="Dataset",
            y="Diversity",
            data=diversity_df,
            ax=ax,
            whis=[2.5, 97.5],
            showfliers=False,
            palette={"Healthy": self.healthy_color, "UC": self.uc_color},
            boxprops=dict(alpha=.4)
        )

        ax.plot([-0.5, 7], [healthy_baseline] * 2, color='blue', linestyle='dashed')
        ax.plot([-0.5, 7], [uc_baseline] * 2, color='orange', linestyle='dashed')

        ax.set_ylabel("Diversity (Normalized Entropy)")
        ax.set_xlabel("Fraction of OTUs perturbed")

        # =========== P-values + Benjamini-Hochberg correction
        df_healthy = diversity_df.loc[diversity_df["Dataset"] == "Healthy", ["PerturbedFrac", "Diversity"]]
        df_uc = diversity_df.loc[diversity_df["Dataset"] == "UC", ["PerturbedFrac", "Diversity"]]
        df_merged = df_healthy.merge(df_uc, on="PerturbedFrac", how="inner", suffixes=["Healthy", "UC"])

        # Compute statistic (raw p-values)
        def fn(tbl):
            u = scipy.stats.mannwhitneyu(tbl["DiversityHealthy"], tbl["DiversityUC"], alternative="greater")
            return u

        pvalues = df_merged.groupby("PerturbedFrac").apply(fn)
        pvalues_df = pd.DataFrame({"pvalue": pvalues.sort_values()})

        # Apply BH correction
        p_adjusted = []
        p_adj_prev = 0.0
        for i, (index, row) in enumerate(pvalues_df.iterrows()):
            p_adj = row["pvalue"].pvalue * pvalues_df.shape[0] / (i+1)
            p_adj = min(max(p_adj, p_adj_prev), 1)
            p_adjusted.append(p_adj)
            p_adj_prev = p_adj

        pvalues_df["pvalue_adj"] = p_adjusted
        pvalues_df = pvalues_df.sort_values("PerturbedFrac").reset_index()
        sig_indices = pvalues_df.index[pvalues_df["pvalue_adj"] <= 1e-3]

        print("Pvalues for diversity")
        display(pvalues_df)


class EigenvalueFigure():
    def __init__(self, healthy_pickle_path, uc_pickle_path, healthy_color=_default_healthy_color, uc_color=_default_uc_color):
        self.healthy_color = healthy_color
        self.uc_color = uc_color

        print("Computing Healthy dataset eigenvalues.")
        self.healthy_eig_X = self.compute_eigenvalues(healthy_pickle_path)

        print("Computing D dataset eigenvalues.")
        self.uc_eig_X = self.compute_eigenvalues(uc_pickle_path)

    def compute_eigenvalues(self, mcmc_pickle_path, upper_bound: float = 1e20):
        # ================ Data loading
        mcmc = md2.BaseMCMC.load(mcmc_pickle_path)
        si_trace = -np.absolute(mcmc.graph[STRNAMES.SELF_INTERACTION_VALUE].get_trace_from_disk(section='posterior'))
        interactions = mcmc.graph[STRNAMES.INTERACTIONS_OBJ].get_trace_from_disk(section='posterior')

        interactions[np.isnan(interactions)] = 0
        for i in range(len(mcmc.graph.data.taxa)):
            interactions[:,i,i] = si_trace[:,i]

        # ================ Eigenvalue computation
        N = interactions.shape[0]
        M = interactions.shape[1]
        eigs = np.zeros(shape=(N, M), dtype=np.complex)
        for i in tqdm(range(N)):  # range(arr.shape[0]):
            matrix = interactions[i]

            # only for interaction matrices
            matrix = np.nan_to_num(matrix, nan=0.0)

            slice_eigs = np.linalg.eigvals(matrix)

            # Throw out samples where eigenvalues blow up
            if np.sum(np.abs(slice_eigs) > upper_bound) > 0:
                print("Upper bound threshold {th} passed for sample {i}; skipping.".format(
                    th=upper_bound,
                    i=i
                ))
                continue

            eigs[i, :] = slice_eigs

        # Get real positive parts only.
        eigs = eigs.real.flatten()
        return eigs[eigs > 0]

    def plot(self, ax, alpha: float = 0.7):
        ####################################################
        # Eigenvalue histogram.
        ####################################################

        bins = np.linspace(0, 5e-9, 45)
        sns.histplot(self.uc_eig_X, ax=ax, bins=bins, label='Dysbiosis',
                     alpha=alpha, color=self.uc_color)
        sns.histplot(self.healthy_eig_X, ax=ax, bins=bins, label='Healthy',
                     alpha=alpha, color=self.healthy_color)

        ax.set_xlabel('Pos. Real Part of Eigenvalues', labelpad=20)
        ax.set_ylabel('Total count with multiplicity')
        ax.ticklabel_format(style="sci", scilimits=(0,0), useMathText=True, useOffset=False)


class CycleFigure():
    def __init__(self, healthy_pickle_path, uc_pickle_path, healthy_color=_default_healthy_color, uc_color=_default_uc_color):
        self.healthy_color = healthy_color
        self.uc_color = uc_color

        self.healthy_pickle_path = healthy_pickle_path
        self.uc_pickle_path = uc_pickle_path

        healthy_fixed_cluster = MdsineOutput(
            "Healthy",
            self.healthy_pickle_path
        )
        uc_fixed_cluster = MdsineOutput(
            "UC",
            self.uc_pickle_path
        )

        print("Computing cycles for Healthy dataset.")
        self.healthy_signed_cycles = self.compute_signed_statistics(
            healthy_fixed_cluster.get_clustered_interactions()
        )

        print("Computing cycles for Dysbiotic dataset.")
        self.uc_signed_cycles = self.compute_signed_statistics(
            uc_fixed_cluster.get_clustered_interactions()
        )
        self.n_samples = 10000

    def compute_signed_statistics(self, interactions):
        """
        Loop through each gibbs sample. For each gibbs sample, compute the number of cycles, assorted by sign.
        (Does not tell us exactly which cycles appear frequently.)
        """
        N = interactions.shape[0]
        ans = {
            '++': np.zeros(N),
            '--': np.zeros(N),
            '+-': np.zeros(N),
            '+++': np.zeros(N),
            '---': np.zeros(N),
            '++-': np.zeros(N),
            '--+': np.zeros(N)
        }
        for idx, mat in tqdm(enumerate(interactions), total=N):
            signed_cycle_counts = self.count_signed_cycles(mat)
            for sgn, counts in ans.items():
                counts[idx] = signed_cycle_counts[sgn]
        return ans

    def count_signed_cycles(self, mat):
        '''
        Count length 2 and 3 cycles, given a particular interaction matrix (corresp. to a single gibbs sample).
        '''
        adj = np.copy(mat).T
        plus_adj = np.zeros(shape=adj.shape, dtype=np.int)
        plus_adj[adj > 0] = 1
        minus_adj = np.zeros(shape=adj.shape, dtype=np.int)
        minus_adj[adj < 0] = 1

        return {
            '++': self.count_cycles_with_sign(['++'], plus_adj, minus_adj) / 2,
            '--': self.count_cycles_with_sign(['--'], plus_adj, minus_adj) / 2,
            '+-': self.count_cycles_with_sign(['+-', '-+'], plus_adj, minus_adj) / 2,
            '+++': self.count_cycles_with_sign(['+++'], plus_adj, minus_adj) / 3,
            '---': self.count_cycles_with_sign(['---'], plus_adj, minus_adj) / 3,
            '++-': self.count_cycles_with_sign(['++-', '+-+', '-++'], plus_adj, minus_adj) / 3,
            '--+': self.count_cycles_with_sign(['--+', '-+-', '+--'], plus_adj, minus_adj) / 3,
        }

    @staticmethod
    def count_cycles_with_sign(signs, plus, minus):
        """
        Multiply adjacency matrices to count the number of cycles. Example: #(+-+) = Trace[(M+) * (M-) * (M+)]
        """
        ans = 0
        for pattern in signs:
            M = np.eye(plus.shape[0])
            for sign in pattern:
                if sign == "+":
                    M = M @ plus
                elif sign == "-":
                    M = M @ minus
            ans = ans + np.sum(np.diag(M))
        return ans

    def plot(self, ax):
        lengths = [2, 3]
        signs = ['++', '--', '+-', '+++', '---', '++-', '--+']
        sign_order = {
            "({})".format(" ".join(sgn)): i for i, sgn in enumerate(signs)
        }
        df = pd.DataFrame(columns=["Count"],
                          dtype=np.float,
                          index=pd.MultiIndex.from_product(
                              [signs, range(self.n_samples), ["Healthy", "UC"]],
                              names=["Sign", "Index", "Dataset"]
                          ))

        for pattern in signs:
            df.loc[(pattern, slice(None), "Healthy")] = self.healthy_signed_cycles[pattern].reshape(-1, 1)
            df.loc[(pattern, slice(None), "UC")] = self.uc_signed_cycles[pattern].reshape(-1, 1)

        df = df.reset_index()
        df["Sign"] = df["Sign"].map({
            "++": "(+ +)",
            "--": "(- -)",
            "+-": "(+ -)",
            "+++": "(+ + +)",
            "---": "(- - -)",
            "++-": "(+ + -)",
            "--+": "(- - +)",
        })

        medianprops = dict(linestyle='--', linewidth=2.5)

        sns.violinplot(x="Sign",
                       y="Count",
                       hue="Dataset", data=df,
                       ax=ax,
                       scale="count",
                       cut=0,
    #                    inner="quartile",
                       bw=0.5,
                       palette={"Healthy": self.healthy_color, "UC": self.uc_color})

        # =========== P-values + Benjamini-Hochberg correction
        df_healthy = df.loc[df["Dataset"] == "Healthy", ["Sign", "Index", "Count"]]
        df_uc = df.loc[df["Dataset"] == "UC", ["Sign", "Index", "Count"]]
        df_merged = df_healthy.merge(
            df_uc,
            left_on=["Sign", "Index"],
            right_on=["Sign", "Index"],
            how="inner",
            suffixes=["Healthy", "UC"]
        )

        # Compute statistic (raw p-values)
        def fn(tbl):
            u = scipy.stats.mannwhitneyu(
                tbl["CountHealthy"], tbl["CountUC"],
                alternative="less"
            )
            return u

        pvalues = df_merged.groupby(
            "Sign"
        ).apply(fn)

        pvalues_df = pd.DataFrame(
            {"pvalue": pvalues.sort_values()}
        )

        # Apply BH correction
        p_adjusted = []
        p_adj_prev = 0.0
        for i, (index, row) in enumerate(pvalues_df.iterrows()):
            p_adj = row["pvalue"].pvalue * pvalues_df.shape[0] / (i+1)
            p_adj = min(max(p_adj, p_adj_prev), 1)
            p_adjusted.append(p_adj)
            p_adj_prev = p_adj

        pvalues_df["pvalue_adj"] = p_adjusted



        pvalues_df = pvalues_df.reset_index()
        sig_signs = pvalues_df.loc[pvalues_df["pvalue_adj"] <= 1e-3, "Sign"]
        display(pvalues_df)

        for sgn in sig_signs:
            idx = sign_order[sgn]
            pval = pvalues_df.loc[pvalues_df['Sign'] == sgn, 'pvalue_adj'].item()
            if pval <= 1e-4:
                indicator = "****"
            elif pval <= 1e-3:
                indicator = "***"
            else:
                indicator = "ERR"
            y = df.loc[df['Sign'] == sgn, "Count"].max()
            stat_annotate(idx-0.5, idx+0.5, y=y, h=0.5, ax=ax, color='black', desc=indicator)
        ax.set_ylabel("Number of cycles per sample")


# Preprocessing for Cycles (Figure F)
class MdsineOutput(object):
    '''
    A class to encode the data output by MDSINE.
    '''
    def __init__(self, dataset_name, pkl_path):
        self.dataset_name = dataset_name
        self.mcmc = md2.BaseMCMC.load(pkl_path)
        self.taxa = self.mcmc.graph.data.taxa
        self.name_to_taxa = {otu.name: otu for otu in self.taxa}

        self.interactions = None
        self.clustering = None

        self.clusters_by_idx = {
            (c_idx): [self.get_taxa(oidx) for oidx in cluster.members]
            for c_idx, cluster in enumerate(self.get_clustering())
        }

    @property
    def num_samples(self) -> int:
        return self.mcmc.n_samples

    def get_cluster_df(self):
        return pd.DataFrame([
            {
                "id": cluster.id,
                "idx": c_idx+1,
                "otus": ",".join([self.get_taxa(otu_idx).name for otu_idx in cluster.members]),
                "size": len(cluster)
            }
            for c_idx, cluster in enumerate(self.clustering)
        ])

    def get_interactions(self):
        if self.interactions is None:
            self.interactions = self.mcmc.graph[STRNAMES.INTERACTIONS_OBJ].get_trace_from_disk(section='posterior')
        return self.interactions

    def get_taxa(self, idx):
        return self.taxa.index[idx]

    def get_taxa_by_name(self, name: str):
        return self.name_to_taxa[name]

    def get_taxa_str(self, idx):
        tax = self.taxa.index[idx].taxonomy
        family = tax["family"]
        genus = tax["genus"]
        species = tax["species"]

        if genus == "NA":
            return "{}**".format(family)
        elif species == "NA":
            return "{}, {}*".format(family, genus)
        else:
            return "{}, {} {}".format(family, genus, species)

    def get_taxa_str_long(self, idx):
        return "{}\n[{}]".format(self.get_taxa(idx).name, self.get_taxa_str(idx))

    def get_clustering(self):
        if self.clustering is None:
            self.clustering = self.mcmc.graph[STRNAMES.CLUSTERING_OBJ]
            for cidx, cluster in enumerate(self.clustering):
                cluster.idx = cidx
        return self.clustering

    def get_clustered_interactions(self):
        clusters = self.get_clustering()
        otu_interactions = self.get_interactions()
        cluster_interactions = np.zeros(
            shape=(
                otu_interactions.shape[0],
                len(clusters),
                len(clusters)
            ),
            dtype=np.float
        )
        cluster_reps = [
            next(iter(cluster.members)) for cluster in clusters
        ]
        for i in range(cluster_interactions.shape[0]):
            cluster_interactions[i] = otu_interactions[i][np.ix_(cluster_reps, cluster_reps)]
        return cluster_interactions


### Figure 6A + 6B

A: Deviation from unperturbed steady state for random perturbations on a random α-fraction of taxa.

B: Change in α-diversity for the same set of simulated trajectories.

In [None]:
pert_path = "/content/fwsim_random_pert"
pert_fig = PerturbationSimFigure(Path(
    pert_path  # Used to be /data/cctm/darpa_perturbation_mouse_study/youn_notebooks/fwsim_random_pert
))

fig, axes = plt.subplots(2, 1, figsize=(15,20))

pert_fig.plot_deviations(axes[0], ymin=0.0, ymax=1.5)
pert_fig.plot_diversity(axes[1])
fig.savefig("{}/figure6ab.pdf".format(saveloc))

### Figure 6C

Histogram of real, positive parts of eigenvalues.

In [None]:
mcmc_healthy_path_unfixed = "/content/mixed_prior_unfixed/healthy-seed0-mixed/mcmc.pkl"
mcmc_uc_path_unfixed = "/content/mixed_prior_unfixed/uc-seed0-mixed/mcmc.pkl"
eig_fig = EigenvalueFigure(mcmc_healthy_path_unfixed, mcmc_uc_path_unfixed)

fig, ax = plt.subplots(1, 1, figsize=(15, 10))
eig_fig.plot(ax)
fig.savefig("{}/figure6c.pdf".format(saveloc))

## Figure 6D

Counting of module-level simple cycles.

In [None]:
mcmc_healthy_path_fixed = "/content/mixed_prior_fixed/healthy-seed0-mixed/mcmc.pkl"
mcmc_uc_path_fixed = "/content/mixed_prior_fixed/uc-seed0-mixed/mcmc.pkl"
cycle_fig = CycleFigure(
    mcmc_healthy_path_fixed, # Fixed clustering run
    mcmc_uc_path_fixed  # Fixed clustering run
)

fig, ax = plt.subplots(1, 1, figsize=(15, 10))
cycle_fig.plot(ax)
fig.savefig("{}/figure6d.pdf".format(saveloc))

## Figure 7

Custom code necessary for rendering Keystoneness figure.



In [None]:
#@title
import numpy as np
import mdsine2 as md2
from tqdm.notebook import tqdm
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm, gridspec
import matplotlib.colors as mcolors
import seaborn as sns


def cluster_nonmembership_df(md):
    entries = []
    for cluster in md.get_clustering():
        for otu in md.taxa:
            if otu.idx not in cluster.members:
                entries.append({
                    "ClusterID": cluster.id,
                    "OTU": otu.name
                })
    return pd.DataFrame(entries)


def cluster_membership_df(md):
    entries = []
    for cluster in md.get_clustering():
        for oidx in cluster.members:
            otu = md.get_taxa(oidx)
            entries.append({
                "ClusterOfOTU": cluster.id,
                "OTU": otu.name
            })
    return pd.DataFrame(entries)


def create_cmap(tag, nan_value="red"):
    cmap = cm.get_cmap(tag)
    cmap.set_bad(color=nan_value)
    return cmap


class KeystonenessFigure():
    def __init__(self, dataset_name: str, mcmc_pickle_path, subjset_path, fwsim_path):
        print("Loading pickle files.")
        self.md = MdsineOutput(dataset_name, mcmc_pickle_path)
        self.study = md2.Study.load(subjset_path)
        
        print("Loading dataframe from disk.")
        self.fwsim_df = pd.read_hdf(fwsim_path, key='df', mode='r')
        
        print("Compiling dataframe.")
        self.ky_df = self.generate_keystoneness_df()
        
        print("Compiling abundance data.")
        self.abundance_array = self.get_abundance_array()
        
        print("Extracting keystoneness values.")
        self.ky_array = self.get_ky_array()
        
        print("Extracting baseline abundances.")
        self.day20_array = self.get_day20_abundances()

    def generate_keystoneness_df(self):
        md = self.md
        fwsim_df = self.fwsim_df
        
        nonmembers_df = cluster_nonmembership_df(md)

        baseline = fwsim_df.loc[fwsim_df["ExcludedCluster"] == "None"]

        altered = fwsim_df.loc[fwsim_df["ExcludedCluster"] != "None"]
        altered = altered.merge(
            right=nonmembers_df,
            how="inner",
            left_on=["ExcludedCluster", "OTU"],
            right_on=["ClusterID", "OTU"]
        )

        merged = altered.merge(
            baseline[["OTU", "SampleIdx", "StableState"]],
            how="left",
            left_on=["OTU", "SampleIdx"],
            right_on=["OTU", "SampleIdx"],
            suffixes=["", "Base"] 
        )
        
        merged["DiffStableState"] = np.log10(merged["StableStateBase"] + 1e5) - np.log10(merged["StableState"] + 1e5)

        return merged[
            ["ExcludedCluster", "SampleIdx", "DiffStableState"]
        ].groupby(
            ["ExcludedCluster", "SampleIdx"]
        ).mean().rename(columns={"DiffStableState": "Ky"})
    
    def get_abundance_array(self):
        md = self.md
        fwsim_df = self.fwsim_df
        
        clustering = md.get_clustering()
        membership_df = cluster_membership_df(md)
        merged_df = fwsim_df.merge(
            membership_df,
            how="left",
            left_on="OTU",
            right_on="OTU"
        )
        
        abund_array = np.zeros(shape=(len(clustering) + 1, len(clustering)))

        # Baseline abundances (no cluster removed) -- sum across OTUs (per sample), median across samples.
        subset_df = merged_df.loc[merged_df["ExcludedCluster"] == "None"]
        subset_df = subset_df[
            ["ClusterOfOTU", "SampleIdx", "StableState"]
        ].groupby(
            ["ClusterOfOTU", "SampleIdx"]
        ).sum(
            # Aggregate over OTUs  (e.g. Baseline abundance of a cluster is the sum of its constituents.)
        ).groupby(
            level=0
        ).median(
            # Aggregate over samples
        )
        for cluster in clustering:
            abund_array[0, cluster.idx] = subset_df.loc[cluster.id]

        # Altered abundances (remove 1 cluster at a time)
        for removed_cluster in tqdm(clustering, total=len(clustering), desc="Heatmap Abundances"):
            subset_df = merged_df.loc[merged_df["ExcludedCluster"] == removed_cluster.id]

            # Compute the total abundance (over OTUs) for each cluster, for each sample. Then aggregate (median) across samples.
            subset_df = subset_df[
                ["ClusterOfOTU", "SampleIdx", "StableState"]
            ].groupby(
                ["ClusterOfOTU", "SampleIdx"]
            ).sum(
                # Aggregate over OTUs
            ).groupby(
                level=0
            ).median(
                # Aggregate over samples
            )

            for cluster in clustering:
                abund_array[removed_cluster.idx + 1, cluster.idx] = subset_df.loc[cluster.id]
        return abund_array
    
    def get_ky_array(self):
        # Group by Cluster, aggregate (mean/median) across samples.
        agg_ky_df = self.ky_df.groupby(level=0).median()
        return np.array(
            [agg_ky_df.loc[cluster.id, "Ky"] for cluster in self.md.get_clustering()]
        )
    
    def get_day20_abundances(self):
        M = self.study.matrix(dtype='abs', agg='mean', times='intersection', qpcr_unnormalize=True)
        day20_state = M[:, 19]
        cluster_day20_abundances = np.zeros(len(self.md.get_clustering()))

        for cidx, cluster in enumerate(self.md.get_clustering()):
            cluster_day20_abundances[cidx] = np.sum(
                [day20_state[oidx] for oidx in cluster.members]
            )
        return cluster_day20_abundances
    
    def plot(self, fig):
        md = self.md
        abund_array = self.abundance_array
        ky_array = self.ky_array
        day20_array = self.day20_array
        
        # Main abundance grid shows the _difference_ from baseline, instead of the abundances itself.
        n_clusters = len(ky_array)

        # =========== Pre-sorting. ===========
        ky_order = np.argsort(ky_array)
        ky_order = ky_order[::-1]
        ky_array = ky_array[ky_order]

        day20_array = day20_array[ky_order].reshape(1, len(day20_array))

        baseline_array = abund_array[[0], :]
        baseline_array = baseline_array[:, ky_order]

        altered_array = abund_array[1 + ky_order, :]  # Reorder the rows first (exclude the baseline row),
        altered_array = altered_array[:, ky_order]  # Then reorder the columns.

        baseline_diff_array = np.log10(baseline_array + 1e5) - np.log10(altered_array + 1e5)
        for i in range(baseline_diff_array.shape[0]):
            baseline_diff_array[i, i] = np.nan

        # =========== Heatmap settings. ========
        gridspec_kw = {"height_ratios":[1, 1, abund_array.shape[0] - 1], "width_ratios" : [1, abund_array.shape[1]]}

        # Colors and normalization (abund)
        abund_min = np.max([
            np.min(abund_array[abund_array > 0]), 
            1e5
        ])
        abund_max = np.min([
            np.max(abund_array[abund_array > 0]),
            1e13
        ])
        print("abund_min = {}, abund_max = {}".format(abund_min, abund_max))

        abund_cmap = create_cmap("Greens", nan_value="white")
        abund_norm = matplotlib.colors.LogNorm(vmin=abund_min, vmax=abund_max)

        # Colors and normalization (Ky)
        gray = np.array([0.95, 0.95, 0.95, 1.0])
        red = np.array([1.0, 0.0, 0.0, 1.0])
        blue = np.array([0.0, 0.0, 1.0, 1.0])
        n_interp=128

        top = blue
        bottom = red
        top_middle = 0.05 * top + 0.95 * gray
        bottom_middle = 0.05 * bottom + 0.95 * gray

        ky_cmap = matplotlib.colors.ListedColormap(
            np.vstack(
                [(1-t) * bottom + t * bottom_middle for t in np.linspace(0, 1, n_interp)]
                +
                [(1-t) * top_middle + t * top for t in np.linspace(0, 1, n_interp)]
            ),
            name='Keystoneness'
        )
        ky_cmap.set_bad(color="white")


        ky_min = 0.90 * np.min(ky_array) + 0.10 * np.min(baseline_diff_array[altered_array > 0])
        ky_max = 0.90 * np.max(ky_array) + 0.10 * np.max(baseline_diff_array[altered_array > 0])

        ky_norm = matplotlib.colors.TwoSlopeNorm(vmin=ky_min, vcenter=0, vmax=ky_max)
        def _forward(x):
            y = x.copy()
            positive_part = x[x > 0]
            y[x > 0] = np.sqrt(positive_part / ky_max)

            negative_part = x[x < 0]
            y[x < 0] = -np.sqrt(np.abs(negative_part / ky_min))
            return y
        def _reverse(x):
            y = x.copy()
            positive_part = x[x > 0]
            y[x > 0] = ky_max * np.power(positive_part, 2)

            negative_part = x[x < 0]
            y[x < 0] = -np.abs(ky_min) * np.power(negative_part, 2)
            return y
        ky_norm = matplotlib.colors.FuncNorm((_forward, _reverse), vmin=ky_min, vmax=ky_max)

        # Seaborn Heatmap Kwargs
        abund_heatmapkws = dict(square=False, 
                                cbar=False, 
                                cmap=abund_cmap, 
                                linewidths=0.5,
                                norm=abund_norm)
        ky_heatmapkws = dict(square=False, cbar=False, cmap=ky_cmap, linewidths=0.5, norm=ky_norm)

        # ========== Plot layout ===========
    #     [left, bottom, width, height]
        main_x = 0.67
        main_y = 0.5
        box_unit = 0.03
        main_width = box_unit * n_clusters
        main_height = main_width
        main_left = main_x - 0.5 * main_width
        main_bottom = main_y - 0.5 * main_width
        print("Left: {}, bottom: {}, width: {}, height: {}".format(main_left, main_bottom, main_width, main_height))
        print("Right: {}, Top: {}".format(main_left + main_width, main_bottom + main_height))

        ky_ax = fig.add_axes([main_left + main_width + 0.5 * box_unit, main_bottom, box_unit, main_height])
        abundances_ax = fig.add_axes([main_left, main_bottom, main_width, main_height])
        obs_ax = fig.add_axes([main_left, main_bottom + main_height + 1.5 * box_unit, box_unit * n_clusters, box_unit])
        baseline_ax = fig.add_axes([main_left, main_bottom + main_height + 0.5 * box_unit, box_unit * n_clusters, box_unit])

        # ========= Rendering. ==========
        # ====== Bottom left: Keystoneness
        hmap_ky = sns.heatmap(
            ky_array.reshape(len(ky_array), 1),
            ax=ky_ax,
            xticklabels=False,
            yticklabels=False,
            **ky_heatmapkws
        )
        hmap_ky.xaxis.set_tick_params(width=0)
        fig.text(main_left + main_width + 2*box_unit, main_y, "Keystoneness", ha='center', va='center', rotation=-90)

        for _, spine in hmap_ky.spines.items():
            spine.set_visible(True)
            spine.set_linewidth(1.0)

        # ====== Top right 1: Observed levels (day 20)
        hmap_day20_abund = sns.heatmap(day20_array, 
                                       ax=obs_ax, 
                                       xticklabels=False, 
                                       yticklabels=["Observation"],
                                       **abund_heatmapkws)
        hmap_day20_abund.set_yticklabels(hmap_day20_abund.get_yticklabels(), rotation=0)
        for _, spine in hmap_day20_abund.spines.items():
            spine.set_visible(True)
            spine.set_linewidth(1.0)
        fig.text(main_x, main_bottom + main_height + 3 * box_unit, "Steady State Abundance", ha='center', va='center')

        # ====== Top right 2: Baseline abundances
        hmap_base_abund = sns.heatmap(baseline_array, 
                                      ax=baseline_ax, 
                                      xticklabels=False, 
                                      yticklabels=["Simulation"],
                                      **abund_heatmapkws)
        hmap_base_abund.set_yticklabels(hmap_base_abund.get_yticklabels(), rotation=0)
        for _, spine in hmap_base_abund.spines.items():
            spine.set_visible(True)
            spine.set_linewidth(1.0)

        # ====== Bottom right: Abundances with clusters removed.
        ticklabels = [
            "{}{}".format(
                "H" if md.dataset_name == "Healthy" else "D", 
                c_idx + 1
            ) 
            for c_idx in ky_order
        ]
        hmap_removed_cluster_abund = sns.heatmap(
            baseline_diff_array, 
            ax=abundances_ax, 
            xticklabels=ticklabels, 
            yticklabels=ticklabels, 
            **ky_heatmapkws
        )
        # Draw a marker ("X") on top of NaNs.
        abundances_ax.scatter(*np.argwhere(np.isnan(baseline_diff_array.T)).T + 0.5, marker="x", color="black", s=100)
        abundances_ax.set_ylabel("Module Removed")
        abundances_ax.set_xlabel("Per-Module Change")
        for _, spine in hmap_removed_cluster_abund.spines.items():
            spine.set_visible(True)
            spine.set_linewidth(1.0)
        hmap_removed_cluster_abund.xaxis.set_ticks_position('bottom')
        hmap_removed_cluster_abund.set_xticklabels(
            hmap_removed_cluster_abund.get_xticklabels(), rotation=90, horizontalalignment='center'
        )
        abundances_ax.tick_params(direction='out', length=0, width=0)

        # ======= Draw the colormaps ========
        cbar_from_main = 0.25
        cbar_width = 0.01
        cbar_height = 0.35

        # Cbar on the right (steady state diff, green)
        cax = fig.add_axes([main_left - cbar_from_main, main_y - 0.5 * cbar_height, cbar_width, cbar_height])
        sm = matplotlib.cm.ScalarMappable(cmap=abund_cmap, norm=abund_norm)
        sm.set_array(np.array([]))
        cbar = fig.colorbar(sm, cax=cax)

        yticks = cbar.get_ticks()
        yticklabels = [str(np.log10(y)) for y in yticks]
        yticklabels[0] = "<{}".format(yticklabels[0])  
        cax.set_yticklabels(yticklabels)
        cax.set_ylabel("Log-Abundance")

        # Cbar on the left (Keyst., RdBu)
        cax = fig.add_axes([main_left - cbar_from_main - 2*cbar_width, main_y - 0.5 * cbar_height, cbar_width, cbar_height])
        sm = matplotlib.cm.ScalarMappable(cmap=ky_cmap, norm=ky_norm)
        sm.set_array(np.array([]))
        cbar = fig.colorbar(sm, cax=cax)
        cax.yaxis.set_ticks_position('left')
        cax.set_ylabel("Log-Difference from Base")
        cax.yaxis.set_label_position("left")

        yticks = cbar.get_ticks()
        yticklabels = ["{:.1f}".format(y) for y in yticks]
        cax.set_yticklabels(yticklabels)

        # Legend label text
        fig.text(
            main_left - cbar_from_main - cbar_width,
            main_y + 0.5 * cbar_height + 0.05,
            "Legend",
            ha='center', va='center',
            fontweight='bold'
        )

###Healthy Keystoneness values

In [None]:
healthy_subjset_path = "/content/MDSINE2_Paper/analysis/output/gibson/preprocessed/gibson_healthy_agg_taxa_filtered.pkl"
healthy_keystone_path = "/content/keystoneness/healthy_fwsim_day20.h5"
fig = plt.figure(figsize=(10,10))
KeystonenessFigure(
    "Healthy",
    mcmc_healthy_path_fixed,  # Fixed clustering run
    healthy_subjset_path,
    healthy_keystone_path  # Used to be: darpa_perturbation_mouse_study/youn_notebooks/keystoneness/healthy_fwsim_day20.h5
).plot(fig)
fig.savefig("{}/figure7_healthy.pdf".format(saveloc))

###Dysbiosis Keystoneness values

In [None]:
uc_subjset_path = "/content/MDSINE2_Paper/analysis/output/gibson/preprocessed/gibson_uc_agg_taxa_filtered.pkl"
uc_keystone_path = "/content/keystoneness/uc_fwsim_day20.h5"
fig = plt.figure(figsize=(10,10))

KeystonenessFigure(
    "Dysbiotic",
    mcmc_uc_path_fixed,  # Fixed clustering run
    uc_subjset_path,
    uc_keystone_path  # Used to be: darpa_perturbation_mouse_study/youn_notebooks/keystoneness/uc_fwsim_day20.h5
).plot(fig)
fig.savefig("{}/figure7_dysbiotic.pdf".format(saveloc))

## Supplemental Figure 1


In [None]:
!python gibson_inference/figures/supplemental_figure1.py \
    -file1 "gibson_inference/figures/preprocessed_all/gibson_healthy_agg_taxa.pkl" \
    -file2 "gibson_inference/figures/preprocessed_all/gibson_uc_agg_taxa.pkl" \
    -file3 "gibson_inference/figures/preprocessed_all/gibson_inoculum_agg_taxa.pkl"\
    -o_loc "output/gibson/plots"

## Supplemental Figure 2

The whole figure is created in multiple steps. In the first step we plot, the relative abundance at the phylum level. Then, we separately make the heatmaps showing the deseq results. The final figure is made by combining the sub-figures in Adobe illustrator. 


###First Step

In [None]:
!python gibson_inference/figures/supplemental_figure2.py \
    -file1 "gibson_inference/figures/preprocessed_all/gibson_healthy_agg_taxa.pkl" \
    -file2 "gibson_inference/figures/preprocessed_all/gibson_uc_agg_taxa.pkl" \
    -file3 "gibson_inference/figures/preprocessed_all/gibson_inoculum_agg_taxa.pkl" \
    -o_loc "output/gibson/plots"

###Heatmap showing DeSeq results at steady state at phylum level

The heatmaps are saved as mat_phylum_high_ss and mat_phylum_low_ss. 

In [None]:
!python gibson_inference/figures/deseq_heatmap_ss.py \
    -loc "gibson_inference/figures/supplemental_figure2_files" \
    -abund "high" \
    -txt "abundant_species_phylum" \
    -taxo "phylum" \
    -o "mat_phylum_high_ss" \
    -o_loc "output/gibson/plots"


!python gibson_inference/figures/deseq_heatmap_ss.py \
    -loc "gibson_inference/figures/supplemental_figure2_files" \
    -abund "low" \
    -txt "abundant_species_phylum" \
    -taxo "phylum" \
    -o "mat_phylum_low_ss" \
    -o_loc "output/gibson/plots"



####Heatmap showing DeSeq results for different perturbations at phylum level
The heatmaps are saved as mat_phylum_high and mat_phylum_low. 

In [None]:
!python gibson_inference/figures/deseq_heatmap_phylum.py \
    -loc "gibson_inference/figures/supplemental_figure2_files" \
    -abund "high" \
    -txt "abundant_species_phylum" \
    -taxo "phylum" \
    -o "mat_phylum_high" \
    -o_loc "output/gibson/plots"


!python gibson_inference/figures/deseq_heatmap_phylum.py \
    -loc "gibson_inference/figures/supplemental_figure2_files" \
    -abund "low" \
    -txt "abundant_species_phylum" \
    -taxo "phylum" \
    -o "mat_phylum_low" \
    -o_loc "output/gibson/plots"

###Heatmap showing results of DeSeq analysis at Order level

The heatmaps are saved as mat_order_high and mat_order_low. 

In [None]:
!python gibson_inference/figures/deseq_heatmap_order.py \
    -loc "gibson_inference/figures/supplemental_figure2_files" \
    -abund "high" \
    -txt "abundant_species_order" \
    -taxo "order" \
    -o "mat_order_high" \
    -o_loc "output/gibson/plots"


!python gibson_inference/figures/deseq_heatmap_order.py \
    -loc "gibson_inference/figures/supplemental_figure2_files" \
    -abund "low" \
    -txt "abundant_species_order" \
    -taxo "order" \
    -o "mat_order_low" \
    -o_loc "output/gibson/plots"


## Supplemental Figure 3

In [None]:
!python gibson_inference/figures/supplemental_figure3.py \
    -file1 "output/gibson/preprocessed/gibson_healthy_agg_taxa.pkl" \
    -file2 "output/gibson/preprocessed/gibson_uc_agg_taxa.pkl" \
    -o_loc "output/gibson/plots"


## Supplemental Figure 4

In [None]:
!python gibson_inference/figures/supplemental_figure4.py \
    --mdsine_path "/content/forward_sims/"\
    --clv_elas_path "/content/clv_results/results_rel_elastic/"\
    --clv_ridge_path "/content/clv_results/results_rel_ridge/"\
    --glv_elas_path "/content/clv_results/results_abs_elastic/"\
    --glv_ridge_path "/content/clv_results/results_abs_ridge/forward_sims_abs_ridge/"\
    --output_path "output/gibson/plots/"

## Supplemental Figure 5

In [None]:
!python gibson_inference/figures/supplemental_figure5.py \
     -loc1 "/content/mixed_prior_fixed/healthy-seed0-mixed/mcmc.pkl"\
     -loc2 "/content/mixed_prior_fixed/uc-seed0-mixed/mcmc.pkl"\
     -o_loc "output/gibson/plots"

