In [1]:
import torch

import numpy as np

In [2]:
EPS = 1e-8


class NPMultivariateHypergeometric:
    def __init__(self, probs):
        self.probs = np.array(probs)
        self.n = len(probs)
        
    def sample(self, size=1):
        return np.random.choice(np.arange(self.n), self.n, p=self.probs, replace=False)
    
    def log_prob(self, value):
        return np.log(self.prob(value))
    
    def prob(self, value):
        value = value.astype(int)
        numerator = self.probs[value]
        denominator = np.hstack([np.zeros((len(value), 1)), numerator[:,:-1]])
        denominator = 1 - np.cumsum(denominator, axis=1)
        p = np.prod(numerator/(denominator+EPS), axis=1)
        return p
    

class MultivariateHypergeometric(torch.distributions.Categorical):
    """Sample ranking index according to Softmax distribution"""

    def __init__(self, probs=None, logits=None, validate_args=None):
        super().__init__(probs, logits, validate_args)
        self._event_shape = (self._param.size()[-1],)

    def sample(self, sample_shape=torch.Size(), replacement=False):
        probs_2d = self.probs.reshape(-1, self._num_events)
        samples_2d = torch.multinomial(probs_2d, self._num_events, replacement)
        return samples_2d

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        return self.prob(value).log()

    def prob(self, value):
        numerator = self.probs[value.long()]
        denominator = torch.hstack(
            [torch.zeros_like(numerator[:, :1]), numerator[:, :-1]]
        )
        denominator = 1 - torch.cumsum(denominator, dim=1)
        p = torch.prod(numerator / (denominator + EPS), dim=1)
        return p

In [3]:
probs = np.array([0.6, 0.25, 0.15])
size = 100000

In [4]:
dist1 = NPMultivariateHypergeometric(probs)

sample1 = np.zeros((size, len(probs)))
for i in range(size):
    sample1[i] = dist1.sample()

act1, count1 = np.unique(sample1, axis=0, return_counts=True)

In [5]:
dist2 = MultivariateHypergeometric(torch.Tensor(probs))

sample2 = np.zeros((size, len(probs)))
for i in range(size):
    sample2[i] = dist2.sample().numpy().flatten()
    
act2, count2 = np.unique(sample2, axis=0, return_counts=True)

In [6]:
(act1 == act2).all()

True

In [7]:
count1/count1.sum()

array([0.3769 , 0.22439, 0.19878, 0.04956, 0.10618, 0.04419])

In [8]:
count2/count2.sum()

array([0.37553, 0.22441, 0.19887, 0.05032, 0.10618, 0.04469])

In [9]:
dist1.prob(act1)

array([0.37499996, 0.22499998, 0.19999998, 0.05      , 0.10588235,
       0.04411765])

In [10]:
dist2.prob(torch.Tensor(act1))

tensor([0.3750, 0.2250, 0.2000, 0.0500, 0.1059, 0.0441])

In [11]:
dist1.prob(act2)

array([0.37499996, 0.22499998, 0.19999998, 0.05      , 0.10588235,
       0.04411765])

In [12]:
dist2.prob(torch.Tensor(act2))

tensor([0.3750, 0.2250, 0.2000, 0.0500, 0.1059, 0.0441])

In [13]:
-(dist2.prob(torch.Tensor(act1)) * dist2.log_prob(torch.Tensor(act1))).sum()

tensor(1.5505)

In [14]:
dist2.entropy()

tensor(0.9376)

In [15]:
probs = np.array([1e-8, 0.99, 0.01-1e-8])
dist3 = MultivariateHypergeometric(torch.Tensor(probs))

In [16]:
dist3.entropy()

tensor(0.0560)

In [17]:
-(dist3.prob(torch.Tensor(act1)) * dist3.log_prob(torch.Tensor(act1))).sum()

tensor(0.0560)