In [None]:
import torch
from scipy.linalg import circulant
from collections import deque

def compute_gae(
    next_value: list, rewards: list, masks: list, values: list, gamma: float, tau: float
):
    """Compute gae."""
    values = values + [next_value]
    gae = 0
    returns = deque()

    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * tau * masks[step] * gae
        returns.appendleft(gae + values[step])

    return list(returns)

def compute_gae_test(values: torch.Tensor, next_value: torch.Tensor, rewards: torch.Tensor, dones: torch.Tensor, gamma: float, tau: float):
    """Compute gae."""
    L = values.shape[0]
    device = values.device
    delta = rewards + gamma * (1 - dones) * next_value - values
    
    coef = torch.triu(torch.full((L, L), gamma * tau, device=device))
    l = torch.Tensor(circulant(torch.arange(L))).T.to(device=device, non_blocking=True)
    coef = coef ** l
    
    return coef @ delta + values

In [None]:
next_value = torch.rand(1)
rewards = torch.rand(4, 1)
dones = torch.tensor([[0],[0],[0],[1]])
values = torch.rand(4, 1)
print(next_value)
print(rewards)
print(dones)
print(values)

In [None]:
print(compute_gae(
    next_value.squeeze().tolist(),
    rewards.squeeze().tolist(),
    (1 - dones).squeeze().tolist(),
    values.squeeze().tolist(), 0.99, 0.9))

In [None]:
print(compute_gae_test(values, torch.cat([values, next_value.unsqueeze(0)], dim=0)[-4:], rewards, dones, 0.99, 0.9))

In [1]:
import torch

ckpt = torch.load("output/sarcoma+mvalues+no_invalid+freeze+obj_ptr+double_dqn+vanilla/2025-11-19-10-19-13/epoch_20.pth")
print(ckpt.keys())

dict_keys(['model', 'q_agent'])
