In [None]:
import jax
import jax.numpy as np
from jax.scipy.linalg import solve_triangular

from utils import *

from priors import *
from transformations import *
from likelihoods import *
from advi import *

import matplotlib.pyplot as plt
plt.style.use('ggplot')

# Data

## 2

In [None]:
X_mnist_2 = np.load('X_mnist_2.npy')
X_mnist_2_proj = np.load('X_mnist_2_proj.npy')

## The Evens: 4, 6, 8

In [None]:
X_mnist_evens = np.load('X_mnist_evens.npy')
X_mnist_evens_proj = np.load('X_mnist_evens_proj.npy')

## Data helpers

In [None]:
def project_and_resize_2(X):
    return np.reshape(X @ X_mnist_2_proj, (28, 28))
#

In [None]:
def project_and_resize_evens(X):
    return np.reshape(X @ X_mnist_evens_proj, (28, 28))
#

# Plotting

In [None]:
def plot_image_grid(image_grid, projection_func, fig_width = 10):
    fig, axes = plt.subplots(len(image_grid), len(image_grid[0]))
    for row in range(len(image_grid)):
        for col in range(len(image_grid[row])):
            axes[row, col].imshow(projection_func(image_grid[row][col]))
            axes[row, col].set_axis_off()
        #
    #
    fig.set_figwidth(fig_width)
    fig.set_figheight(fig_width*len(image_grid)/len(image_grid[0]))
    plt.show()
#

# Probabilistic model builder
* Parameters: a list of components, with each component consisting of a mean vector and precision matrix, as well as a set of mixture weights.
* Gaussian prior for the mean vector parameter, with mean hyperparameter set to 0, and mean precision parameter set to .001
* Wishart prior for the precision matrix parameter, with inverse scale matrix set to identity, and dof set to `D+dof_delta`, with `dof_delta = 2`.

Should return:

1. An initialization of the variational parameters for the mean and precision. All parameters initialized to zero.
2. The transformations for all model parameters.
3. The ELBO.

In [None]:
def assemble_mixture_model(X, n_components, concentration, prior_precision = 1e-3, dof_delta = 2):
    D = X.shape[1]
    D_prec = D + D*(D-1)//2
    init_variational_params, log_priors, transformations = {}, {}, {}
    
    return init_variational_params, transformations, elbo
#

# Bayesian mixture model for digit 2

In [None]:
T = 1000 # number of optimization steps
starting_step_size, alpha, mc_samples = .2, .5, 4 # use these parameters for optimization (starting off, lower MC samples for computation efficiency)

# TODO: Assemble probabilistic model, optimize ELBO


# Bayesian mixture model for even digits

In [None]:
T = 1000 # number of optimization steps
starting_step_size, alpha, mc_samples = .2, .5, 4 # use these parameters for optimization

# TODO: Assemble probabilistic model, optimize ELBO
