# Variational message passing
This is an extremely crude implementation of the Variational Message passing example in the original paper by 
Winn & Bishop.

Model description: We have a precision a priori Gamma distributed, and a mean a priori following a Gaussian.

We then have a dataset sampled from Normal(5, 1). The key is to find posterior mean and precision.
This is done using VMP. Everything here is defined tailor-made for the current model. Hence, no clever scoping of
messages or whatever goes on; we simply hard-code the structure and the sequencing, and get the expected result.

There is no ELBO calculation here, and also no convergence monitoring.

**Imports**

In [None]:
import numpy as np
from scipy import special, stats
import matplotlib.pyplot as plt

## Definition of the Exponential Family superclass
Many of the operations in VMP are shared among ExpFam distributions. We therefore define a superclass taht takes of that, and make subclasses for each of the actual distributions afterwards.

The methods we need are as follows:
* **Initialization** (done by the `__init__`-method). The code here assumes that the **priors** are supplied at initialization time. Furthermore, the method sets aside space for an observation (may be None) and messages.
* **Translation** between _natural parameters_ and _moment parameters_. As we have seen previously, the way to do the translation depends on the distribution, hence these methods (`translate_to_moment_parameters` and `translate_to_natural_parameters` are virtual in the superclass).
* Accepting **incoming messages**. This amounts to taking care of what comes from whom, and is straightforward using the `dict`-datastructure. It is the same for all distributions, and will not be refined in the subclasses.
* Sending **outgoing messages**. We need to know how to send messages to both parents and children. What is sent depends on the distribution of the sender and the receover, hence is overwritten in the subclasses.

This class is finished, no need to fiddle with it.

In [None]:
class ExponentialFamilyDistribution:
    def __init__(self, moment_parameters=None, natural_parameters=None, observation=None):
        if moment_parameters is not None:
            self.moment_parameters = np.array(moment_parameters)
            self.translate_to_natural_parameters()
        else:
            self.natural_parameters = np.array(natural_parameters)
            self.translate_to_moment_parameters()

        # Keep the prior.
        self.prior = self.natural_parameters

        # Enable a storage point for all messages. These are stored in a dictionary
        self.messages = {}

        # If we have an observation, then store that, too
        self.observation = observation

    def translate_to_moment_parameters(self):
        # This depends on the distribution type, so will be defined in subclasses
        raise NotImplementedError

    def translate_to_natural_parameters(self):
        # This depends on the distribution type, so will be defined in subclasses
        raise NotImplementedError

    def accept_incoming_messages(self, message, sender):
        self.messages[sender] = message

    def update_based_on_incoming_messages(self):
        # Add msg to the natural parameters
        self.natural_parameters = self.prior
        for sender in self.messages.keys():
            message = self.messages[sender]
            assert np.all(np.shape(message) == np.shape(self.natural_parameters))
            self.natural_parameters += np.array(message)

        # Update moment parameters -- just to have them for plotting and whatnot
        self.translate_to_moment_parameters()

        # The messages we have received are all incorporated, and will be voided.
        self.messages = {}

    def generate_message_to_child(self):
        # Messages are calculated as the expectation of the sufficient statistics.
        # It is depending on the distribution, hence defined at subclass level
        raise NotImplementedError

    def generate_message_to_parent(self, receiver):
        # Messages are calculated based on the recipient. It sends thee expected
        # natural parameters in the conjugate representation.
        # Quite "nasty" stuff, and written at subclass level
        raise NotImplementedError


# The Gaussian class
We implement the Gaussian distribution as a subclass of the general ExponentialFamily. 
Some stuff is directly inherited from the superclass (namely initialization, accepting messages, and updating natural parameters based on the selected messages).

**These things are to be implemented:**
* Translation **from moment parameters to natural parameters** (`translate_to_moment_parameters`). This has already beenm implemented in the previous code task. Note that we only care about natural parameters here, not the log partition function. 
* Translation **from natural parameters to moment parameters** (`translate_to_natural_parameters`). This is the "inverse" of `translate_to_moment_parameters`.
* Generate **messages to children:** In this model, the only message from a Gaussian to a child is from the laten/unobserved variable representing the unknown mean and down to the data-points, that are also Gaussian.
* Generate **message to parents**: The observed variables will send a message to the mean as well as to the precision variable. That means, we must be able to send to both Gaussian and Gamma parents. Messages are described in the slides, and should be in `generate_message_to_parent`.


