## Gradient Policy

### Define Loss

In [2]:
from collections import namedtuple
import torch

pg_data = namedtuple('pg_data', ['logit', 'action', 'return_'])
pg_loss = namedtuple('pg_loss', ['policy_loss', 'entropy_loss'])

def pg_error(data: namedtuple) -> namedtuple:

    logit, action, return_ = data

    # 用 torch 根据 policy 的数据拟合一个 \pi(action|state) 的分布
    dist = torch.distributions.categorical.Categorical(logits=logit)
 
    # 计算实际 action 对应的概率值
    log_prob = dist.log_prob(action)


    # key: 这里的 loss 是乘上了一个负号，因为 RL 中是想要目标函数最大，而 torch 则是求最小，因此要变换一下
    policy_loss =  - (return_ * log_prob).mean()

    # entropy = \pi*log(\pi)
    entropy_loss = dist.entropy().mean()

    return pg_loss(policy_loss, entropy_loss)

### Test

In [16]:
def test_pg():
    B, N = 4, 32
    logit = torch.randn(B, N)
    logit.requires_grad = True # 4,32
    action = torch.randint(0, N, size=(B, )) # 4
    return_ = torch.randn(B) * 2 # 4

    data = pg_data(logit, action, return_)
    loss = pg_error(data)
    
    assert all([l.shape == tuple() for l in loss])
    assert logit.grad is None

    total_loss = sum(loss)
    total_loss.backward()
    print(f'gradient = {logit.grad}')
    assert isinstance(logit.grad, torch.Tensor)

In [17]:
test_pg()

gradient = tensor([[ 3.5486e-03,  4.1057e-03, -5.9870e-03, -1.2126e-02,  4.1362e-03,
          4.4808e-03, -1.1184e-02,  4.4520e-03,  4.3979e-03,  2.6895e-03,
          3.2572e-03, -4.0002e-03, -1.3826e-02,  4.4099e-03,  1.5390e-03,
          4.2445e-03,  4.3016e-03,  4.4508e-03,  2.4397e-03,  3.6939e-03,
          4.0700e-03,  4.4068e-03, -3.8916e-02,  4.1086e-03,  4.4745e-03,
          4.4634e-03,  4.3491e-03,  3.8016e-03, -1.1592e-02,  4.4277e-03,
          3.6870e-03,  3.6957e-03],
        [ 2.5218e-03,  1.2896e-03, -5.0744e-03,  3.7747e-03,  3.6985e-03,
          3.9812e-03,  3.0884e-03,  3.6825e-03,  1.9921e-03,  1.6325e-03,
          1.6994e-03,  4.6582e-02, -8.5802e-03,  2.1367e-03,  3.4858e-03,
          2.6416e-03, -5.3362e-02,  5.3370e-04,  3.8763e-03,  3.0482e-03,
          2.3109e-03,  3.3736e-03, -1.4816e-02,  3.7264e-03,  2.3092e-03,
         -3.9912e-02,  3.9392e-03,  3.8367e-03,  2.2680e-03,  3.9189e-03,
          2.5579e-03,  3.8398e-03],
        [ 1.2087e-03, -5.5561