In [2]:
from torch.distributions import constraints
from pyro.distributions.transforms import ELUTransform
from pyro.infer import MCMC, NUTS, Predictive
from pyro.distributions.constraints import positive_definite
import numpy as np
import pyro.distributions as dist
import pyro
import scipy as sp
import torch

## Wishart Distribution

In [3]:
from scipy.stats import wishart, chi2

In [4]:
x = np.linspace(1e-5, 8, 100)
w = wishart.rvs(df=3, scale=1)

x.shape, w

((100,), 3.545261090740356)

In [None]:
from numbers import Number
from torch.distributions.exp_family import ExponentialFamily

class Wishart(ExponentialFamily):
    r"""
    Creates a Wishart distribution parameterized by :attr:`df`
    scalar and :attr:`scale` matrix, which must be symmetric and positive
    definite.

    The Wishart distribution is often denoted
    .. math::
        W_p(\nu, \Sigma)
    where :math:`\nu` is the degrees of freedom and :math:`\Sigma` is the
    :math:`p \times p` scale matrix.
    The probability density function for `wishart` has support over positive
    definite matrices :math:`S`; if :math:`S \sim W_p(\nu, \Sigma)`, then
    its PDF is given by:
    .. math::
        f(S) = \frac{|S|^{\frac{\nu - p - 1}{2}}}{2^{ \frac{\nu p}{2} }
               |\Sigma|^\frac{\nu}{2} \Gamma_p \left ( \frac{\nu}{2} \right )}
               \exp\left( -tr(\Sigma^{-1} S) / 2 \right)
    If :math:`S \sim W_p(\nu, \Sigma)` (Wishart) then
    :math:`S^{-1} \sim W_p^{-1}(\nu, \Sigma^{-1})` (inverse Wishart).
    If the scale matrix is 1-dimensional and equal to one, then the Wishart
    distribution :math:`W_1(\nu, 1)` collapses to the :math:`\chi^2(\nu)`
    distribution.

    Example::

        >>> w = Wishart(torch.tensor(3), torch.tensor([]))
        >>> w.sample()  # 30% chance 1; 70% chance 0
        tensor([ 0.])

    Args:
        df (Number, Tensor): degrees of freedom
        scale (SPSD Matrix, Tensor): the log-odds of sampling `1`
    """
    arg_constraints = {'df': constraints.positive,
                       'scale': constraints.positive_definite}
    support = constraints.positive_definite

    def __init__(self, df, scale, validate_args=None):
        
        if probs is not None:
            is_scalar = isinstance(probs, Number)
            self.probs, = broadcast_all(probs)
        else:
            is_scalar = isinstance(logits, Number)
            self.logits, = broadcast_all(logits)
        self._param = self.probs if probs is not None else self.logits
        if is_scalar:
            batch_shape = torch.Size()
        else:
            batch_shape = self._param.size()
        super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)

## Model

In [147]:
pyro.enable_validation(True)

In [148]:
def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

In [149]:
N = 50
D = 3


X_train = np.random.dirichlet([3, 3, 3], (N, D))
Y_train = torch.tensor(np.argmax(np.sum(X_train, axis=1), axis=1))

X_test = np.random.dirichlet([3, 3, 3], (N, D))
Y_test = torch.tensor(np.argmax(np.sum(X_test, axis=1), axis=1))

In [150]:
Mu = pyro.param("M_mu", torch.eye(3).double())
Sigma = pyro.param("M_sigma", torch.tensor(np.tile(np.expand_dims(np.eye(3), 0), (3, 1, 1))).double())
M = dist.MultivariateNormal(Mu, Sigma)

In [151]:
M.shape(), M.batch_shape, M.event_shape

(torch.Size([3, 3]), torch.Size([3]), torch.Size([3]))

In [152]:
M_prime = M.sample()
M_prime, np.linalg.eigvals(M_prime), is_pos_def(M_prime)

