In [14]:
from HMC.gaussian_hmc import GaussianTargetHMC
import numpy as np
from scipy.stats import multivariate_normal
from utils import quick_MVN_marginals, quick_MVN_scatter

In [2]:
mu = np.zeros(2)
Sigma = np.array([[1.0, 0.99], [0.99, 2.0]])
target = multivariate_normal(mu, Sigma)

In [12]:
n = 100
M = np.eye(2)
T = 1
epsilon = 0.05

In [13]:
x = target.rvs()
samples = GaussianTargetHMC(x, n, M, T, epsilon, Sigma, mu).sample()

Let $x\in\mathbb{R}^d$ and $u\in\mathbb{R}$. Define $q = (x, u)\in\mathbb{R}^{d+1}$. The the target density is
$$
\pi(q) = \pi(x, u) = \frac{1}{Z_\pi}\mathbb{I}\left\{(x, u) \, :\, 0 < u < \gamma(x)\right\}
$$
Introduce a new variable $p\in\mathbb{R}^{d+1}$ and define the joint
$$
\pi(q, p) = \pi(p \mid q) \pi(q)
$$
Hamilton's equations are
\begin{align}
\partial_t q &= \partial_p K \\
\partial_t p &= - \partial_q K - \partial_q V
\end{align}
If we take $\pi(p \mid q) = N(0, M)$ then we have
\begin{align}
\partial_p K &= M^{-1} p \\
\partial_q K &= 0
\end{align}
Now we only need to compute $-\partial_q V$.
\begin{align}
    \partial_q V(q) 
    &= -\partial_{q} \log \pi(q) \\
    &= -\nabla_{(x, u)} \left[- \log Z_\pi + \log \mathbb{I}\left\{(x, u)\, :\, 0 < u < \gamma(x)\right\}\right]
\end{align}

Now for the first term
$$
\nabla_{(x, u)} \log \int_{\mathcal{X}} \gamma(x) dx = 0
$$

The second term
\begin{align}
- \nabla_{(x, u)} \log \mathbb{I}\left\{(x, u)\, :\, 0 < u < \gamma(x)\right\}
&= - \nabla_{(x, u)} \log \begin{cases}
1 & 0 < u < \gamma(x) \\
0 & \text{otherwise}
\end{cases}  \\
&= - \nabla_{(x, u)} \begin{cases}
    0 & 0 < u < \gamma(x) \\ 
    -\infty & \text{otherwise}
\end{cases}
\end{align}

Now clearly the function we are taking the gradient of is constant everywhere except where it changes value. It changes value only when $u = \gamma(x)$ so its derivative is
$$
\delta_{(x, \gamma^{-1}(u))} = \delta_{(x, \gamma(x))}(x, u)
$$

Therefore the dynamics becomes
\begin{align}
    \partial_t q &= M^{-1} p \\
    \partial_t p &= - \delta_{(x, \gamma(x))}(x, u)
\end{align}

# Zig Zag

In [17]:
np.max((0, 1, 2))

2

In [None]:
def pot(xi, sigma=1.0):
    """Potential of a 1D Gaussian."""
    return (xi**2)/(2*sigma**2)

def gradpot(xi, sigma=1.0):
    """Gradient of Potential of a 1D Gaussian"""
    return xi / (sigma**2)

def gamma(xi):
    """Used to compute switch rate. Here it is constant"""
    return 1

def M_generator(t, xi, theta, gammafunc=gamma, sigma=1.0):
    """Bound on switch rate."""
    return (np.abs(theta*xi)/(sigma**2)) + gammafunc(xi) + t/(sigma**2)

def switchrate(xi, theta, gammafunc=gamma, sigma=1.0):
    """Lambda function. Switching rate."""
    np.max((0, theta*xi/(sigma**2))) +  gammafunc(xi)

In [None]:
def zigzag(xi0, theta0):
    pass
