In [25]:
from tensordict import TensorDict
import torch
import torch.nn as nn
from tensordict.nn import TensorDictModule
from torchrl.objectives.value import TDLambdaEstimator



# Example from https://pytorch.org/tensordict/stable/reference/generated/tensordict.nn.TensorDictModule.html

# from tensordict import TensorDict
# value_net = TensorDictModule(
#     nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
# )
# module = TDLambdaEstimator(
#     gamma=0.98,
#     lmbda=0.94,
#     value_network=value_net,
# )
# obs, next_obs = torch.randn(2, 1, 10, 3)
# reward = torch.randn(1, 10, 1)
# done = torch.zeros(1, 10, 1, dtype=torch.bool)
# terminated = torch.zeros(1, 10, 1, dtype=torch.bool)
# tensordict = TensorDict({"obs": obs, "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated}}, [1, 10])
# _ = module(tensordict)
# assert "advantage" in tensordict.keys()



# Define a simple value network
value_net = TensorDictModule(
    nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"]
)

# Define the TDLambdaEstimator
module = TDLambdaEstimator(
    gamma=0.98,
    lmbda=0.94,
    value_network=value_net,
)

# Define some mock data for observations, rewards, etc.
obs, next_obs = torch.randn(2, 1, 10, 3)
reward = torch.randn(1, 10, 1)
done = torch.zeros(1, 10, 1, dtype=torch.bool)
terminated = torch.zeros(1, 10, 1, dtype=torch.bool)

# Create a TensorDict to hold the data
tensordict = TensorDict(
    {
        "obs": obs,
        "next": {"obs": next_obs, "done": done, "reward": reward, "terminated": terminated},
    },
    [1, 10]
)

# Perform the TD(λ) estimation
_ = module(tensordict)
assert "advantage" in tensordict.keys()
print(tensordict["state_value"])
print(tensordict["obs"]) #obs is for observation?? observation of a state


tensor([[[-0.1888],
         [-1.0039],
         [-0.2972],
         [ 0.3838],
         [-0.0213],
         [ 1.1974],
         [ 0.5711],
         [-0.0019],
         [ 0.4399],
         [-0.4321]]])
tensor([[[-0.2702, -0.5811, -0.3897],
         [-0.8238,  1.1683, -2.2404],
         [ 0.7017,  0.0621, -1.4522],
         [-0.8791,  0.3105,  0.9278],
         [ 0.1898, -0.4703, -0.3810],
         [ 1.3271, -0.0940,  1.2741],
         [ 0.5811,  0.6501,  0.2461],
         [ 0.4324, -1.7821, -0.1039],
         [-0.6848, -0.6500,  1.2065],
         [-0.0964, -0.0170, -1.1792]]])