(tensor([[ 2.5894,  1.9562, -0.1126],
         [-0.9601,  0.5194, -0.3486],
         [ 0.4703, -0.9829, -0.7804]], dtype=torch.float64),
 array([ 1.66210418+1.02720085j,  1.66210418-1.02720085j,
        -0.99590463+0.j        ]),
 False)

## Model

In [156]:
def model(X, Y=None):
    N, L, Dim = X.shape
    with torch.no_grad():
        M = pyro.sample("M", dist.MultivariateNormal(Mu, Sigma))
        print(M)
        for i in pyro.plate("sequences", N):
            x = X[i]
            if Y is None:
                y = None
            else:
                y = Y[i]
            x_i_j = torch.zeros(Dim).double()
            for j in pyro.markov(range(L)):
                print('M @ x_i_j', M @ x_i_j)
                print('x[j]', x[j])
                x_i_j = pyro.sample(f"x_{i}_{j}", dist.Dirichlet(ELUTransform(M @ x_i_j) + x[j]))
            y_i = pyro.sample(f"y_{i}", dist.Categorical(x_i_j), obs=y)

In [157]:
kernel = NUTS(model)
mcmc = MCMC(kernel, warmup_steps=2, num_samples=3)

In [158]:
mcmc.run(X_train, Y_train)

Warmup:   0%|          | 0/5 [00:00, ?it/s]

tensor([[ 0.3682,  0.2255, -0.2509],
        [-0.5418,  0.5020, -0.7181],
        [-0.4653,  0.9348,  0.7187]], dtype=torch.float64)
M @ x_i_j tensor([0., 0., 0.], dtype=torch.float64)
x[j] [0.25636354 0.43385948 0.30977699]


RuntimeError: Boolean value of Tensor with more than one value is ambiguous
 Trace Shapes:       
  Param Sites:       
 Sample Sites:       
        M dist  3 | 3
         value  3 | 3
sequences dist    |  
         value 50 |  

In [126]:
samples = mcmc.get_samples()

In [127]:
predictive = Predictive(model, samples)

In [128]:
predictive(X_test)

M @ x_i_j tensor([0., 0., 0.], dtype=torch.float64)
x[j] [0.34420525 0.34873567 0.30705909]
M @ x_i_j tensor([2.7776, 3.9797, 0.0981], dtype=torch.float64)
x[j] [0.31403831 0.47507691 0.21088477]
M @ x_i_j tensor([2.0215, 3.3321, 1.0980], dtype=torch.float64)
x[j] [0.21412387 0.47838123 0.3074949 ]
M @ x_i_j tensor([0., 0., 0.], dtype=torch.float64)
x[j] [0.16079475 0.18820794 0.65099731]
M @ x_i_j tensor([0.3430, 1.0972, 3.2434], dtype=torch.float64)
x[j] [0.50477315 0.13786737 0.35735947]
M @ x_i_j tensor([-0.4387,  1.0138,  4.3317], dtype=torch.float64)
x[j] [0.3050602  0.4569554  0.23798439]


ValueError: The parameter concentration has invalid values
 Trace Shapes:       
  Param Sites:       
 Sample Sites:       
        M dist  3 | 3
         value  3 | 3
sequences dist    |  
         value 50 |  
    x_0_0 dist    | 3
         value    | 3
    x_0_1 dist    | 3
         value    | 3
    x_0_2 dist    | 3
         value    | 3
      y_0 dist    |  
         value    |  
    x_1_0 dist    | 3
         value    | 3
    x_1_1 dist    | 3
         value    | 3

In [109]:
pyro.poutine.trace(model).get_trace(X_train)

ValueError: The parameter concentration has invalid values
 Trace Shapes:       
  Param Sites:       
 Sample Sites:       
        M dist  3 | 3
         value  3 | 3
sequences dist    |  
         value 50 |  
    x_0_0 dist    | 3
         value    | 3
    x_0_1 dist    | 3
         value    | 3