In [1]:
from numbers import Number
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
from torch.distributions import normal
import math

In [34]:
class JohnsonSB(Distribution):
    r"""
    Creates a JohnsonSB distribution parameterized by :attr:`gamma` and :attr:`delta` and :attr:`mu`  and :attr:`sigma`.

    Example::

        >>> m = JohnsonSB(torch.tensor([1.0]), torch.tensor([8.0]), torch.tensor([4.0]), torch.tensor([1.0]), validate_args=True)
        >>> m.sample()
        tensor([4.4506])

    Args:
        gamma (float or Tensor): shape of the distribution
        delta (float or Tensor): shape of the distribution
        mu (float or Tensor): location of the distribution
        sigma (float or Tensor): scale of the distribution
    """

    arg_constraints = {'gamma': constraints.real, 'delta': constraints.positive,
                       'mu': constraints.real, 'sigma': constraints.positive}
    support = constraints.real

    def __init__(self, gamma, delta, mu, sigma, validate_args=None):
        self.gamma, self.delta, self.mu, self.sigma = broadcast_all(gamma, delta, mu, sigma)
        if isinstance(gamma, Number) and isinstance(delta, Number) and isinstance(mu, Number) and isinstance(sigma,
                                                                                                             Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.mu.size()
        super(JohnsonSB, self).__init__(batch_shape, validate_args=validate_args)

    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(JohnsonSB, _instance)
        batch_shape = torch.Size(batch_shape)
        new.gamma = self.gamma.expand(batch_shape)
        new.delta = self.delta.expand(batch_shape)
        new.mu = self.mu.expand(batch_shape)
        new.sigma = self.sigma.expand(batch_shape)
        super(JohnsonSB, new).__init__(batch_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        Z = torch.randn(shape)
        Y = torch.div(torch.sub(Z, self.gamma), self.delta)
        return torch.div(torch.add(torch.mul(self.sigma, torch.exp(Y)), torch.mul(self.mu, (torch.exp(Y) + 1))),
                         torch.add(torch.exp(Y), 1))

    def log_prob(self, value):
        if value > self.mu and value < self.mu + self.sigma:
            pi = torch.Tensor([math.pi])
            numerator = self.delta * self.sigma * torch.exp((-1 / 2) * (torch.pow(torch.add(self.gamma,
                                                                                            torch.mul(self.delta,
                                                                                                      torch.log(
                                                                                                          torch.div(
                                                                                                              value - self.mu,
                                                                                                              self.mu + self.sigma - value)))),
                                                                                  2)))
            denominator = torch.mul(torch.sqrt(2 * pi), torch.mul(value - self.mu, self.mu + self.sigma - value))
            return torch.log(numerator / denominator)
        else:
            return torch.Tensor([0.0])

    def cdf(self, value):
        if value > self.mu and value < self.mu + self.sigma / 2:
            return (1 / 2) * torch.erfc(-torch.div(self.gamma + torch.mul(self.delta, torch.log(
                torch.div(value - torch.mu, self.mu + self.sigma - value))), torch.sqrt(torch.tensor([2.0]))))
        if value >= self.mu + self.sigma / 2 and value < self.mu + self.sigma:
            return (1 / 2) * (1 + torch.erf(torch.div(
                self.gamma + torch.mul(self.delta, torch.log(torch.div(value - self.mu, self.mu + self.sigma - value))),
                torch.sqrt(torch.tensor([2.0])))))
        if value >= self.mu + self.sigma:
            return torch.Tensor([1.0])
        else:
            return torch.Tensor([0.0])
