In [1]:
import torch
import time

from torchrl.envs import GymEnv, StepCounter, TransformedEnv
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torchrl.modules import EGreedyModule, MLP, QValueModule
from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torch.optim import Adam
from torchrl.objectives import DQNLoss, SoftUpdate
from torchrl._utils import logger as torchrl_logger
from torchrl.record import CSVLogger, VideoRecorder
from torchrl.modules import QValueActor
from torchrl.data import CompositeSpec

torch.manual_seed(0)


<torch._C.Generator at 0x7f12d5f00030>

In [2]:
# Define the environment
env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter())
env.set_seed(0)



795726461

In [3]:
class MICOMLPNetwork(torch.nn.Module):
    def __init__(self,
                 in_features,
                 activation_class, 
                 encoder_out_features,
                 mlp_out_features,
                 encoder_num_cells = None,
                 mlp_num_cells = None):
        super(MICOMLPNetwork, self).__init__()

        self.activation = activation_class()

        if encoder_num_cells is None:
            encoder_num_cells = []
        layers_sizes = [in_features] + encoder_num_cells + [encoder_out_features]

        self.encoder = torch.nn.ModuleList()
        for i in range(len(layers_sizes) - 1):
            self.encoder.append(torch.nn.Linear(layers_sizes[i], layers_sizes[i+1]))

        if mlp_num_cells is None:
            mlp_num_cells = []

        layers_sizes = [encoder_out_features] + mlp_num_cells + [mlp_out_features.item()]

        self.q_net = torch.nn.ModuleList()
        for i in range(len(layers_sizes) - 1):
            self.q_net.append(torch.nn.Linear(layers_sizes[i], layers_sizes[i+1]))
        
    
    def forward(self, x):
        for i in range(len(self.encoder)):
            x = self.activation(self.encoder[i](x))

        representation = x

        for i in range(len(self.q_net)-1):
            x = self.activation(self.q_net[i](x))

        return self.q_net[-1](x), representation

In [4]:
value_mlp = MICOMLPNetwork(
    in_features=4,
    activation_class=torch.nn.ReLU,
    encoder_out_features=3,
    mlp_out_features=env.action_spec.shape[-1],
    encoder_num_cells=[64],
    mlp_num_cells=[64]
)

value_mlp

MICOMLPNetwork(
  (activation): ReLU()
  (encoder): ModuleList(
    (0): Linear(in_features=4, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=3, bias=True)
  )
  (q_net): ModuleList(
    (0): Linear(in_features=3, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=2, bias=True)
  )
)

In [5]:
import numpy as np

# Specifications
num_observations = 4
cart_position_min = -4.8
cart_position_max = 4.8
cart_velocity_min = -np.inf
cart_velocity_max = np.inf
pole_angle_min = -0.418
pole_angle_max = 0.418
pole_angular_velocity_min = -np.inf
pole_angular_velocity_max = np.inf

def create_batched_random_tensor(n):
    # Creating the batched random tensor
    cart_position = np.random.uniform(cart_position_min, cart_position_max, size=n)
    cart_velocity = np.random.normal(loc=0.0, scale=10.0, size=n)  # Assuming normal distribution with large std deviation
    pole_angle = np.random.uniform(pole_angle_min, pole_angle_max, size=n)
    pole_angular_velocity = np.random.normal(loc=0.0, scale=10.0, size=n)  # Assuming normal distribution with large std deviation

    # Combining into a single tensor of shape (n, 4)
    batched_tensor = np.stack((cart_position, cart_velocity, pole_angle, pole_angular_velocity), axis=-1)
    
    return batched_tensor

# Example usage with batch size n = 5
n = 100
batched_tensor = torch.tensor(create_batched_random_tensor(n), dtype=torch.float32)
print(batched_tensor)

