In [1]:
import torch

import numpy as np

from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions import Categorical

In [2]:
class MultivariateCategorical(Distribution):
    def __init__(self, nvec, probs=None, logits=None, validate_args=None):
        nvec = list(nvec)
        dims = np.cumsum([0] + nvec)
        if probs is not None:
            self._dists = [
                Categorical(probs=probs[:, i:j]) for i, j in zip(dims[:-1], dims[1:])
            ]
        elif logits is not None:
            self._dists = [
                Categorical(logits=logits[:, i:j]) for i, j in zip(dims[:-1], dims[1:])
            ]
        else:
            raise ValueError("probs and logits are both None.")
        batch_shape = self._dists[0].batch_shape
        super(MultivariateCategorical, self).__init__(batch_shape, validate_args=False)

    def sample(self, sample_shape=torch.Size()):
        sample = torch.stack([d.sample(sample_shape) for d in self._dists], dim=1)
        return sample

    def log_prob(self, value):
        log_prob = torch.stack(
            [d.log_prob(value[:,i]) for i, d in enumerate(self._dists)], dim=1
        )
        return log_prob.sum(1)

    def entropy(self):
        entropy = torch.stack([d.entropy() for d in self._dists], dim=1)
        return entropy.sum(1)

    def argmax(self):
        return torch.stack([d.logits.argmax(1) for d in self._dists], dim=1)

In [3]:
logits = torch.randn(1,13)

In [4]:
logits

tensor([[ 0.5107, -1.8501,  1.5848, -0.2882, -1.7965, -0.0613, -0.8995,  0.6027,
         -0.2925,  0.7570, -0.6498,  1.5419, -0.5167]])

In [5]:
d = MultivariateCategorical(nvec=[10,3], logits=logits)

In [6]:
a = d.sample()

In [7]:
a

tensor([[2, 1]])

In [8]:
d.log_prob(a)

tensor([-1.2450])

In [9]:
d.entropy()

tensor([2.5188])

In [10]:
d.argmax()

tensor([[2, 1]])

In [11]:
d.log_prob(d.argmax())

tensor([-1.2450])