In [1]:
import torch 
import torch.nn as nn
from utils import init                                   # import it from utils.py or uncomment the (def init()-- function)
from torch.utils.data import WeightedRandomSampler
from torch.distributions import Categorical

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class _Categorical(Categorical):
    """
    a son class inherit from class torch.distributions.Categorical
    it adds a gumbel softmax sample method, for gumbel softmax sample
    and a mode method for argmax sample
    """

    def __init__(self, _logits):
        super(_Categorical, self).__init__(logits=_logits)
        self._logits = self.logits
        self.weighted_sampler = WeightedRandomSampler

    def gumbel_softmax_sample(self, tau, device):
        dist = F.gumbel_softmax(self._logits, tau=tau, hard=False)
        action = torch.tensor(list(self.weighted_sampler(dist, 1, replacement=False))).to(device)
        return action.squeeze(-1)

    def mode(self):
        return torch.argmax(self._logits, dim=-1, keepdim=False)


class MultiHeadCategorical(nn.Module):
    """
    define a multi-head Categorical for multi-label classification
    --init:
    num_inputs: input feature dim
    dim_vec: a list for dim of each action space, e.g. [2,3,5], 2-dim for action1, 3-dim for action2, 5-dim for action3
    device: running device
    --forward:
    inputs: flatten input feature
    """

    def __init__(self, num_inputs, action_num, action_dim, device):
        super(MultiHeadCategorical, self).__init__()
        init_ = lambda m: init(m,
                               nn.init.orthogonal_,
                               lambda x: nn.init.constant_(x, 0),
                               gain=0.01)
        self.action_num = action_dim
        self.linear_list = torch.nn.ModuleList(
            [init_(nn.Linear(num_inputs, action_dim).to(device)) for _ in range(action_num)])
        self.action_num = action_num
        self.logits_head = []
        self.weight_sample = WeightedRandomSampler
        self.device = device
        self.categorical_list = []
        self.train()

    def forward(self, inputs):
        self.categorical_list = [_Categorical(linear(inputs)) for linear in self.linear_list]

        print("\nCategorical distributions:")
        for i, dist in enumerate(self.categorical_list):
            print(f"Head {i}:")
            print(" Sample:", dist.sample())              # random action selection     dist.sample --- instance of the _Categorical class
            print("Mode:", dist.mode())                   # greedy action (highest probability)

        return self.categorical_list

In [5]:
model = MultiHeadCategorical(16, 3, 4, device)
inputs = torch.randn(1, 16).to(device)

model.forward(inputs)

actions = [dist.sample() for dist in model.categorical_list]
print("Sampled actions:", actions)


Categorical distributions:
Head 0:
 Sample: tensor([1])
Mode: tensor([2])
Head 1:
 Sample: tensor([3])
Mode: tensor([1])
Head 2:
 Sample: tensor([1])
Mode: tensor([3])
Sampled actions: [tensor([2]), tensor([3]), tensor([2])]
