# Model Recovery 

The goal of this notebook is to show how to recover parameters from a model that didn't finish running.

This is critical for long-running jobs that mysteriously die, running intermediate diagnostics, 
or even for pausing and resuming model training.

There are a few steps that need to be performed before running this notebook.  

First, all of the checkpoint files need to be present.  You recognize these files that either end in `index` or `meta`.  You also need the file called `checkpoint`.  This file is particularly important -- you need to make sure that all of the paths within this file are correct (i.e. all of the paths in that file actually exist).  If those paths aren't correct (i.e. these files have been moved to another machine), then those paths need to be corrected.

The other crappy thing about recovering parameters is that you lose information about the underlying microbe-metabolite ids.  So you need to match those yourself.  The procedure on how to do that is outlined below, but this also means reproducing the **exact** preprocessing procedure performed in minstrel.  So it is absolutely critical to save the minstrel command that you ran to produce the model.  This also includes fetching all of the files used in the minstrel command.

Once you have all of the this information, you can start loading up the files for recovery.  Here, I have placed all of the input files that I need under the **data** directory (but this is just personal preference).

In [6]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from biom import load_table

# Below, we will be retrieving the input data files that were used with minstrel.
data_dir = '../data/'
microbes = load_table(os.path.join(data_dir, 'otus_nt.biom'))
metabolites = load_table(os.path.join(data_dir, 'lcms_nt.biom'))

Here, we will retreive the actual model parameters.  In this case, I've stored all of the checkpoint files to the `results` folder.

Personally, I prefer to name the folder based on the parameters that I ran for the model.  It allows for easy comparison between different models using Tensorboard.  I would recommend looking over the Songbird readme on more details on how to diagnostics using Tensorboard.

If you get dumb errors, it is likely because of the checkpoints file -- make sure that all of those filepaths are correct!

In [10]:
summary_dir = os.path.join(
    '../results/latent_dim_3_input_prior_1.00_output_prior_1.00_beta1_0.85_beta2_0.90'
)

with tf.Session() as sess:
    fname = os.path.join(
            summary_dir, 'model.ckpt-18681103.meta'
    )
    new_saver = tf.train.import_meta_graph(fname)
    new_saver.restore(sess, tf.train.latest_checkpoint(summary_dir))
    graph = tf.get_default_graph()
    qU = graph.get_tensor_by_name("qU:0").eval()
    qV = graph.get_tensor_by_name("qV:0").eval()
    qUbias = graph.get_tensor_by_name("qUbias:0").eval()
    qVbias = graph.get_tensor_by_name("qVbias:0").eval()

INFO:tensorflow:Restoring parameters from ../results/latent_dim_3_input_prior_1.00_output_prior_1.00_beta1_0.85_beta2_0.90/model.ckpt-18681103


Now I am going to reproduce the minstrel reprocessing procedure.  Specifically, this will filter out low abundance samples and low abundance features.  In addition, this will filter out any samples that aren't shared between the microbe and metabolite biom tables.

In [11]:
# the imported function is duplicating the filtering done here
# https://github.com/mortonjt/minstrel/blob/master/scripts/minstrel#L108
from minstrel.util import split_tables

# the parameters below were the ones used in the run
res = split_tables(
        microbes, metabolites,
        num_test=10,    # --num-testing-examples
        min_samples=10  # --min-feature-count
)

(train_microbes_df, test_microbes_df,
 train_metabolites_df, test_metabolites_df) = res

# Rank computation

Since we we are retrieving the underlying model parameters, we need to compute the ranks ourselves.  
We also need to match up the microbe/metabolite ids after the preprocessing as done in the previous step.
That can be done as follows.

In [14]:
U_ = np.hstack(
    (np.ones((qU.shape[0], 1)),
     qUbias.reshape(-1, 1), qU)
)
V_ = np.vstack(
    (qVbias.reshape(1, -1),
     np.ones((1, qV.shape[1])), qV)
)