tensor([[-4.0787e+00, -1.0307e+01, -2.3199e-01,  9.0902e+00],
        [-3.3921e+00, -4.1854e+00,  6.8255e-02,  2.4537e+01],
        [-9.2820e-01,  8.3175e+00,  2.8329e-01, -1.2202e+01],
        [ 4.1594e+00, -9.1213e+00,  1.6487e-01, -5.2256e+00],
        [-9.8915e-01,  7.0261e+00,  2.7696e-01,  1.5996e+01],
        [-2.2121e+00,  9.5004e+00, -2.2218e-01,  1.5953e+01],
        [-4.1936e+00,  6.0429e+00,  8.0565e-02,  3.2016e+00],
        [-5.9195e-01,  8.8343e-01, -3.4044e-01, -3.9493e+00],
        [-1.0859e+00, -1.0469e+01, -1.0958e-01, -2.9635e+01],
        [-2.8919e+00,  5.4180e+00,  3.4800e-01,  8.0432e-01],
        [-4.2849e+00,  1.6165e+01, -1.0406e-01,  3.1885e+00],
        [ 3.6752e+00,  5.1540e-01, -2.2495e-01,  3.2402e+00],
        [-3.5336e+00, -7.8563e+00, -3.3527e-01, -6.3229e+00],
        [ 3.8183e+00,  1.2409e+01, -9.1591e-02,  1.3265e+01],
        [ 4.3202e+00, -1.8075e+01, -1.6755e-01, -1.6818e+00],
        [ 4.4210e+00,  6.1093e-01,  8.4307e-02, -1.8480e+01],
        

In [6]:
# import torch.nn as nn
# import torch.optim as optim

# # Initialize the network
# # network = MICOMLPNetwork(in_features=4,
# #                          activation_class=torch.nn.ReLU, 
# #                          encoder_out_features=8,
# #                          mlp_out_features=1,
# #                          encoder_num_cells=[16],
# #                          mlp_num_cells=[8])

# network = MICOMLPNetwork(
#     in_features=4,
#     activation_class=torch.nn.ReLU,
#     encoder_out_features=2,
#     mlp_out_features=env.action_spec.shape[-1],
#     encoder_num_cells=[128, ],
#     mlp_num_cells=[64]
# )

# # Define dummy target tensors for losses
# target_representation = torch.randn(100, 2)  # Assuming the representation has 8 features
# target_q_values = torch.randn(100, 2)  # Assuming the Q-values have 1 feature

# # Define loss functions
# criterion_representation = nn.MSELoss()
# criterion_q_values = nn.MSELoss()

# # Define an optimizer
# optimizer = optim.Adam(network.parameters(), lr=0.001)

# # Forward pass
# q_values, representation = network(batched_tensor)

# # Compute the losses
# loss_representation = criterion_representation(representation, target_representation)
# loss_q_values = criterion_q_values(q_values, target_q_values)

# # Sum the losses
# total_loss = loss_representation + loss_q_values

# # Backward pass and optimization
# optimizer.zero_grad()
# total_loss.backward()
# optimizer.step()

# # Print losses
# print(f"Total Loss: {total_loss.item()}, Loss Representation: {loss_representation.item()}, Loss Q-Values: {loss_q_values.item()}")

In [7]:
# # Print the gradients of the total_loss
# for name, param in network.named_parameters():
#     print(name, param.grad)

In [8]:
# Print the maximum of the model parameters
def print_maximum_weights(model):
    print("Maximum of the model parameters per layer")#
    weights = []
    for p in model.parameters():
        print(torch.max(p))
        weights.append(torch.max(p))
    return weights

def print_maximum_grads(model):
    print("Maximum of the model gradients per layer")#
    max_grads = []
    for p in model.parameters():
        if p.grad is not None:
            print(torch.max(p.grad))
            max_grads.append(torch.max(p.grad))
    return max_grads

def print_target_value_weights(loss_module):
    with loss_module.target_value_network_params.to_module(loss_module.value_network):
        return print_maximum_weights(loss_module.value_network)    

def print_value_weights(loss_module):
    with loss_module.value_network_params.to_module(loss_module.value_network):
        return print_maximum_weights(loss_module.value_network)
        # print_maximum_grads(loss_module.value_network)    

def print_value_grads(loss_module):
    with loss_module.value_network_params.to_module(loss_module.value_network):
        # print_maximum_weights(loss_module.value_network)
        return print_maximum_grads(loss_module.value_network)   

# Print the maximum of the model parameters
print_target_value_weights(loss_module)

print_value_weights(loss_module)

NameError: name 'loss_module' is not defined

In [None]:
# toy_example = torch.tensor(batched_tensor, dtype=torch.float32)
# q_values, representation = value_mlp(toy_example)
# print(q_values)
# print(representation)

tensor([[-0.0592,  0.1763],
        [ 0.0442,  0.2388],
        [-0.0479,  0.1664],
        [-0.3772,  0.2997],
        [-0.0151,  0.1565],
        [ 0.4427, -0.1130],
        [ 0.0363,  0.2191],
        [ 0.1308,  0.0657],
        [ 0.4101, -0.0557],
        [-0.2128,  0.2514],
        [-0.1447,  0.2496],
        [ 0.1500,  0.0343],
        [ 0.0926,  0.2491],
        [-0.4388,  0.4532],
        [ 0.2027,  0.0300],
        [ 0.3296, -0.0926],
        [ 0.1172,  0.2605],
        [ 0.1412,  0.0377],
        [-0.0290,  0.2214],
        [-0.1543,  0.2412],
        [ 0.7057, -0.3148],
        [ 0.2393,  0.1276],
        [ 0.2538,  0.0212],
        [ 0.1700,  0.0091],
        [-0.0832,  0.2328],
        [ 0.4549, -0.0923],
        [-0.0784,  0.1776],
        [-0.5939,  0.5909],
        [-0.0076,  0.2350],
        [-0.3963,  0.3171],
        [ 0.0748,  0.2182],
        [ 0.1462,  0.0295],
        [-0.0683,  0.2520],
        [ 0.3068, -0.0792],
        [-1.2523,  0.9355],
        [ 0.8156, -0

In [None]:
value_net = Mod(value_mlp, 
                in_keys=["observation"], 
                out_keys=["action_value", "representation"])
value_net

TensorDictModule(
    module=MICOMLPNetwork(
      (activation): ReLU()
      (encoder): ModuleList(
        (0): Linear(in_features=4, out_features=64, bias=True)
        (1): Linear(in_features=64, out_features=3, bias=True)
      )
      (q_net): ModuleList(
        (0): Linear(in_features=3, out_features=64, bias=True)
        (1): Linear(in_features=64, out_features=2, bias=True)
      )
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['action_value', 'representation'])

In [None]:
# policy = Seq(value_net, 
#              QValueModule(spec=env.action_spec))
# policy

policy = QValueActor(
    module=value_net,
    spec=CompositeSpec(action= env.specs["input_spec", "full_action_spec", "action"]),
    in_keys=["observation"],
)
policy

QValueActor(
    module=ModuleList(
      (0): TensorDictModule(
          module=MICOMLPNetwork(
            (activation): ReLU()
            (encoder): ModuleList(
              (0): Linear(in_features=4, out_features=64, bias=True)
              (1): Linear(in_features=64, out_features=3, bias=True)
            )
            (q_net): ModuleList(
              (0): Linear(in_features=3, out_features=64, bias=True)
              (1): Linear(in_features=64, out_features=2, bias=True)
            )
          ),
          device=cpu,
          in_keys=['observation'],
          out_keys=['action_value', 'representation'])
      (1): QValueModule()
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['representation', 'action', 'action_value', 'chosen_action_value'])

In [None]:
# Define the exploration step (e-greedy policy)
exploration_module = EGreedyModule(
    env.action_spec, 
    annealing_num_steps=100_000, 
    eps_init=0.1,
)
policy_explore = Seq(policy, 
                     exploration_module)
policy_explore

TensorDictSequential(
    module=ModuleList(
      (0): QValueActor(
          module=ModuleList(
            (0): TensorDictModule(
                module=MICOMLPNetwork(
                  (activation): ReLU()
                  (encoder): ModuleList(
                    (0): Linear(in_features=4, out_features=64, bias=True)
                    (1): Linear(in_features=64, out_features=3, bias=True)
                  )
                  (q_net): ModuleList(
                    (0): Linear(in_features=3, out_features=64, bias=True)
                    (1): Linear(in_features=64, out_features=2, bias=True)
                  )
                ),
                device=cpu,
                in_keys=['observation'],
                out_keys=['action_value', 'representation'])
            (1): QValueModule()
          ),
          device=cpu,
          in_keys=['observation'],
          out_keys=['representation', 'action', 'action_value', 'chosen_action_value'])
      (1): EGreedyModule()
   

In [None]:
# Define how to collect the data (experiences)
init_rand_steps = 5000 # warm-up steps
frames_per_batch = 100
optim_steps = 10
replay_capacity = 100_000

# NOTE: collector will gather rollouts continously
# If the current trajectory ends, it will start a new one
# NOTE: the rollout gotten from the collector is a dictionary
# that defines the sate and next state as a tensor with a batch dimension in the begining
# for example a rollout of 10 steps will have a tensor of observation of 10 in the batch dimension
# and the next will also have 10 which are all the tensors of the next state
# Practically, next is as you will shift the tensor of observation by one step
# collector = SyncDataCollector(
#     env,
#     policy_explore,
#     frames_per_batch=frames_per_batch,
#     total_frames=500_100,
#     init_random_frames=init_rand_steps,
# )
# rb = ReplayBuffer(storage=LazyTensorStorage(replay_capacity))

In [None]:
# Define the recording and logging
path = "./training_loop"
logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4")
video_recorder = VideoRecorder(logger, tag="video")
record_env = TransformedEnv(
    GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder
)



In [None]:
# collector = SyncDataCollector(
#     env,
#     policy_explore,
#     frames_per_batch=10,
#     total_frames=500_100,
#     init_random_frames=10000,
# )

collector = SyncDataCollector(
    create_env_fn=env,
    policy=policy_explore,
    frames_per_batch=10,
    total_frames=100,
    device="cpu",
    storing_device="cpu",
    max_frames_per_traj=-1
)
# NOTE: IMPORTANTISIMO en las primeras iteraciones no se usa la policy, entonces representation se configura
# a zero, por lo que el primer batch de datos no tiene representation
# Tengo que hacer el warm-up de otra manera (ojo con esto)

for data in collector:
    print(data['representation'])
    break

tensor([[0.1717, 0.1063, 0.0000],
        [0.1425, 0.0952, 0.0061],
        [0.0912, 0.0871, 0.0460],
        [0.0490, 0.0752, 0.0896],
        [0.0580, 0.0906, 0.1324],
        [0.0857, 0.0904, 0.1998],
        [0.1152, 0.0904, 0.2713],
        [0.1495, 0.0832, 0.3370],
        [0.1411, 0.0853, 0.3108],
        [0.1740, 0.0766, 0.3788]])


In [None]:
data['representation']

tensor([[0.1717, 0.1063, 0.0000],
        [0.1425, 0.0952, 0.0061],
        [0.0912, 0.0871, 0.0460],
        [0.0490, 0.0752, 0.0896],
        [0.0580, 0.0906, 0.1324],
        [0.0857, 0.0904, 0.1998],
        [0.1152, 0.0904, 0.2713],
        [0.1495, 0.0832, 0.3370],
        [0.1411, 0.0853, 0.3108],
        [0.1740, 0.0766, 0.3788]])

In [None]:
from torchrl.data import SliceSampler
from torchrl.data import TensorDictReplayBuffer

size = 100
rb = TensorDictReplayBuffer(
    storage=LazyTensorStorage(size),
    sampler=SliceSampler(traj_key=("collector","traj_ids"), slice_len=2),
    batch_size=10,
)
rb

TensorDictReplayBuffer(
    storage=LazyTensorStorage(
        data=<empty>, 
        shape=None, 
        len=0, 
        max_size=100), 
    sampler=SliceSampler(num_slices=None, slice_len=2, end_key=('next', 'done'), traj_key=('collector', 'traj_ids'), truncated_key=('next', 'truncated'), strict_length=True), 
    writer=TensorDictRoundRobinWriter(cursor=0, full_storage=False), 
    batch_size=10, 
    collate_fn=<function _collate_id at 0x7fe9fa97ba60>)

In [None]:
data['collector','traj_ids']

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [None]:
rb.extend(data)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [None]:
sample = rb.sample(10)

In [None]:
sample

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10

In [None]:
sample['step_count']

tensor([[7],
        [8],
        [0],
        [1],
        [7],
        [8],
        [0],
        [1],
        [5],
        [6]])

In [None]:
sample["observation"]

tensor([[ 0.1190,  1.4084, -0.1088, -2.0614],
        [ 0.1472,  1.2145, -0.1500, -1.8043],
        [ 0.0313,  0.0413,  0.0107,  0.0229],
        [ 0.0322,  0.2362,  0.0111, -0.2663],
        [ 0.1190,  1.4084, -0.1088, -2.0614],
        [ 0.1472,  1.2145, -0.1500, -1.8043],
        [ 0.0313,  0.0413,  0.0107,  0.0229],
        [ 0.0322,  0.2362,  0.0111, -0.2663],
        [ 0.0745,  1.0168, -0.0451, -1.4403],
        [ 0.0948,  1.2125, -0.0739, -1.7467]])

In [None]:
sample["next", "observation"]

tensor([[ 0.1472,  1.2145, -0.1500, -1.8043],
        [ 0.1715,  1.4110, -0.1861, -2.1396],
        [ 0.0322,  0.2362,  0.0111, -0.2663],
        [ 0.0369,  0.4312,  0.0058, -0.5555],
        [ 0.1472,  1.2145, -0.1500, -1.8043],
        [ 0.1715,  1.4110, -0.1861, -2.1396],
        [ 0.0322,  0.2362,  0.0111, -0.2663],
        [ 0.0369,  0.4312,  0.0058, -0.5555],
        [ 0.0948,  1.2125, -0.0739, -1.7467],
        [ 0.1190,  1.4084, -0.1088, -2.0614]])

In [None]:
sample['representation']

tensor([[0.1495, 0.0832, 0.3370],
        [0.1411, 0.0853, 0.3108],
        [0.1717, 0.1063, 0.0000],
        [0.1425, 0.0952, 0.0061],
        [0.1495, 0.0832, 0.3370],
        [0.1411, 0.0853, 0.3108],
        [0.1717, 0.1063, 0.0000],
        [0.1425, 0.0952, 0.0061],
        [0.0857, 0.0904, 0.1998],
        [0.1152, 0.0904, 0.2713]])

In [None]:
sample

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([10, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([10

In [None]:
sample['representation']

first_states = sample[0::2] # even rows
second_states = sample[1::2] # odd rows (or next states)

print(first_states)
print(second_states)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([5, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([5]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([5, 4]), de

In [None]:
first_states['representation']

tensor([[0.1495, 0.0832, 0.3370],
        [0.1717, 0.1063, 0.0000],
        [0.1495, 0.0832, 0.3370],
        [0.1717, 0.1063, 0.0000],
        [0.0857, 0.0904, 0.1998]])

In [None]:
second_states['representation']

tensor([[0.1411, 0.0853, 0.3108],
        [0.1425, 0.0952, 0.0061],
        [0.1411, 0.0853, 0.3108],
        [0.1425, 0.0952, 0.0061],
        [0.1152, 0.0904, 0.2713]])

In [None]:
import torch

a = torch.tensor([1, 2, 3])
repeated_a = torch.Tensor.repeat(a, (2, 3))

print(repeated_a)

a = torch.tensor([1, 2, 3])
tiled_a = torch.tile(a, (2, 3))
print(tiled_a)

tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3, 1, 2, 3]])
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3, 1, 2, 3]])


In [None]:
second_states['representation']

tensor([[0.1411, 0.0853, 0.3108],
        [0.1425, 0.0952, 0.0061],
        [0.1411, 0.0853, 0.3108],
        [0.1425, 0.0952, 0.0061],
        [0.1152, 0.0904, 0.2713]])

In [None]:
second_states['representation'].shape # batch, rep_dim

repeated_rep = torch.tile(second_states['representation'], (1,1,5)).view(5,5,3)
repeated_rep

tensor([[[0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108]],

        [[0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061]],

        [[0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108]],

        [[0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061]],

        [[0.1152, 0.0904, 0.2713],
         [0.1152, 0.0904, 0.2713],
         [0.1152, 0.0904, 0.2713],
         [0.1152, 0.0904, 0.2713],
         [0.1152, 0.0904, 0.2713]]])

In [None]:
second_states['representation'].shape

torch.Size([5, 3])

In [None]:
def squarify(x):
    # Squarify will take the input and adds a new dimension between the batch and the representation
    # so that the representation is repeated along the new dimension
    # To visualize thing of x as a matrix of batch_size x representation_dim
    # and squarify will place that matrix in a lateral way and repeat it along the new dimension j

    # NOTE: after squarify if you pick a i-th row all the elements (j-th index) in that row will be the same
    batch_size = x.shape[0]
    if len(x.shape) > 1:
        representation_dim = x.shape[-1]
        return x.tile((batch_size,)).view(batch_size, batch_size, representation_dim)
    return x.tile((batch_size,)).view(batch_size, batch_size)

In [None]:
squarify(second_states['next','reward']).squeeze(-1)

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

In [None]:
squarify(second_states['representation'])

tensor([[[0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108]],

        [[0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061]],

        [[0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108],
         [0.1411, 0.0853, 0.3108]],

        [[0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061],
         [0.1425, 0.0952, 0.0061]],

        [[0.1152, 0.0904, 0.2713],
         [0.1152, 0.0904, 0.2713],
         [0.1152, 0.0904, 0.2713],
         [0.1152, 0.0904, 0.2713],
         [0.1152, 0.0904, 0.2713]]])

In [None]:
def representation_distances(first_representations, second_representations,
                             distance_fn, beta=0.1,
                             return_distance_components=False):
  """Compute distances between representations.
     In the paper, it corresponds to the calculation of the U term
     for each pair of representations in the batch (all-vs-all).

  This will compute the distances between two representations.

  Args:
    first_representations: first set of representations to use.
    second_representations: second set of representations to use.
    distance_fn: function to use for computing representation distances.
    beta: float, weight given to cosine distance between representations.
    return_distance_components: bool, whether to return the components used for
      computing the distance.

  Returns:
    The distances between representations, combining the average of the norm of
    the representations and the distance given by distance_fn.
  """
  batch_size = first_representations.shape[0]
  representation_dim = first_representations.shape[-1]

  # Squarify the representations and reshape them to make a pair-waise comparison with vmap
  first_squared_reps = squarify(first_representations)
  first_squared_reps = torch.reshape(first_squared_reps,
                                   [batch_size**2, representation_dim])
  
  # Squarify the representations and reshape them to make a pair-waise comparison with vmap
  # However, we now need to permute (transpose) the dimension 0, 1 to alternate the values
  # so that we have the pair-wise comparisons of all-vs-all
  second_squared_reps = squarify(second_representations)
  second_squared_reps = torch.permute(second_squared_reps, dims=(1, 0, 2))
  second_squared_reps = torch.reshape(second_squared_reps,
                                    [batch_size**2, representation_dim])
  
  # vmap will calculate the pairwise distance_fn along the dimension specified
  # in in_axes. In this case, will take the dim 0 of the first_squared_reps and
  # the dim 0 of the second_squared_reps and apply the distance
  # It vertorize the process of calculating the distance between all the pairs

  # NOTE: base distance corresponds to the second term in the U calculation in the paper
  # It calculates the angle between the representations in the paper
  # Check what function is using
  base_distances = torch.vmap(distance_fn, in_dims=(0, 0))(first_squared_reps,
                                                         second_squared_reps)
  base_distances = base_distances
  print(base_distances.shape)
  # Sum along the second dimension and normalize the distance
  # NOTE: this is practically the first term of U in the paper
  norm_average = 0.5 * (torch.sum(torch.square(first_squared_reps), -1) +
                        torch.sum(torch.square(second_squared_reps), -1))
  
  print(norm_average.shape)
  if return_distance_components:
    return norm_average + beta * base_distances, norm_average, base_distances
  return norm_average + beta * base_distances

EPSILON = 1e-9

def _sqrt(x):
  # zeros like instead of zeros
  # It is because vmap works with a weird way of broadcasting
  # and a weird structure based on tensors
  tol = torch.zeros_like(x)
  return torch.sqrt(torch.maximum(x, tol))


def cosine_distance(x, y):
  # NOTE: the cosine similarity is not calculate directly for 
  # instabilities observed when using `jnp.arccos`, but I'm using torch
  # so I don't know if I will need to do this
  numerator = torch.sum(x * y)
  denominator = torch.sqrt(torch.sum(x**2)) * torch.sqrt(torch.sum(y**2))
  cos_similarity = numerator / (denominator + EPSILON)

  # cos_similarity = cos(theta)

  # NOTE: From, the Pythagorean trigometric identity
  # sin^2(theta) + cos^2(theta) = 1
  # you can get sin(theta) = sqrt(1 - cos^2(theta))
  # and the arctan2(sin(theta), cos(theta)) = theta
  return torch.arctan2(_sqrt(1. - cos_similarity**2), cos_similarity)

distances = representation_distances(first_states['representation'], 
                                     second_states['representation'], 
                                     cosine_distance)
distances

torch.Size([25])
torch.Size([25])


tensor([0.1357, 0.1928, 0.1357, 0.1928, 0.1266, 0.1906, 0.0400, 0.1906, 0.0400,
        0.1758, 0.1357, 0.1928, 0.1357, 0.1928, 0.1266, 0.1906, 0.0400, 0.1906,
        0.0400, 0.1758, 0.1046, 0.1417, 0.1046, 0.1417, 0.0849])

In [None]:
# NOTE: check in the main code if the output of this requires grad and the 
# other output must require grad
# @torch.no_grad()
def target_distances(representations, rewards, distance_fn, cumulative_gamma):
  """Target distance using the metric operator. This is the T in the paper :D"""
  next_state_similarities = representation_distances(
      representations, representations, distance_fn)
  squared_rews = squarify(rewards).squeeze(-1)
  squared_rews_transp = squared_rews.T
  squared_rews = squared_rews.reshape((squared_rews.shape[0]**2))
  squared_rews_transp = squared_rews_transp.reshape(
      (squared_rews_transp.shape[0]**2))
  reward_diffs = torch.abs(squared_rews - squared_rews_transp)
  return reward_diffs + cumulative_gamma * next_state_similarities

t_distances = target_distances(first_states['representation'], first_states['next','reward'], cosine_distance, cumulative_gamma = 0.9)

torch.Size([25])
torch.Size([25])


In [None]:
t_distances.requires_grad

False

In [None]:
# NOTE: Checar la distancia euclidean entre la representacion target y la representacion con
# la politica actual

# NOTE: en el repositorio de MICO, la distancia target es calculada con
# la target network (que es una copia de la politica actual) osea mis representaciones guardadas
# la distancia online por otra parte es calculada con una representacion con red actual
# y una representacion target


collector.policy(first_states['observation'])

(tensor([[0.1495, 0.0832, 0.3370],
         [0.1717, 0.1063, 0.0000],
         [0.1495, 0.0832, 0.3370],
         [0.1717, 0.1063, 0.0000],
         [0.0857, 0.0904, 0.1998]]),
 tensor([[0.0215, 0.2401],
         [0.0836, 0.2269],
         [0.0215, 0.2401],
         [0.0836, 0.2269],
         [0.0543, 0.2416]]),
 tensor([[0.2401],
         [0.2269],
         [0.2401],
         [0.2269],
         [0.2416]]),
 tensor([[0, 1],
         [0, 1],
         [0, 1],
         [0, 1],
         [0, 1]]))

In [None]:

# Make the components
# Policy

from dqn_mico_er.utils_cartpole import make_dqn_model, make_env
from dqn_mico_er.custom_modules import MICODQNLoss

from tensordict.nn import TensorDictSequential
from torchrl.data.replay_buffers.samplers import RandomSampler, PrioritizedSampler, PrioritizedSliceSampler
from torchrl.objectives import DQNLoss, HardUpdate

# load condig_cartpole.yaml
import yaml
from omegaconf import OmegaConf

with open("dqn_mico_er/config_cartpole.yaml") as f:
    cfg = OmegaConf.create(yaml.safe_load(f))

model = make_dqn_model("CartPole-v1", cfg.policy)

greedy_module = EGreedyModule(
    annealing_num_steps=cfg.collector.annealing_frames,
    eps_init=cfg.collector.eps_start,
    eps_end=cfg.collector.eps_end,
    spec=model.spec,
)
model_explore = TensorDictSequential(
    model,
    greedy_module,
) #.to(device)

# Create the collector
# NOTE: init_random_frames: Number of frames 
# for which the policy is ignored before it is called.
collector = SyncDataCollector(
    create_env_fn=make_env(cfg.env.env_name, "cpu", cfg.env.seed),
    policy=model_explore,
    frames_per_batch=cfg.collector.frames_per_batch,
    total_frames=cfg.collector.total_frames,
    device="cpu",
    storing_device="cpu",
    max_frames_per_traj=-1,
    init_random_frames=cfg.collector.init_random_frames,
)

# Create the replay buffer
if cfg.buffer.prioritized_replay:
    print("Using Prioritized Replay Buffer")
    sampler = PrioritizedSliceSampler(
        max_capacity=cfg.buffer.buffer_size, 
        alpha=cfg.buffer.alpha, 
        beta=cfg.buffer.beta, 
        traj_key=("collector","traj_ids"), 
        slice_len=2)
else:
    sampler = SliceSampler(
        traj_key=("collector","traj_ids"), 
        slice_len=2)
    
replay_buffer = TensorDictReplayBuffer(
    pin_memory=False,
    prefetch=10,
    storage=LazyTensorStorage(
        max_size=cfg.buffer.buffer_size,
        device="cpu",
    ),
    batch_size=cfg.buffer.batch_size,
    sampler = sampler
)

# Create the loss module
loss_module = MICODQNLoss(
    value_network=model,
    loss_function="l2", 
    delay_value=True, # delay_value=True means we will use a target network
    mico_gamma=cfg.loss.mico_gamma,
    mico_beta=cfg.loss.mico_beta,
    mico_weight=cfg.loss.mico_weight,
)

loss_module.make_value_estimator(gamma=cfg.loss.gamma) # only to change the gamma value
loss_module = loss_module #.to(device)
target_net_updater = HardUpdate(
    loss_module, value_network_update_interval=cfg.loss.hard_update_freq
)

# Create the optimizer
optimizer = torch.optim.Adam(loss_module.parameters(), lr=cfg.optim.lr)



In [None]:
# Print the maximum of the model parameters
def print_maximum_weights(model):
    print("Maximum of the model parameters per layer")#
    weights = []
    for p in model.parameters():
        print(torch.max(p))
        weights.append(torch.max(p))
    return weights

def print_maximum_grads(model):
    print("Maximum of the model gradients per layer")#
    max_grads = []
    for p in model.parameters():
        if p.grad is not None:
            print(torch.max(p.grad))
            max_grads.append(torch.max(p.grad))
    return max_grads

def print_target_value_weights(loss_module):
    with loss_module.target_value_network_params.to_module(loss_module.value_network):
        return print_maximum_weights(loss_module.value_network)    

def print_value_weights(loss_module):
    with loss_module.value_network_params.to_module(loss_module.value_network):
        return print_maximum_weights(loss_module.value_network)
        # print_maximum_grads(loss_module.value_network)    

def print_value_grads(loss_module):
    with loss_module.value_network_params.to_module(loss_module.value_network):
        # print_maximum_weights(loss_module.value_network)
        return print_maximum_grads(loss_module.value_network)   

# Print the maximum of the model parameters
print_target_value_weights(loss_module)

print_value_weights(loss_module)

Maximum of the model parameters per layer
tensor(0.4998)
tensor(0.4991)
tensor(0.0913)
tensor(0.0879)
tensor(0.1083)
tensor(0.0750)
tensor(0.6998)
tensor(0.6956)
tensor(0.0913)
tensor(0.0894)
tensor(0.1087)
tensor(-0.0087)
Maximum of the model parameters per layer
tensor(0.4998, grad_fn=<MaxBackward1>)
tensor(0.4991, grad_fn=<MaxBackward1>)
tensor(0.0913, grad_fn=<MaxBackward1>)
tensor(0.0879, grad_fn=<MaxBackward1>)
tensor(0.1083, grad_fn=<MaxBackward1>)
tensor(0.0750, grad_fn=<MaxBackward1>)
tensor(0.6998, grad_fn=<MaxBackward1>)
tensor(0.6956, grad_fn=<MaxBackward1>)
tensor(0.0913, grad_fn=<MaxBackward1>)
tensor(0.0894, grad_fn=<MaxBackward1>)
tensor(0.1087, grad_fn=<MaxBackward1>)
tensor(-0.0087, grad_fn=<MaxBackward1>)


[tensor(0.4998, grad_fn=<MaxBackward1>),
 tensor(0.4991, grad_fn=<MaxBackward1>),
 tensor(0.0913, grad_fn=<MaxBackward1>),
 tensor(0.0879, grad_fn=<MaxBackward1>),
 tensor(0.1083, grad_fn=<MaxBackward1>),
 tensor(0.0750, grad_fn=<MaxBackward1>),
 tensor(0.6998, grad_fn=<MaxBackward1>),
 tensor(0.6956, grad_fn=<MaxBackward1>),
 tensor(0.0913, grad_fn=<MaxBackward1>),
 tensor(0.0894, grad_fn=<MaxBackward1>),
 tensor(0.1087, grad_fn=<MaxBackward1>),
 tensor(-0.0087, grad_fn=<MaxBackward1>)]

In [1]:
# Main loop
collected_frames = 0
total_episodes = 0
start_time = time.time()
num_updates = cfg.loss.num_updates
batch_size = cfg.buffer.batch_size
test_interval = cfg.logger.test_interval
num_test_episodes = cfg.logger.num_test_episodes
frames_per_batch = cfg.collector.frames_per_batch
# pbar = tqdm.tqdm(total=cfg.collector.total_frames)
init_random_frames = cfg.collector.init_random_frames
sampling_start = time.time()
q_losses = torch.zeros(num_updates) #, device=device)

for i, data in enumerate(collector):

        # NOTE: This reshape must be for frame data (maybe)
        data = data.reshape(-1)
        current_frames = data.numel()
        replay_buffer.extend(data)
        collected_frames += current_frames
        greedy_module.step(current_frames)

        # Get the number of episodes
        total_episodes += data["next", "done"].sum()

        # Get and log training rewards and episode lengths
        # Collect the episode rewards and lengths in average over the
        # transitions in the current data batch
        episode_rewards = data["next", "episode_reward"][data["next", "done"]]


        # Warmup phase (due to the continue statement)
        # Additionally This help us to keep a track of the collected_frames
        # after the init_random_frames
        if collected_frames < init_random_frames:
            continue

        # optimization steps
        training_start = time.time()
        for j in range(num_updates):
            sampled_tensordict = replay_buffer.sample(batch_size)
            # TODO: check if the sample is already in the device
            sampled_tensordict = sampled_tensordict #.to(device)

            # Also the loss module will use the current and target model to get the q-values
            loss_td = loss_module(sampled_tensordict)
            q_loss = loss_td["td_loss"]
            optimizer.zero_grad()
            q_loss.backward()
            optimizer.step()

            # Update the priorities
            if cfg.buffer.prioritized_replay:
                replay_buffer.update_priority(index=sampled_tensordict['index'], priority = sampled_tensordict['td_error'])

            # NOTE: This is only one step (after n-updated steps defined before)
            # the target will update
            target_net_updater.step()
            q_losses[j].copy_(q_loss.detach())
        training_time = time.time() - training_start

        # Get and log evaluation rewards and eval time
        # NOTE: As I'm using only the model and not the model_explore that will deterministic I think
        # with torch.no_grad(): #, set_exploration_type(ExplorationType.DETERMINISTIC):

        #     # NOTE: Check how we are using the frames here because it seems that I am dividing 
        #     # 10 for 50000
        #     prev_test_frame = ((i - 1) * frames_per_batch) // test_interval
        #     cur_test_frame = (i * frames_per_batch) // test_interval
        #     final = current_frames >= collector.total_frames

        #     # compara prev_test_frame < cur_test_frame is the same as current_frames % test_interval == 0
        #     if (i >= 1 and (prev_test_frame < cur_test_frame)) or final:
        #         model.eval()
        #         eval_start = time.time()
        #         test_rewards = eval_model(model, test_env, num_test_episodes)
        #         eval_time = time.time() - eval_start
        #         model.train()
        #         log_info.update(
        #             {
        #                 "eval/reward": test_rewards,
        #                 "eval/eval_time": eval_time,
        #             }
        #         )

        # Log all the information

        # update weights of the inference policy
        # NOTE: Updates the policy weights if the policy of the data 
        # collector and the trained policy live on different devices.
        collector.update_policy_weights_()
        sampling_start = time.time()

collector.shutdown()

NameError: name 'time' is not defined

In [28]:
loss = DQNLoss(value_network=policy, 
               action_space=env.action_spec, 
               delay_value=True) # delay_value=True means we will use a target network
optim = Adam(loss.parameters(), lr=0.02)

# eps: will be used to update the target network as 
# \theta_t = \theta_{t-1} * \epsilon + \theta_t * (1-\epsilon)
# where eps = 1 is hard update
updater = SoftUpdate(loss, eps=0.99)

In [34]:
loss.value_network_params

TensorDictParams(params=TensorDict(
    fields={
        module: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        module: TensorDict(
                            fields={
                                activation: TensorDict(
                                    fields={
                                    },
                                    batch_size=torch.Size([]),
                                    device=None,
                                    is_shared=False),
                                encoder: TensorDict(
                                    fields={
                                        0: TensorDict(
                                            fields={
                                                bias: Parameter(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
                                                weight: Parameter(shape=torch.Size([64, 4]), device=cpu, dtype

In [33]:
loss.target_value_network_params

TensorDictParams(params=TensorDict(
    fields={
        module: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        module: TensorDict(
                            fields={
                                activation: TensorDict(
                                    fields={
                                    },
                                    batch_size=torch.Size([]),
                                    device=None,
                                    is_shared=False),
                                encoder: TensorDict(
                                    fields={
                                        0: TensorDict(
                                            fields={
                                                bias: Parameter(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
                                                weight: Parameter(shape=torch.Size([64, 4]), device=cpu, dtype

In [80]:
from tensordict import TensorDict
from torchrl.data import SliceSampler
from torchrl.data import LazyMemmapStorage

rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(size),
    sampler=SliceSampler(traj_key="episode", num_slices=4),
    batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
obs = torch.randn((3, 4, 5)).expand(10, 3, 4, 5)
data = TensorDict(
    {
        "episode": episode,
        "obs": obs,
        "act": torch.randn((20,)).expand(10, 20),
        "other": torch.randn((20, 50)).expand(10, 20, 50),
        "steps": steps,
    },
    [10],
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])

episode are grouped tensor([3, 3, 4, 4, 2, 2, 1, 1], dtype=torch.int32)
steps are successive tensor([0, 1, 0, 1, 0, 1, 0, 1])


In [82]:
episode

tensor([1, 1, 1, 2, 2, 3, 3, 4, 4, 4], dtype=torch.int32)

In [81]:
rb

TensorDictReplayBuffer(
    storage=LazyMemmapStorage(
        data=TensorDict(
            fields={
                act: MemoryMappedTensor(shape=torch.Size([10, 20]), device=cpu, dtype=torch.float32, is_shared=False),
                episode: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.int32, is_shared=False),
                index: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
                obs: MemoryMappedTensor(shape=torch.Size([10, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
                other: MemoryMappedTensor(shape=torch.Size([10, 20, 50]), device=cpu, dtype=torch.float32, is_shared=False),
                steps: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False), 
        shape=torch.Size([10]), 
        len=10, 
        max_size=100), 
    sampler=S

In [234]:
import datetime

current_date = datetime.datetime.now()
date_str = current_date.strftime("%Y_%m_%d-%H_%M_%S")  # Includes date and time
date_str

'2024_07_23-17_34_50'

In [235]:
total_count = 0
total_episodes = 0
t0 = time.time()
for i, data in enumerate(collector):
    # Write data in replay buffer
    rb.extend(data)
    max_length = rb[:]["next", "step_count"].max() # From all the next steps get the max step count
    if len(rb) > init_rand_steps: # wam-up steps
        # Optim loop (we do several optim steps
        # per batch collected for efficiency)
        for _ in range(optim_steps):
            sample = rb.sample(128) # sample a batch of 128 (repetition is allowed)
            # print(sample)
            break
            loss_vals = loss(sample)
            loss_vals["loss"].backward()
            optim.step()
            optim.zero_grad()
            # Update exploration factor
            # NOTE: Why I am updating the exploration factor here? 
            # I'm considering practically that I did 100 (or n) iteractions in the environment time optim_steps
            exploration_module.step(data.numel()) # data.numel() returns the number of elements in the data
            # Update target params each optimisation step
            updater.step()
            if i % 10:
                torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}")
            total_count += data.numel()
            total_episodes += data["next", "done"].sum() # sum the number of done episodes
    
    if max_length > 200:
        break

t1 = time.time()

torchrl_logger.info(
    f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s."
)

2024-07-23 17:34:53,433 [torchrl][INFO] solved after 0 steps, 0 episodes and in 2.57519268989563s.


In [242]:
sample[0]

TensorDict(
    fields={
        _weight: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_sha

In [236]:
sample

TensorDict(
    fields={
        _weight: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.float32, is_shared=False),
        action: Tensor(shape=torch.Size([128, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        action_value: Tensor(shape=torch.Size([128, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        chosen_action_value: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([128]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([128, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        index: Tensor(shape=torch.Size([128]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([128, 1]),

In [8]:
record_env.rollout(max_steps=1000, policy=policy)
video_recorder.dump()

In [5]:
import random

# Generate and print 10 random seeds
random_seeds = [random.randint(0, 1000000) for _ in range(10)]
print(random_seeds)

[118398, 676190, 786456, 171936, 887739, 919409, 711872, 442081, 189061, 117840]
