In [2]:
import json
import torch
from typing import List, Optional, Tuple, Union

In [3]:
def get_advantages_and_returns(
        values: torch.Tensor,
        rewards: torch.Tensor,
        action_mask: torch.Tensor,
        gamma: float,
        lambd: float,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Function that computes advantages and returns from rewards and values.
        Calculated as in the original PPO paper: https://arxiv.org/abs/1707.06347
        Note that rewards may include a KL divergence loss term.

        Advantages looks like this:
        Adv1 =  R1 + γ * λ * R2     + γ^2 * λ^2 * R3       + ...
              - V1 + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...

        Returns looks like this:
        Ret1 =  R1 + γ * λ * R2     + γ^2 * λ^2 * R3       + ...
                   + γ * (1 - λ) V2 + γ^2 * λ * (1 - λ) V3 + ...

        Input:
        - values: Tensor of shape (batch_size, response_size)
        - rewards: Tensor of shape (batch_size, response_size)

        Output:
        - advantages: Tensor of shape (batch_size, response_size)
        - returns: Tensor of shape (batch_size, response_size)
        """
        lastgaelam = 0
        advantages_reversed = []
        response_length = rewards.size(1)

        # Mask invalid responses
        values = action_mask * values
        rewards = action_mask * rewards

        for t in reversed(range(response_length)):
            nextvalues = values[:, t + 1] if t < response_length - 1 else 0.0
            delta = rewards[:, t] + gamma * nextvalues - values[:, t]
            lastgaelam = delta + gamma * lambd * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)
        returns = advantages + values
        return advantages.detach(), returns


In [8]:
def _compute_advantages(
        valid_values: torch.FloatTensor,
        rewards: torch.FloatTensor,
        shifted_labels_mask: torch.LongTensor,
        gamma: float,
        lambd: float,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """
        Compute the advantages from the values and rewards.

        Args:
            valid_values (`torch.FloatTensor`):
                The values of the responses, shape (`batch_size`, `max_seq_len-1`)
            rewards (`torch.FloatTensor`):
                The rewards of the responses, shape (`batch_size`, `max_seq_len-1`)
            shifted_labels_mask (`torch.LongTensor`):
                Left Shifted by 1 Mask for the labels (i.e. actions), shape (`batch_size`, `max_seq_len-1`)

        Returns:
            `torch.FloatTensor`: The advantages of the responses, shape (`batch_size`, `max_seq_len-1`)
            `torch.FloatTensor`: The returns of the responses, shape (`batch_size`, `max_seq_len-1`)
        """
        lastgaelam = 0
        advantages_reversed = []
        actions_seq_len = rewards.shape[-1]

        # Make sure invalid rewards are masked
        rewards *= shifted_labels_mask


        for t in reversed(range(actions_seq_len)):
            next_state_values = (
                valid_values[:, t + 1] if t < (actions_seq_len - 1) else 0.0
            )
            delta = (
                rewards[:, t]
                + gamma * next_state_values
                - valid_values[:, t]
            )
            lastgaelam = (
                delta + gamma * lambd * lastgaelam
            )
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1]).transpose(0, 1)
        assert advantages.shape == rewards.shape

        returns = advantages + valid_values
        return advantages.detach(), returns.detach()

In [12]:
values = torch.tensor([0.2,0.2,0.5,0.5, 0,0,0]).unsqueeze(0)
rewards = torch.tensor([0,0,0,0,0, 0, 1]).unsqueeze(0)
action_mask = torch.tensor([1,1,1,1,1,1,1]).unsqueeze(0)
gamma = 1
lambd = 1
get_advantages_and_returns(values, rewards, action_mask, gamma, lambd)

(tensor([[0.8000, 0.8000, 0.5000, 0.5000, 1.0000, 1.0000, 1.0000]]),
 tensor([[1., 1., 1., 1., 1., 1., 1.]]))

In [5]:
# 1000个0
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> torch.Tensor:
    if dim is not None:
        return (tensor * mask).sum(axis=dim) / (mask.sum(axis=dim) + 1e-7)
    else:
        return (tensor * mask).sum() / mask.sum()
# values = torch.ones(1000)
values = torch.zeros(1000)
# print(values) 
values[-1] = 0
rewards = torch.zeros(1000)
rewards[-1] = -1
values = values.unsqueeze(0)
rewards = rewards.unsqueeze(0)
action_mask = torch.ones(1000).unsqueeze(0)
gamma = 1
lambd = 1
advantage,returns = get_advantages_and_returns(values, rewards, action_mask, gamma, lambd)
mean = masked_mean(advantage, action_mask, dim=-1).mean()
print(advantage,mean)

tensor([[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,
         -1., -1., -1., -1., -1., -1., -1., -1., -1.

In [33]:
values = torch.tensor([0.5,0.2, 0]).unsqueeze(0)
rewards = torch.tensor([0, 0, -1]).unsqueeze(0)
action_mask = torch.tensor([1,1,1]).unsqueeze(0)
gamma = 1
lambd = 0.95
get_advantages_and_returns(values, rewards, action_mask, gamma, lambd)

(tensor([[-1.3925, -1.1500, -1.0000]]), tensor([[-0.8925, -0.9500, -1.0000]]))

In [38]:
#创建一个16，60的tensor
values = torch.rand(2, 3)
print(values)
action_mask = torch.tensor([[1,0,0],[0,1,0]])
print(action_mask)
eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True)
print(eos_indices)
# 把values的eos_indices位置的值设置为0
values.scatter_(1, eos_indices, 0)
print(values)

tensor([[0.8086, 0.7561, 0.9381],
        [0.8739, 0.8820, 0.5278]])
tensor([[1, 0, 0],
        [0, 1, 0]])
tensor([[0],
        [1]])
tensor([[0.0000, 0.7561, 0.9381],
        [0.8739, 0.0000, 0.5278]])
