Skip to content

Commit

Permalink
Jax: Add log_multivariate_normal
Browse files Browse the repository at this point in the history
  • Loading branch information
andreArtelt committed Nov 27, 2019
1 parent 8987e02 commit 4cc159b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion ceml/backend/jax/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@ def normal_distribution(x, mean, variance):
return npx.exp(-.5 * npx.square(x - mean) / variance) / npx.sqrt(2. * npx.pi * variance)

def log_normal_distribution(x, mean, variance):
return -.5 * npx.square(x - mean) / variance - 0.5 * (2. + npx.pi + variance)
return -.5 * npx.square(x - mean) / variance - .5 * (2. + npx.pi + variance)

def log_multivariate_normal(x, mean, sigma_inv, k):
return .5 * npx.log(npx.linalg.det(sigma_inv)) - .5 * k * npx.log(2. * npx.pi) - .5 * npx.dot(x - mean, npx.dot(sigma_inv, (x - mean)))

0 comments on commit 4cc159b

Please sign in to comment.