In [1]:
from numbers import Number
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
from torch.distributions import normal
import math

In [312]:
class JohnsonSU(Distribution):
    r"""
    Creates a JohnsonSU distribution parameterized by :attr:`gamma` and :attr:`delta` and :attr:`mu`  and :attr:`sigma`.

    Example::

        >>> m = JohnsonSU(torch.tensor([2.0]), torch.tensor([8.0]), torch.tensor([4.0]), torch.tensor([1.0]), validate_args=True)
        >>> m.sample()
        tensor([3.8110])

    Args:
        gamma (float or Tensor): shape of the distribution
        delta (float or Tensor): shape of the distribution
        mu (float or Tensor): location of the distribution
        sigma (float or Tensor): scale of the distribution
    """

    arg_constraints = {'gamma': constraints.real, 'delta': constraints.positive,
                       'mu': constraints.real, 'sigma': constraints.positive}
    support = constraints.real

    @property
    def mean(self):
        rhs = torch.mul(torch.mul(torch.exp(torch.div(1, 2 * torch.pow(self.delta, 2))),
                                  self.sigma),
                        torch.sinh(torch.div(self.gamma, self.delta)))
        return torch.sub(self.mu, rhs)

    @property
    def variance(self):
        first_term = torch.div(torch.exp(torch.div(-2 * self.gamma, self.delta)), 4)

        second_term = torch.sub(torch.exp(torch.div(1, torch.pow(self.delta, 2))), 1)

        third_term = torch.add(torch.add(torch.exp(torch.div(1, torch.pow(self.delta, 2))),
                                         2 * torch.exp(torch.div(2 * self.gamma, self.delta))),
                               torch.exp(torch.div(1 + 4 * torch.mul(self.gamma, self.delta),
                                                   torch.pow(self.delta, 2))))
        return torch.mul(first_term,
                         torch.mul(second_term,
                                   torch.mul(third_term,
                                             torch.pow(self.sigma, 2))))

    @property
    def stddev(self):
        return torch.sqrt(self.variance)

    @property
    def median(self):
        return torch.add(self.mu, torch.mul(self.sigma, torch.sinh(torch.div(-self.gamma, self.delta))))

    def __init__(self, gamma, delta, mu, sigma, validate_args=None):
        self.gamma, self.delta, self.mu, self.sigma = broadcast_all(gamma, delta, mu, sigma)
        if isinstance(gamma, Number) and isinstance(delta, Number) and isinstance(mu, Number) and isinstance(sigma,
                                                                                                             Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.mu.size()
        super(JohnsonSU, self).__init__(batch_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(JohnsonSU, _instance)
        batch_shape = torch.Size(batch_shape)
        new.gamma = self.gamma.expand(batch_shape)
        new.delta = self.delta.expand(batch_shape)
        new.mu = self.mu.expand(batch_shape)
        new.sigma = self.sigma.expand(batch_shape)
        super(JohnsonSU, new).__init__(batch_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        U = torch.rand(shape)
        return torch.add(torch.mul(self.sigma,
                                   torch.sinh(torch.div(torch.sub(normal.Normal(0, 1).icdf(U),
                                                                  self.gamma), self.delta))), self.mu)

    def log_prob(self, value):
        pi = torch.Tensor([math.pi])
        numerator = self.delta * torch.exp((-1 / 2) * (torch.pow(
            torch.add(self.gamma, torch.mul(self.delta, torch.arcsinh(torch.div(value - self.mu, self.sigma)))), 2)))
        denominator = torch.mul(torch.sqrt(2 * pi),
                                torch.sqrt(torch.pow(value - self.mu, 2) + torch.pow(self.sigma, 2)))
        return torch.log(numerator / denominator)

    def cdf(self, value):
        return (1 + torch.erf(
            torch.div(self.gamma + torch.mul(self.delta, torch.arcsinh(torch.div(value - self.mu, self.sigma))),
                      torch.sqrt(torch.tensor([2.0]))))) / 2


In [303]:
m = JohnsonSU(torch.tensor([1.0]), torch.tensor([8.0]), torch.tensor([4.0]), torch.tensor([1.0]), validate_args=True)

In [314]:
m.sample()

tensor([3.8110])

In [304]:
m.log_prob(torch.tensor([2.0]))

tensor([-55.2858])

In [305]:
m.cdf(torch.tensor([4.0]))

tensor([0.8413])

In [311]:
m.median

tensor([3.8747])