In [3]:
from typing import Optional
import torch
from torch.distributions.categorical import Categorical
from torch import einsum
from einops import reduce

In [4]:
class CategoricalMasked(Categorical):
    def __init__(self, logits: torch.Tensor, mask: Optional[torch.Tensor] = None):
        self.mask = mask
        self.batch, self.nb_action = logits.size()
        if mask is None:
            super(CategoricalMasked, self).__init__(logits=logits)
        else:
            self.mask_value = torch.tensor(
                torch.finfo(logits.dtype).min, dtype=logits.dtype
            )
            logits = torch.where(self.mask, logits, self.mask_value)
            super(CategoricalMasked, self).__init__(logits=logits)

    def entropy(self):
        if self.mask is None:
            return super().entropy()
        # Elementwise multiplication
        p_log_p = einsum("ij,ij->ij", self.logits, self.probs)
        # Compute the entropy with possible action only
        p_log_p = torch.where(
            self.mask,
            p_log_p,
            torch.tensor(0, dtype=p_log_p.dtype, device=p_log_p.device),
        )
        return -reduce(p_log_p, "b a -> b", "sum", b=self.batch, a=self.nb_action)


In [5]:
logits_or_qvalues = torch.randn((2, 3), requires_grad=True) # batch size, nb action
print(logits_or_qvalues) 
# tensor([[-1.8222,  1.0769, -0.6567],
#         [-0.6729,  0.1665, -1.7856]])

mask = torch.zeros((2, 3), dtype=torch.bool) # batch size, nb action
mask[0][2] = True
mask[1][0] = True
mask[1][1] = True
print(mask) # False -> mask action 
# tensor([[False, False,  True],
#         [ True,  True, False]])

tensor([[ 0.3031, -0.2697, -0.0426],
        [-1.8454, -1.9453,  0.3548]], requires_grad=True)
tensor([[False, False,  True],
        [ True,  True, False]])


In [6]:
head = CategoricalMasked(logits=logits_or_qvalues)
print(head.probs) # Impossible action are not masked
# tensor([[0.0447, 0.8119, 0.1434], There remain 3 actions available
#         [0.2745, 0.6353, 0.0902]]) There remain 3 actions available

head_masked = CategoricalMasked(logits=logits_or_qvalues, mask=mask)
print(head_masked.probs) # Impossible action are  masked
# tensor([[0.0000, 0.0000, 1.0000], There remain 1 actions available
#         [0.3017, 0.6983, 0.0000]]) There remain 2 actions available

print(head.entropy())
# tensor([0.5867, 0.8601])

print(head_masked.entropy())
# tensor([-0.0000, 0.6123])

tensor([[0.4402, 0.2483, 0.3115],
        [0.0915, 0.0828, 0.8257]], grad_fn=<SoftmaxBackward0>)
tensor([[0.0000, 0.0000, 1.0000],
        [0.5250, 0.4750, 0.0000]], grad_fn=<SoftmaxBackward0>)
tensor([1.0704, 0.5831], grad_fn=<NegBackward0>)
tensor([-0.0000, 0.6919], grad_fn=<NegBackward0>)


In [11]:
# 定義字典
example_dict = {'Y': 800, 'M': 500, 'K': 450}

# 定義乘法因子列表
factor_list = [1.1, 1, 1.1]

# 遍歷字典並對每個值進行乘法操作
result_dict = {key: value * factor_list[i] for i, (key, value) in enumerate(example_dict.items())}

print(result_dict)


{'Y': 880.0000000000001, 'M': 500, 'K': 495.00000000000006}
