In [1]:
import torch 
import torch.nn as nn
from utils import init                                   # import it from utils.py or uncomment the (def init()-- function)
from torch.utils.data import WeightedRandomSampler
from torch.distributions import Categorical
import torch.nn.functional as F

In [2]:
def init(module, weight_init, bias_init, gain =1):                 # can be imported from utils.py
    weight_init(module.weight.data, gain = gain)
    if hasattr(module, 'bias') and module.bias is not None:
        bias_init(module.bias.data)
    return module

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class _Categorical(Categorical):
    def __init__(self, _logits):
        super().__init__(logits = _logits)
        self._logits = self.logits
        self.weighted_sampler = WeightedRandomSampler

    def gumbel_softmax_sample(self, tau, device):
        dist = F.gumbel_softmax(self._logits, tau = tau, hard = False)
        action = torch.multinomial(dist, num_samples = 1).to(device)
        # action = torch.tensor(list(self.weighted_sampler(dist, 1, replacement = False))).to(device)
        return action.squeeze(-1)
    
    def mode(self):
        return torch.argmax(self._logits, dim = -1, keepdim = False)

In [5]:
class  MultiHeadCategorical(nn.Module):
    def __init__(self, num_inputs, action_dim, action_num, device):
        super().__init__()
        
        init_ =  lambda m: init(m,
                               nn.init.orthogonal_,
                               lambda x: nn.init.constant_(x, 0),
                               gain = 0.01) 
        self.action_num = action_num
        self.linear_list = torch.nn.ModuleList([init_(nn.Linear(num_inputs, action_dim).to(device)) for _ in range (action_num)])    # could be only for training, for backpropagation (differtiable training) 
        self.logits_head = []
        self.weight_sample = WeightedRandomSampler
        self.device = device
        self.categorical_list = []
        self.train()

    def forward(self, inputs):
        self.categorical_list = [_Categorical(linear(inputs)) for linear in self.linear_list]    # sample has not been done yet, # probabillity distributions, enables sampling

    def gumbel_softmax_sample(self, tau):                      
        action = torch.cat([p.gumbel_softmax_sample(tau, self.device) for p in self.categorical_list]) 
        self.action = torch.cat([p.gumbel_softmax_sample(tau, self.device) for p in self.categorical_list])  # y = softmax((logits + gumbel_noise) / tau)
        return action
    
    def probs(self):
        if self.action_num == 1:
            return self.categorical_list[0].probs
        else:
            return torch.cat([p.probs.unsqueeze(-1) for p in self.categorical_list], dim = -1)  #unsqueeze(-1)- add new dimension at the last position
            

    def log_probs(self, action):
        if self.action_num == 1:
            return self.categorical_list[0].log_probs(action)
        else:
            return torch.cat([p.log_prob(a).unsqueeze(-1) for a, p in zip(action, self.categorical_list)], dim = -1)
        
    def mode(self):
        if self.action_num == 1:
            return self.catelogical_list[0].mode()
        else:
            return torch.cat([p.mode() for p in self.categorical_list])
        
    def sample(self):                                         # true discrete sample, during evaluation and interence
        if self.action_num == 1:
            return self.categorical_list[0].sample()
        else:
            return torch.cat([p.sample() for p in self.categorical_list])
        
    def entropy(self):
        if self.action_num == 1:
            return self.categorical[0].entropy()
        else:
            return torch.cat([p.entropy() for p in self.categorical_list])

For debugging

In [6]:
num_inputs = 16
action_dim = 3
action_num = 4
tau = 0.85

model = MultiHeadCategorical(num_inputs, action_dim, action_num, device)

inputs = torch.randn(1, num_inputs).to(device)
inputs.shape

model.forward(inputs)
print(model.categorical_list)                # logits, raw output from linear layers
print(len(model.categorical_list))

model.gumbel_softmax_sample(tau)
actions = model.action                           # action sampled for each agent
print("actions:", actions)

print("probabilities:", model.probs())
print("mode:", model.mode())
print("sample:", model.sample())
print("entropy:", model.entropy())

print("log-probs;", model.log_probs(actions))


[_Categorical(logits: torch.Size([1, 3])), _Categorical(logits: torch.Size([1, 3])), _Categorical(logits: torch.Size([1, 3])), _Categorical(logits: torch.Size([1, 3]))]
4
actions: tensor([0, 2, 1, 0])
probabilities: tensor([[[0.3314, 0.3324, 0.3295, 0.3357],
         [0.3295, 0.3362, 0.3344, 0.3309],
         [0.3391, 0.3314, 0.3361, 0.3334]]], grad_fn=<CatBackward0>)
mode: tensor([2, 1, 2, 0])
sample: tensor([1, 0, 1, 0])
entropy: tensor([1.0985, 1.0986, 1.0986, 1.0986], grad_fn=<CatBackward0>)
log-probs; tensor([[-1.1045, -1.1043, -1.0953, -1.0915]], grad_fn=<CatBackward0>)
