# A2C

## Definition

In [4]:
from collections import namedtuple
import torch
import torch.nn.functional as F

a2c_data = namedtuple('a2c_data', ['logits', 'action', 'value', 'adv', 'return_', 'weight'])
a2c_loss = namedtuple('a2c_loss', ['policy_loss', 'value_loss', 'entropy_loss'])

def a2c_error(data: namedtuple) -> namedtuple:
    logits, action, value, adv, return_, weight = data

    if weight is None:
        weight = torch.ones_like(value)
    
    dist = torch.distributions.categorical.Categorical(logits=logits)
    log_prod = dist.log_prob(action)

    policy_loss = -(log_prod*adv*weight).mean()

    value_loss = torch.mean(weight*(return_-value)**2)

    entropy_loss = (dist.entropy()*weight).mean()

    return a2c_loss(policy_loss, value_loss, entropy_loss)


这里对于所有隐变量发表自己的看法
- logits：就是我们的网络分布
- action：是我们网络中采样得到的一个实际行动
- value：是当前情况下给出的价值函数结果
- adv：优势函数，通过 $r_t + \gamma v(s_{t+1}) - \gamma v(s_{t})$ 求得
- return_: 一个 episode 走完后的真实 reward 累积
- weight：权重

In [5]:
def test_a2c():
    B, N = 4, 32
    logit = torch.randn(B, N).requires_grad_(True)
    action = torch.randint(0, N, size=(B, ))
    value = torch.randn(B).requires_grad_(True)
    adv = torch.rand(B)
    return_ = torch.randn(B) * 2
    data = a2c_data(logit, action, value, adv, return_, None)

    loss = a2c_error(data)

    assert logit.grad is None
    assert value.grad is None
    total_loss = sum(loss)
    total_loss.backward()
    assert isinstance(logit.grad, torch.Tensor)
    assert isinstance(value.grad, torch.Tensor)


In [6]:
test_a2c()