ranks = pd.DataFrame(
    clr(centralize(clr_inv(np.hstack(
        (np.zeros((qU.shape[0], 1)), U_ @ V_))))),
    index=train_microbes_df.columns,
    columns=train_metabolites_df.columns)

# save the ranks to a csv file.
ranks.to_csv('../results/ranks.csv')

And that's it!  Now you have the conditional probability matrix (aka `ranks.csv`).

And if you want to convert these quantities to co-occurence probabilities, you just have to run `clr_inv` on each row of this matrix (as shown below).

This can tell you the probability of a metabolite occurring for a specific microbe. Ultimately, these probabilities can be ranked, so that you can find the most highly associated molecules for a given microbe.

In [17]:
probs = ranks.apply(clr_inv)
probs.head()

Unnamed: 0,X290.0883mz60.1277,X291.0489mz61.9903,X254.1601mz62.7917,X265.1154mz65.0188,X118.0839mz71.5519,X127.0605mz78.9317,X86.0928mz97.5564,X141.9563mz99.3446,X226.1042mz146.3815,X188.0692mz152.4098,...,X716.4112mz590.2013,X812.5824mz593.0447,X343.2854mz593.0689,X741.5355mz593.2980,X434.3601mz593.5585,X412.3782mz593.7169,X439.3937mz598.8498,X461.3787mz599.1107,X460.3757mz599.3083,X438.3935mz599.3187
TACGTAGGTGGCAAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTCCTTAAGTCTGATGTGAAAGCCCCCGGCTCAACCGGGGAGGGTCATTGGAAACTGGGGAACTTGAGTGCAGAAGAGGAGAGTGGAATTCCATG,0.00123,0.009182,0.005149,0.005459,0.003798,0.003523,0.004734,0.005888,0.022202,0.005883,...,0.00537,0.005991,0.004949,0.00431,0.004989,0.005669,0.004276,0.004757,0.005107,0.00571
TACGGAGGGTGCGAGCGTTAATCGGAATAACTGGGCGTAAAGGGCACGCAGGCGGTGACTTAAGTGAGGTGTGAAAGCCCCGGGCTTAACCTGGGAATTGCATTTCATACTGGGTCGCTAGAGTACTTTAGGGAGGGGTAGAATTCCACG,0.001099,0.001336,0.008822,0.008269,0.012435,0.012446,0.003789,0.007966,0.001347,0.005681,...,0.005835,0.005468,0.009051,0.007707,0.009216,0.008705,0.008762,0.007989,0.009276,0.008711
TACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGTGGCTTAACCATAGTATGCTTTGGAAACTGTTTAACTTGAGTGCAGAAGGGGAGAGTGGAATTCCATGT,2e-06,0.035953,0.014225,0.016832,0.005955,0.005985,0.098926,0.005973,0.001798,0.050488,...,0.008974,0.050385,0.021741,0.000309,0.016209,0.016135,0.029957,0.02752,0.015741,0.015083
TACGTAGGTCCCGAGCGTTGTCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGTGGCTTAACCATAGTACGCTTTGGAAACTGTTTAACTTGAGTGCAAGAGGGGAGAGTGGAATTCCATGT,0.002955,0.01717,0.005722,0.007336,0.00384,0.003514,0.007954,0.005589,0.019146,0.008787,...,0.005062,0.008416,0.006219,0.002795,0.005775,0.006606,0.005764,0.005966,0.005922,0.006556
TACGTAGGTCCCGAGCGTTATCCGGATTTATTGGGCGTAAAGCGAGCGCAGGCGGTTAGATAAGTCTGAAGTTAAAGGCTGTGGCTTAACCATAGTACGCTTTGGAAACTGTTTAACTTGAGTGCAGAAGGGGAGAGTGGAATTCCATGT,0.005554,0.025121,0.006316,0.010412,0.004542,0.004075,0.010293,0.005468,0.013734,0.011383,...,0.004329,0.009852,0.007775,0.002485,0.006701,0.007647,0.007593,0.006993,0.006928,0.0075