In [None]:
class Gaussian(ExponentialFamilyDistribution):
    def __init__(self, moment_parameters=None, natural_parameters=None, observation=None):
        # Initialization is simply handed off to superclass
        super(Gaussian, self).__init__(moment_parameters=moment_parameters,
                                       natural_parameters=natural_parameters,
                                       observation=observation)

    def translate_to_moment_parameters(self):
        # Rule is: Natural params == (mu*q, -.5q), where q = 1/variance
        # Want to return [mu, sigma_square]
        assert self.natural_parameters is not None
        
        self.moment_parameters = [????, ????]
        return self.moment_parameters

    def translate_to_natural_parameters(self):
        # Rule is: Natural params == (mu*q, -.5q), where q = 1/variance
        # Want to return [eta1, eta2]
        assert self.moment_parameters is not None

        self.natural_parameters = [????, ????]
        return self.natural_parameters

    def generate_message_to_child(self):
        # In our model a Gaussian sends a message to a child only in one case:
        # The variable mu sends to its children X_i. X_i are Gaussians.
        # The Gaussian will send a message giving
        # E[X], E[X**2]
        # where the expectation is to be taken over its own distribution.
        # Using moment params this is simple:
        #     E[X] = self.moment_parameters[0]
        #     E[X**2] = self.moment_parameters[0]**2 + self.moment_parameters[1]
        
        return = [????, ????]

    def generate_message_to_parent(self, receiver):
        # The receiver can be either a Gaussian (X_i sends to mu) or a Gamma (X_i sends to tau).
        # The shape of the message depends on the receiver, so we need to make sure we do this accordingly.
        if isinstance(receiver, Gaussian):
            # Message to a Gaussian is the local model's
            # best guess on the natural parameters,
            # [ E[Q]  * data_value,  -.5 * E(Q)]
            # These are the expected natural parameters
            # For this to work, the node must have already received the message from parents determining Q.
            # We therefore check for incoming messages being filled, where the sender is Gamma distributed.
            incoming_from_gamma = None
            for sender in self.messages.keys():
                # Go through all messages the variable has received
                if isinstance(sender, Gamma):
                    # We have something from a Gamma. Since the model is known to have only
                    # Gamma parent, this is the one we look for
                    incoming_from_gamma = self.messages[sender]
                    break
            # Check we have something there
            assert incoming_from_gamma is not None

            # The value of incoming is a message from my Gamma-distributed parent.
            # It has the information [E[log(tau)], E[tau]], and the E[tau] part plays the role of Q here
            # Next, the observation (X_i is observed) will play the role as E[mean].
            message = [????, ????]

        elif isinstance(receiver, Gamma):
            # Message to a Gamma is the local model's
            # best guess on the natural parameters,
            # [.5,  -.5(x_i^2 - 2* x_i * E[mu] + E[mu^2])]
            # These are the expected natural parameters

            # For this to work, I must have already received the message from my parent determining mu.
            # So, check if we have indeed received an incoming message from a Gaussian.
            # In a general setup we may have more than two parents, and then this simple check would not work.
            # Rather, we should go through all parents and children, check for all but one message being received,
            # and then the variable with a missing message would be the one we could send to.
            incoming_from_gauss = None
            for sender in self.messages.keys():
                if isinstance(sender, Gaussian):
                    incoming_from_gauss = self.messages[sender]
            assert incoming_from_gauss is not None

            # Now, incoming is a message from my Gaussian-distributed parent.
            # It has the information [E[mu], E[mu**2]]
            # The message to send to the Gamma is [1/2, -1/2 E[ (x_i - mu)**2 ]
            message = [????, ????]

        else:
            raise ValueError("Not a conjugate family member this code supports.")

        return message

# The Gamma class
Much like the Gauss-class, we here need to implement the distribution-specific operations. We make a simplification based on the model structure: There is no Gamma-distributed variable in the model that has a parent, hence we need n ot consider an implemebntation of `generate_message_to_parent`. 

In [None]:
class Gamma(ExponentialFamilyDistribution):
    def __init__(self, moment_parameters=None, natural_parameters=None, observation=None):
        super(Gamma, self).__init__(moment_parameters=moment_parameters,
                                    natural_parameters=natural_parameters,
                                    observation=observation)

    def translate_to_moment_parameters(self):
        # Rule is: Natural params == [alpha - 1, -beta]
        # Want to return [alpha, beta]
        assert self.natural_parameters is not None

        
        self.moment_parameters = [????, ????]
        return self.moment_parameters

    def translate_to_natural_parameters(self):
        # Rule is: Natural params == [alpha - 1, -beta]
        # Want to return [eta1, eta2]
        assert self.moment_parameters is not None
        
        
        self.natural_parameters = [????, ????]
        return self.natural_parameters

    def generate_message_to_child(self):
        # The Gamma will send a message to Gaussian variables giving
        # E[log(X)], E[X]]
        # where the expectation is to be taken over its own distribution.
        # Using moment params this is simple:
        #     E[log(X)] = - log(self.beta) + digamma(alpha)
        #     E[X] = alpha / beta

        return = [????, ????]

    def generate_message_to_parent(self, receiver):
        # No parent for the Gamma in this model, so we do not have to implement it
        raise NotImplementedError

