In [1]:
:ext TemplateHaskell DataKinds DeriveGeneric TypeApplications ScopedTypeVariables
:ext UndecidableInstances FlexibleContexts PartialTypeSignatures

In [2]:
import           Inference.Conjugate

import           Control.Monad                  ( foldM )
import           Data.Dynamic                   ( Dynamic
                                                , toDyn
                                                , fromDynamic
                                                , Typeable
                                                )
import qualified Data.Sequence                 as S
import qualified Data.Vector                   as V
import           GHC.Generics
import           Lens.Micro.TH                  ( makeLenses )
import           System.Random.MWC.Probability

# Conjugate Inference with Probabilistic Programming

## Bayesian Models

Bayesian modeling specifies a domain as a probability distribution that links observed ($x$) and latent ($z$) variables.
With Bayesian inference, we try to obtain the *posterior distribution* of the latent variables given the observations:
$$ p(z \mid x) = \dfrac{p(x, z)}{p(z)} $$

Often, the relationship between the latent and observed variables is assumed to be a generative process,
by which the observations are "generated" based on the latent variables.
In this case, the model factorizes into a *prior distribution* over latent variables $p(z)$ and a likelihood $p(x \mid z)$.
$$ p(x, z) = p(x \mid z) p(z) $$

If we have many observations generated (independently) by the same process,
some latent variables can be *global* (i.e., they affect all datapoints in the same way),
while others are *local* to each datapoint:
$$ p(x, z) = \prod_i \left( p(x_i \mid z_i^l, z^g) p(z_i^l \mid z^g) \right) p(z^g) $$
A common case is that the global variables are parameters of the probability distributions used in the likelihood.

## Conjugate Priors

