In [None]:
import math

import torch
from tensordict import TensorDict

from fabricrl.data.buffers import ReplayBuffer

In [None]:
buf_size = 16
seq_len = 2
n_envs = 3
rb = ReplayBuffer(buf_size, n_envs)

In [None]:
data = TensorDict(
    {
        "observations": torch.rand(seq_len, n_envs, 4),
        "actions": torch.randint(0, 4, (seq_len, n_envs, 1)),
        "rewards": torch.rand(seq_len, n_envs, 1),
        "dones": torch.randint(0, 2, (seq_len, n_envs, 1)),
    },
    batch_size=[seq_len, n_envs],
)
data.shape, rb._buf.shape

In [None]:
while not rb._full:
    rb.add(data)

In [None]:
rb.sample(4)

In [None]:
rb.buffer.view(math.prod(rb.buffer.shape), -1)

In [None]:
rb["returns"] = torch.rand(16, 3, 2)

In [None]:
rb.buffer.view(math.prod(rb.shape), -1)

In [None]:
import numpy as np
import torch

In [None]:
arr_in = torch.arange(16)
mask = torch.tensor([0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0]).bool()
arr_in

In [None]:
pad_arr_in = torch.nn.utils.rnn.pad_sequence(
    torch.tensor_split(arr_in.float(), mask.nonzero().view(-1) + 1), batch_first=True, padding_value=0
)
flip_pad_arr_in = pad_arr_in.fliplr()
cs = flip_pad_arr_in.cumsum(dim=1).fliplr()
cs[cs.nonzero(as_tuple=True)]

In [None]:
# get classical cumsum
cs = arr_in.cumsum(dim=0)
cs

In [None]:
torch.where(mask, cs, 0)

In [None]:
acc = torch.cummax(torch.where(mask, cs, 0), 0)[0].roll(1, 0)
acc[0] = 0
acc

In [None]:
cs - acc

In [None]:
# identify 0s
arr_in = np.arange(16)
mask = np.array([0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0])

# get classical cumsum
cs = (np.ones(arr_in.shape[0]) * (1 - mask)).cumsum()
print(cs)

w = mask * cs
print(w)

m = np.maximum.accumulate(w)
print(m)

# ffill the cumsum value on 1s
# subtract from cumsum
out = cs - m
print(out)

In [None]:
import torch
from torch import Tensor
from typing import Tuple

gamma = 0.99
gae_lambda = 0.95


def conditional_arange(n: int, mask: Tensor) -> Tensor:
    rolled_mask = torch.roll(mask, 1, 0)
    rolled_mask[0] = 0
    cs = (torch.ones(n) * (1 - rolled_mask)).cumsum(dim=0)
    acc = torch.cummax(rolled_mask * cs, 0)[0]
    return cs - torch.where(acc > 0, acc - 1, 0) - 1


@torch.no_grad()
def estimate_returns_and_advantages(
    rewards: Tensor,
    values: Tensor,
    dones: Tensor,
    next_done: Tensor,
    next_value: Tensor,
    num_steps: int,
    gamma: float,
    gae_lambda: float,
) -> Tuple[Tensor, Tensor]:
    advantages = torch.zeros_like(rewards)
    lastgaelam = 0
    for t in reversed(range(num_steps)):
        if t == num_steps - 1:
            nextnonterminal = torch.logical_not(next_done)
            nextvalues = next_value
        else:
            nextnonterminal = torch.logical_not(dones[t + 1])
            nextvalues = values[t + 1]
        delta = rewards[t] + gamma * nextvalues * nextnonterminal - values[t]
        advantages[t] = lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam
    returns = advantages + values
    return returns, advantages


@torch.no_grad()
def fast_estimate_returns_and_advantages(
    rewards: Tensor,
    values: Tensor,
    dones: Tensor,
    next_done: Tensor,
    next_value: Tensor,
    gamma: float,
    gae_lambda: float,
):
    if len(rewards.shape) == 3:
        t_steps = torch.cat(
            [
                conditional_arange(values.shape[0], dones[:, dim, :].view(-1)).view(-1, 1)
                for dim in range(rewards.shape[1])
            ],
            dim=1,
        )
    elif len(rewards.shape) == 2:
        t_steps = conditional_arange(values.shape[0], dones.view(-1)).view(-1, 1)
    else:
        raise ValueError("Shape must be 2 or 3 dimensional")
    gt = torch.pow(gamma * gae_lambda, t_steps.view_as(dones))
    next_values = torch.roll(values, -1, dims=0)
    next_values[-1] = next_value
    next_dones = torch.roll(dones, -1, dims=0)
    next_dones[-1] = next_done
    deltas = rewards + gamma * next_values * (1 - next_dones) - values
    cs = torch.flipud(deltas * gt).cumsum(dim=0)
    acc = torch.cummax(torch.flipud(dones) * cs, 0)[0]
    acc[0] = 0
    dones[-1] = 0
    # mask = dones.nonzero(as_tuple=True)
    # adv = torch.flipud(cs - acc) / gt
    # adv[mask] = deltas[mask] + gamma * gae_lambda * adv[mask[0] + 1, mask[1]]
    adv = torch.flipud(cs - acc) / gt
    return adv + dones * (deltas + gamma * gae_lambda * adv.roll(-1, 0))

In [None]:
num_steps = 256
batch_size = 2
rewards = torch.rand(num_steps, 1).tanh()
values = torch.rand(num_steps, 1)
dones = torch.zeros(num_steps, 1)
dones[0, 0] = 1.0
dones[-1, 0] = 1.0
next_done = torch.zeros(1, 1) + 1
next_value = torch.rand(1, 1)

In [None]:
r = estimate_returns_and_advantages(rewards, values, dones, next_done, next_value, num_steps, gamma, gae_lambda)
r[1].view(-1)

In [None]:
fr = fast_estimate_returns_and_advantages(rewards, values, dones, next_done, next_value, gamma, gae_lambda)
fr.view(-1)

In [None]:
torch.testing.assert_close(r[1], fr)

In [None]:
from torch import func as fc

In [None]:
v = fc.vmap(fast_estimate_returns_and_advantages, in_dims=(1, 1, 1, 1, 1, None, None), out_dims=1)

In [None]:
vr = v(rewards, values, dones, next_done, next_value, gamma, gae_lambda)

In [None]:
torch.testing.assert_close(r[1][:, 1, :], vr[:, 1, :])

In [None]:
a = torch.ones(8)
m = torch.tensor([1, 0, 0, 1, 0, 0, 1, 0])

In [None]:
cs = a.cumsum(0)
cs

In [None]:
torch.cummax(m * cs, 0)[0]

In [None]:
cs - torch.cummax(m * cs, 0)[0]

In [None]:
import torch
import torch.nn as nn
from tensordict.nn import dispatch
from tensordict import TensorDict


class MyModule(nn.Module):
    in_keys = ["a"]
    out_keys = ["b"]

    @dispatch
    def forward(self, tensordict, c, d=1, e=2):
        print("c:", c, "d:", d, "e:", e)
        tensordict["b"] = tensordict["a"] + 1
        return tensordict

    @dispatch(source=["c"], dest=["d"])
    def test(self, tensordict, a=42, b=None):
        print(a, b)
        tensordict["d"] = tensordict["c"] + 1
        return tensordict, a

In [None]:
m = MyModule()

In [None]:
a = TensorDict({"a": torch.zeros(16, 1)}, batch_size=[16])
o = m(a, c="abracadabra", d=None, e=4)

In [None]:
a = TensorDict({"c": torch.zeros(16, 1)}, batch_size=[16])
o = m.test(a, a=23)