# VariationalMessagePassingExample class
Simply implements the example: During initializatoin it generates the required variables, then a training procedure that hardcodes the model structure (that is, which variable sends messages to which recepients before a variable can be updated), and finally a plotting method that shows the posterior over the (mean, precision)-space. 

This class is finished, no need to fiddle with it.

In [None]:
class VariationalMessagePassingExample:

    def __init__(self, data_set):

        # data_set is a vector of observations -- length = N
        self.data_set = data_set
        self.N = len(self.data_set)

        # tau: Scalar value for precision.
        # A priori Gamma distributed with "uninormatve" parameters:
        self.tau = Gamma(moment_parameters=[1e-3, 1E-3])

        # mu: Scalar value for the mean.
        # A priori Normal distributed with gigh variance:
        self.mu = Gaussian(moment_parameters=[0, 1E6])

        # observations: These are the x_i variables.
        # The prior distribution p(x_i|mu, tau) is a
        # Gaussian with mean mu and precision tau, but since the
        # variables are all observed we do not really have to relate to this
        # during "start-up". Rather, we will initialize with some numerical values just to get going.
        self.observations = []
        for idx in range(self.N):
            self.observations.append(Gaussian(moment_parameters=[0, 1],
                                              observation=self.data_set[idx]))

    def train(self, no_iter=10, plot_all=False):

        # Here we strongly utilize the structure of the domain.
        # If we do the following passing scheme, everything will work just fine:
        # 1) tau sends to all observations.
        # 2) observations x_i send to mu.
        # 3) mu needs updating, as it has received all its messages
        # 4) mu sends to all observations
        # 5) observations x_i send to tau.
        # 6) tau can update itself
        # At this point we have done one cycle.
        # mu and tau are updated, while x_i (observed) does not need to do anything.
        # If we want to do another cycle, we just go back to step 1) again, and send
        # a new message, based on the updated information at tau, to the observations.

        for i in range(no_iter):

            # Message from tau to all the observations:
            msg = self.tau.generate_message_to_child()
            for obs in self.observations:
                obs.accept_incoming_messages(msg, self.tau)

            # Message from the observations to mu:
            for obs in self.observations:
                msg = obs.generate_message_to_parent(receiver=self.mu)
                self.mu.accept_incoming_messages(msg, obs)

            # Update mu
            self.mu.update_based_on_incoming_messages()

            # Message from mu to all the data-nodes:
            msg = self.mu.generate_message_to_child()
            for obs in self.observations:
                obs.accept_incoming_messages(msg, self.mu)

            # Message from the data-nodes to tau:
            for obs in self.observations:
                msg = obs.generate_message_to_parent(receiver=self.tau)
                self.tau.accept_incoming_messages(msg, obs)

            # Update tau
            self.tau.update_based_on_incoming_messages()

            print("\n\nUpdated {:d} time(s):".format(i + 1))
            print("Posterior mean is Normal({:.2f}, {:.2f})".format(
                self.mu.moment_parameters[0], self.mu.moment_parameters[1]))
            print("Posterior precision is Gamma({:.2f}, {:.2f}), with mean {:.2f}".format(
                self.tau.moment_parameters[0], self.tau.moment_parameters[1],
                self.tau.moment_parameters[0] / self.tau.moment_parameters[1]))

            if plot_all or i == no_iter - 1:
                self.plot_curve(iteration=i)
                
    def plot_curve(self, iteration):
        # This method plots the posterior over the parameter space (mu, tau)
        mu_range = np.linspace(3, 7, 500).astype(np.float32)
        precision_range = np.linspace(1E-10, 2, 500).astype(np.float32)
        mu_mesh, precision_mesh = np.meshgrid(mu_range, precision_range)
        variational_log_pdf = \
            stats.norm.logpdf(mu_mesh,
                              loc=self.mu.moment_parameters[0],
                              scale=self.mu.moment_parameters[1]) + \
            stats.gamma.logpdf(x=precision_mesh,
                               a=self.tau.moment_parameters[0],
                               scale=1. / self.tau.moment_parameters[1])

        plt.figure()
        plt.contour(mu_mesh, precision_mesh, variational_log_pdf, 25)
        plt.title('Iteration {:d}'.format(iteration + 1))
        plt.show()
        plt.close('all')

## Finally, this is some code to test everything

In [None]:
if __name__ == '__main__':
    np.random.seed(123)
    dataset = 5 + np.random.randn(4)
    example = VariationalMessagePassingExample(dataset)
    example.train(no_iter=5, plot_all=True)