[Wikipedia](https://en.wikipedia.org/wiki/Conjugate_prior)

When the posterior of a latent variable comes from the same family of distributions as its prior,
then the posterior and prior are called *conjugate*.
Since the posterior is a result of the prior and the likelihood,
conjugacy depends on the combination of both.
For example, a beta distribution is a conjugate prior to a Bernoulli likelihood:
- prior: $\theta \sim Beta(\alpha, \beta)$
- likelihood: $\forall i: x_i \sim Bernoulli(p)$
- posterior: $\theta \mid x \sim Beta(\alpha', \beta')$

In this model, each $x_i$ is a local observed variable, while $\theta$ is a global latent variable
(there are no local laten variables).
Here, as in many other cases, the posterior distribution can be obtained analytically:
- $\alpha' = \alpha + \sum_i x_i = \alpha + s$
- $\beta' = \beta + \sum_i (1 - x_i) = \beta + f$

for $s$ successes and $f$ failures.

## Locally Conjugate Models

If we have a model that factorizes into a prior distribution over independent global variables
and a likelihood that uses these global variables directly in a sequence of observable choices
that form conjugate pairs with the priors,
then the inference process can be performed analytically.

Example:

The global parameters consist of a coin probability $\theta$
and two categorical probability vectors $cat1$ and $cat2$.
They are drawn independently from the following priors:
- $\theta \sim Beta(0.5, 0.5)$
- $cat1 \sim Dirichlet_3(0.5)$
- $cat2 \sim Dirichlet_3(0.5)$

The likelihood uses these probabilities to generate the following observed variables:
- $coin \sim Bernoulli(\theta)$
- if $coin = 1$
  - then $category \sim Categorical_3(cat1)$
  - else $category \sim Categorical_3(cat2)$
  
Thus, for every execution of the generative process described in the likelihood,
we observe all values sampled in the process, which allows us to re-trace it,
for example $coin = 1, category = 0$.

# Expressing a Model in Code

## Global Variables and Priors

Since all variables in the prior are independent,
we can express it through a collection of variables with individual priors,
e.g. a record:

In [3]:
-- An example of a record that describes global variables.
-- It can be instantiated with different type constructors for 'f' (kind @* -> *@),
-- which allows it to represent both the hyperparameters and parameters of the model.
data ExampleParams f =
  ExampleParams { _epTheta :: f Beta
                , _epCat1 :: f (Dirichlet 3)
                , _epCat2 :: f (Dirichlet 3)
                }
  deriving (Generic)

Let's add a bit of extra code to make it easier to work with this record:

In [4]:
-- We need lenses to the fields of ExampleParams, which can be generated automatically.
makeLenses ''ExampleParams

-- The Show instance for ExampleParams needs to be standalone
-- because of the Show (f p l) constraints that GHC can't figure out on its own.
-- Here we specify it manually for nicer output.
instance ( Show (f Beta)
         , Show (f (Dirichlet 3)))
         => Show (ExampleParams f) where
  show (ExampleParams p cat1 cat2) =
    "ExampleParams"
      <> "\n  epTheta = "
      <> show p
      <> "\n  epCat1  = "
      <> show cat1
      <> "\n  epCat2  = "
      <> show cat2

This record can be instantiated in different ways, e.g. to represent the values of the global variables:

In [5]:
-- newtype ProbsRep p = ProbsRep { runProbs :: Probs (AsPrior p)}

exampleProbs :: ExampleParams ProbsRep
exampleProbs = ExampleParams { _epTheta = ProbsRep 0.7
                             , _epCat1  = ProbsRep $ V.fromList [0.3, 0.1, 0.6]
                             , _epCat2  = ProbsRep $ V.fromList [0.1, 0.8, 0.1]
                             }

exampleProbs

ExampleParams
  epTheta = ProbsRep {runProbs = 0.7}
  epCat1  = ProbsRep {runProbs = [0.3,0.1,0.6]}
  epCat2  = ProbsRep {runProbs = [0.1,0.8,0.1]}

... or the parameters of the priors:

In [6]:
-- newtype HyperRep p = HyperRep { runHyper :: Hyper (AsPrior p) }

examplePrior :: ExampleParams HyperRep
examplePrior = ExampleParams { _epTheta = HyperRep (0.5, 0.5)
                              , _epCat1  = HyperRep $ V.fromList [0.5, 0.5, 0.5]
                              , _epCat2  = HyperRep $ V.fromList [0.5, 0.5, 0.5]
                              }

examplePrior

ExampleParams
  epTheta = HyperRep {runHyper = (0.5,0.5)}
  epCat1  = HyperRep {runHyper = [0.5,0.5,0.5]}
  epCat2  = HyperRep {runHyper = [0.5,0.5,0.5]}

The two helper types `ProbsRep` and `HyperRep` instantiate the fields of the record with the appropriate types
for the support and the parameters of the prior distribution, respectively.
These are defined as type-level functions on the distribution types, e.g.:
```haskell
Hyper (AsPrior Beta) = Params  Beta = (Double, Double)
Probs (AsPrior Beta) = Support Beta = Double
```

For standard priors such as the Jeffreys prior, we can automatically generate the corresponding values
using the `Generics` instance of `ExampleParams`:

In [7]:
-- jeffreysPrior :: forall a. Jeffreys a => Hyper a
prior = jeffreysPrior @ExampleParams
prior

ExampleParams
  epTheta = HyperRep {runHyper = (0.5,0.5)}
  epCat1  = HyperRep {runHyper = [0.5,0.5,0.5]}
  epCat2  = HyperRep {runHyper = [0.5,0.5,0.5]}

Similarly, we can generically sample values for the global variables from a given prior:

In [8]:
-- create a random generator
gen <- createSystemRandom

-- sample global variables from the prior
probs <- sample (sampleProbs @ExampleParams prior) gen
probs

ExampleParams
  epTheta = ProbsRep {runProbs = 6.374331018604273e-4}
  epCat1  = ProbsRep {runProbs = [0.10712072138983836,0.5807054554863914,0.31217382312377034]}
  epCat2  = ProbsRep {runProbs = [8.153175746581849e-3,0.9606048595216943,3.124196473172396e-2]}

## Likelihood

The likelihood can be a more complex process than the prior (in which all variables are independent),
so we express it as a probabilistic program.
This program is polymorphic, so it can be run by different "interpreters" for different purposes.

The key function here is `sampleValue`, which takes a distribution and a lens into the global record.
When generating samples, this lens is used to obtain the value of the global variable.
During inference, it is used to update the prior into the posterior.

In [9]:
-- The likelihood of the example model, described as a probabilistic program.
-- 'm' is the type variable that stands for the interpreter the program is run with.
exampleLk :: _ => RandomInterpreter m ExampleParams => m Int
exampleLk = do
  coin <- sampleValue Bernoulli epTheta
  sampleValue (Categorical @3) $ if coin then epCat1 else epCat2

Using the global variables sampled above, we can sample a datapoint from this likelihood:

In [10]:
(result, trace) <- sampleTrace probs exampleLk gen
(result, trace)

(1,Trace {runTrace = fromList [<<Bool>>,<<Int>>]})

Since the trace is a bit mysterious like this, let's use the likelihood program again to display the trace with a bit more contextual information:

In [11]:
putStrLn "trace:"
printTrace trace exampleLk

trace:

Sampled value False from a Bernoulli.
Sampled value 1 from a Categorical 3.

Similarly, we can evaluate the probability of the trace:

In [12]:
logp = snd <$> evalTraceLogP probs trace exampleLk
logp

Just (-4.0829767272934546e-2)

## Inference

We can "observe" a trace to update the prior parameters to the posterior using again our likelihood program.

In [13]:
getPosterior prior trace exampleLk

Just ExampleParams
  epTheta = HyperRep {runHyper = (0.5,1.5)}
  epCat1  = HyperRep {runHyper = [0.5,0.5,0.5]}
  epCat2  = HyperRep {runHyper = [0.5,1.5,0.5]}

If we want to do the same with a set of observations from a dataset,
we have to express each observation as a trace:

In [14]:
observations =
        [ Trace $ S.fromList [toDyn True, toDyn (0 :: Int)]
        , Trace $ S.fromList [toDyn False, toDyn (1 :: Int)]
        , Trace $ S.fromList [toDyn True, toDyn (0 :: Int)]
        , Trace $ S.fromList [toDyn False, toDyn (2 :: Int)]
        , Trace $ S.fromList [toDyn True, toDyn (0 :: Int)]
        , Trace $ S.fromList [toDyn False, toDyn (1 :: Int)]
        , Trace $ S.fromList [toDyn False, toDyn (2 :: Int)]
        , Trace $ S.fromList [toDyn False, toDyn (1 :: Int)]
        ]
-- foldM goes over all traces and updates the hyperparameters every time
Just posterior = foldM (\hyper obs -> getPosterior hyper obs exampleLk) prior observations
posterior

ExampleParams
  epTheta = HyperRep {runHyper = (3.5,5.5)}
  epCat1  = HyperRep {runHyper = [3.5,0.5,0.5]}
  epCat2  = HyperRep {runHyper = [0.5,3.5,2.5]}

Voila!