Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix multivariate normal bug (#21105)
Browse files Browse the repository at this point in the history
  • Loading branch information
hankaj committed Jul 29, 2022
1 parent 5e5e0e3 commit db39bb1
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion python/mxnet/numpy_op_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def forward(self, is_train, req, in_data, out_data, aux):
scale = _mx_np.linalg.cholesky(cov)
#set context
noise = _mx_np.random.normal(size=out_data[0].shape, dtype=loc.dtype, device=loc.device)
out = loc + _mx_np.einsum('...jk,...j->...k', scale, noise)
out = loc + _mx_np.einsum('...jk,...k->...j', scale, noise)
self.assign(out_data[0], req[0], out)

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
Expand Down

0 comments on commit db39bb1

Please sign in to comment.