# Amortized LDA implementation

## Loading libraries

In [None]:
# Load libraries
import logging
import pyro

import pandas as pd

from zzz_utils import *
from amortized_lda import *

import matplotlib.pyplot as plt

logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)

pyro.clear_param_store()
pyro.set_rng_seed(123)

## Amortized LDA graphical model

First we simulate a toy dataset to render the Pyro models.

In [None]:
nTopics = 3    # Topics
nCells = 50    # Cells
nRegions = 100 # Regions
N = [30] * nCells # Cells size

# Simulate data
obj = simulate_lda_dataset(nTopics = nTopics, nCells = nCells, 
                           nRegions = nRegions, N = N, 
                           a = [1] * nTopics, b = [1] * nRegions)
# transpose so it matches Pyro's input
D = torch.from_numpy(obj['D'].transpose())

Below we define the LDA model with Pyro. Note that data D is a matrix of nCounts x nCells, and here we assume that nCounts is the same across cells.

In [None]:
pyro.clear_param_store()
pyro.render_model(amortized_lda_model, model_args=(D, nTopics, nRegions), 
                  render_distributions=True, render_params=True)

## Amortized LDA guide (variational approximation)

In [None]:
# Create NN predictor
layer_sizes = "100-100"
pred = nn_predictor(nTopics, nRegions, layer_sizes)
guide = functools.partial(amortized_lda_guide, pred)
pyro.render_model(guide, model_args=(D, nTopics, nRegions, 20), 
                  render_distributions=True, render_params=True)

# Testing variational inference

## Simulate data
We use simulated data from the LDA model to test the amortized LDA inference performance. 
I.e. how close are inferred values to true values used to simulate the data. 

__Note__ there is the known identifiability issue of mixture and mixed-membership models, however we still would expect cell assignments to be consistent with the simulated data.


In [None]:
# Simulate data
nTopics = 2    # Topics
nCells = 1000    # Cells
nRegions = 300 # Regions
N = [100] * nCells # Cells size
a = [1/5] * nTopics
b = [1/10] * nRegions
# Simulate data
obj = simulate_lda_dataset(nTopics = nTopics, nCells = nCells, 
                           nRegions = nRegions, N = N, 
                           a = a, b = b)
# transpose so it matches Pyro's input
D = torch.from_numpy(obj['D'].transpose())

In [None]:
D.shape # simulated data dims nCounts x nCells

In [None]:
obj['theta_true'][1:10, ] # first 10 cells prob assignments to each topic

In [None]:
obj['phi_true'][:, 1:10].transpose() # first 10 region-topic probs

## Fit AmortizedLDA 

To perform inference for Amortized LDA with use ClippedAdam to optimize a 
__trace implementation of ELBO-based SVI__ ('TraceEnum_ELBO'), which supports exhaustive enumeration 
over discrete sample sites, in our case latent topic assignment __z__.

In [None]:
pyro.clear_param_store()

obj = fit_amortized_lda(D = D, nTopics = nTopics, nRegions = nRegions, nSteps = 3000, batch_size=64, lr = 0.01, seed = 123)

Here we plot the ELBO loss during optimisation.

In [None]:
# plot ELBO losses
losses = obj['losses']

plt.figure(figsize=(5, 2))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("ELBO loss")

## Assessing inferred parameters

Below we show estimates of the inferred model parameters. To show this, here I am just taking a sample from the 
posterior fit (i.e. calling the guide with the optimized set of variational parameters). Surely this is not the optimal way to summarise the posterior fit. 

__However__, if I take multiple samples from the posterior and subseqently summarise the posterior samples (e.g. by median), due to label switching that occurs when sampling 
$\theta \sim Dir(\alpha)$, the posterior mode will be useless. 

__TODO__ 

1. Define a better way to summarize the posterior distribution from posterior samples. E.g. by fixing the label switching problem, post-hoc after sampling from the posterior (similar approach to mixture models).
2. Make posterior predictive checks.
3. Need to understand Pyro's `poutine`.


In [None]:
# Get the fitted guide object, from which we will sample from.
guide = obj['guide']

In [None]:
# A single sample from the guide
post_sample = guide(D = D, nTopics = nTopics, nRegions = nRegions, batch_size=64)

In [None]:
post_sample['alpha'] # posterior alpha

In [None]:
post_sample['phi'][:, 1:10].detach().numpy().transpose() # posterior phi

In [None]:
post_sample['phi'][:, 1:10].detach().numpy().transpose() # posterior phi

In [None]:
pyro.param("b_vi")[:, 1:10]

In [None]:
# Extract optimized values of variational parameters

#for name, value in pyro.get_param_store().items():
#    print(name, pyro.param(name).data.cpu().numpy())

## Testing (ignore for now)

In [None]:
from sklearn.feature_extraction.text import CountVectorizer
obj['D_str'][2]

vectorizer = CountVectorizer(max_df=1.0, min_df=0, stop_words=None)
docs = torch.from_numpy(vectorizer.fit_transform(obj['D_str']).toarray())

vocab = pd.DataFrame(columns=['word', 'index'])
vocab['word'] = vectorizer.get_feature_names()
vocab['index'] = vocab.index

In [None]:
docs.shape

In [None]:
obj['D_freq']

In [None]:
obj['D_tfidf']

In [None]:
obj['D_freq']