In [1]:
import torch

import numpy as np

In [2]:
class MultivariateHypergeometric:
    def __init__(self, probs):
        self.probs = np.array(probs)
        self.n = len(probs)
        
    def sample(self, size=1):
        a = np.zeros((size, self.n))
        for i in range(size):
            a[i] = np.random.choice(np.arange(self.n), self.n, p=self.probs, replace=False)
        return a
    
    def log_prob(self, value):
        return np.log(self.prob(value))
    
    def prob(self, value):
        value = value.astype(int)
        numerator = self.probs[value]
        denominator = np.hstack([np.zeros((len(value), 1)), numerator[:,:-1]])
        denominator = 1 - np.cumsum(denominator, axis=1)
        p = np.prod(numerator/(denominator+1e-8), axis=1)
        return p

In [3]:
probs = np.array([0.6, 0.25, 0.15])

In [4]:
dist = MultivariateHypergeometric(probs)

In [5]:
sample = dist.sample(size=1000000)

In [6]:
act, count = np.unique(sample, axis=0, return_counts=True)

In [7]:
act

array([[0., 1., 2.],
       [0., 2., 1.],
       [1., 0., 2.],
       [1., 2., 0.],
       [2., 0., 1.],
       [2., 1., 0.]])

In [8]:
count/count.sum()

array([0.375602, 0.224338, 0.199959, 0.050254, 0.106042, 0.043805])

In [9]:
dist.prob(act)

array([0.37499996, 0.22499998, 0.19999998, 0.05      , 0.10588235,
       0.04411765])

In [10]:
dist.log_prob(act)

array([-0.98082935, -1.49165495, -1.609438  , -2.99573231, -2.24542674,
       -3.12089545])

In [11]:
EPS = 1e-8


class MultivariateHypergeometric(torch.distributions.Categorical):
    """Sample ranking index according to Softmax distribution"""

    def __init__(self, probs=None, logits=None, validate_args=None):
        super().__init__(probs, logits, validate_args)
        self._event_shape = (self._param.size()[-1],)

    def sample(self, sample_shape=torch.Size(), replacement=False):
        probs_2d = self.probs.reshape(-1, self._num_events)
        samples_2d = torch.multinomial(probs_2d, self._num_events, replacement)
        return samples_2d

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        return self.prob(value).log()

    def prob(self, value):
        numerator = self.probs.gather(-1, value.long())
        denominator = torch.hstack(
            [torch.zeros_like(numerator[:, :1]), numerator[:, :-1]]
        )
        denominator = 1 - torch.cumsum(denominator, dim=1)
        p = torch.prod(numerator / (denominator + EPS), dim=1)
        return p

In [12]:
dist = MultivariateHypergeometric(torch.Tensor(probs))

In [13]:
dist.prob(torch.Tensor(act))

tensor([0.3750, 0.2250, 0.2000, 0.0500, 0.1059, 0.0441])

In [14]:
dist.log_prob(torch.Tensor(act))

tensor([-0.9808, -1.4917, -1.6094, -2.9957, -2.2454, -3.1209])

In [15]:
dist.entropy()

tensor(0.9376)

In [16]:
-(dist.prob(torch.Tensor(act)) * dist.log_prob(torch.Tensor(act))).sum()

tensor(1.5505)

In [17]:
probs = np.array([1e-8, 0.99, 0.01-1e-8])

In [18]:
dist = MultivariateHypergeometric(torch.Tensor(probs))

In [19]:
dist.prob(torch.Tensor(act))

tensor([9.9000e-09, 1.0000e-10, 9.9000e-07, 9.9000e-01, 1.0101e-10, 1.0000e-02])

In [20]:
dist.log_prob(torch.Tensor(act))

tensor([-1.8431e+01, -2.3026e+01, -1.3826e+01, -1.0051e-02, -2.3016e+01,
        -4.6052e+00])

In [21]:
dist.log_prob(torch.Tensor(act))

tensor([-1.8431e+01, -2.3026e+01, -1.3826e+01, -1.0051e-02, -2.3016e+01,
        -4.6052e+00])

In [22]:
dist.entropy()

tensor(0.0560)

In [23]:
-(dist.prob(torch.Tensor(act)) * dist.log_prob(torch.Tensor(act))).sum()

tensor(0.0560)