In [1]:
from collections import defaultdict
from matplotlib import pyplot as plt
import torch
from torch import nn
import torchrl
import torchrl.envs as torch_envs
from tqdm import tqdm
import gymnasium as gym
import tensordict
from tensordict import nn as dict_nn
import torchsummary

from spaceship_env import SpaceshipEnv

  from pkg_resources import resource_stream, resource_exists


In [2]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [3]:
device = torch.device("cuda")
lr = 3e-4
max_grad_norm = 1.0

frames_per_batch = 1000
total_frames = 1_000_000

sub_batch_size = 64
num_epochs = 10
clip_epsilon = 0.2
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

In [4]:
def make_norm_transforms(env: gym.Env):
    transforms = []
    for key, space in env.observation_space.items():
        if key in ["position", "velocity", "rotation"]:
            transforms.append(torch_envs.transforms.ObservationNorm(loc=space.low, scale=1 / (space.high-space.low), in_keys=key, out_keys=key, standard_normal=False))
    return torch_envs.transforms.Compose(*transforms)
        

In [5]:
make_norm_transforms(SpaceshipEnv())

Compose(
        ObservationNorm(keys=['position']),
        ObservationNorm(keys=['rotation']),
        ObservationNorm(keys=['velocity']))

In [6]:
env = SpaceshipEnv()
for _ in range(5):
    print(env.step(2))

({'position': array([100.  ,  99.94]), 'velocity': array([ 0.  , -0.06]), 'rotation': array([0, 1, 0, 0]), 'step_count': 1}, -0.02, False, False, None)
({'position': array([99.94, 99.88]), 'velocity': array([-0.06, -0.06]), 'rotation': array([0, 0, 1, 0]), 'step_count': 2}, -0.02, False, False, None)
({'position': array([99.88, 99.88]), 'velocity': array([-0.06,  0.  ]), 'rotation': array([0, 0, 0, 1]), 'step_count': 3}, -0.02, False, False, None)
({'position': array([99.88, 99.88]), 'velocity': array([0., 0.]), 'rotation': array([1, 0, 0, 0]), 'step_count': 4}, -0.02, False, False, None)
({'position': array([99.88, 99.82]), 'velocity': array([ 0.  , -0.06]), 'rotation': array([0, 1, 0, 0]), 'step_count': 5}, -0.02, False, False, None)


In [7]:
gym.register('Spaceship_Target', entry_point="spaceship_env:SpaceshipEnv")

env = torch_envs.GymEnv('Spaceship_Target', device=device)
print(env.observation_spec.keys())
env = torch_envs.transforms.TransformedEnv(base_env=env, 
                                             transform=torch_envs.Compose([
                                                 make_norm_transforms(env),
                                                 torch_envs.transforms.CatTensors(["position", "velocity", "rotation"], "observation")
                                                 ]))

logged_env = torch_envs.GymEnv('Spaceship_Target', device=torch.device('cuda'), return_pixels=True)
logged_env = torch_envs.transforms.TransformedEnv(base_env=logged_env, 
                                             transform=torch_envs.Compose([
                                                 make_norm_transforms(env),
                                                 torch_envs.transforms.CatTensors(["position", "velocity", "rotation"], "observation")
                                                 ]))

_CompositeSpecKeysView(keys=['position', 'rotation', 'step_count', 'velocity'])


In [8]:
logged_env.observation_spec.keys()

_CompositeSpecKeysView(keys=['step_count', 'pixels', 'observation'])

In [9]:
env.rollout(3)['observation']

tensor([[ 7.1429e-02,  1.0000e-01,  1.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 7.1429e-02,  9.9940e-02,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00, -6.0000e-04],
        [ 7.1471e-02,  9.9880e-02,  1.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  6.0000e-04, -6.0000e-04]], device='cuda:0')

In [10]:
from torchrl.modules.tensordict_module import ProbabilisticActor

actor = nn.Sequential(
    nn.Linear(env.observation_spec['observation'].shape[0], 128, device=device),
    nn.Tanh(),
    nn.LazyLinear(128, device=device),
    nn.Tanh(),
    nn.LazyLinear(128, device=device),
    nn.Tanh(),
    nn.LazyLinear(3, device=device),
)

policy_module = dict_nn.TensorDictModule(actor, in_keys=["observation"], out_keys=["logits"])
policy_module = ProbabilisticActor(module=policy_module,
                                    spec=env.action_spec,
                                    in_keys=["logits"],
                                    distribution_class=torch.distributions.OneHotCategorical,
                                    return_log_prob=True)

In [11]:
import torchsummary
torchsummary.summary(actor, (8,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 128]           1,152
              Tanh-2                  [-1, 128]               0
            Linear-3                  [-1, 128]          16,512
              Tanh-4                  [-1, 128]               0
            Linear-5                  [-1, 128]          16,512
              Tanh-6                  [-1, 128]               0
            Linear-7                    [-1, 3]             387
Total params: 34,563
Trainable params: 34,563
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.13
Estimated Total Size (MB): 0.14
----------------------------------------------------------------


In [12]:
from torchrl.modules import ValueOperator

value_net = nn.Sequential(
    nn.Linear(8, 128, device=device),
    nn.Tanh(),
    nn.LazyLinear(128, device=device),
    nn.Tanh(),
    nn.LazyLinear(1, device=device),
)

value_module = ValueOperator(value_net, in_keys=["observation"])

In [13]:
print(policy_module(env.reset())['action'])
print(value_module(env.reset()))

tensor([1., 0., 0.], device='cuda:0')
TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        observation: Tensor(shape=torch.Size([8]), device=cuda:0, dtype=torch.float32, is_shared=True),
        state_value: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.float32, is_shared=True),
        step_count: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.float64, is_shared=True),
        terminated: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        truncated: Tensor(shape=torch.Size([1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True)


In [14]:
out = env.rollout(1000, lambda x, y: torch.tensor([1, 0, 0], dtype=torch.long))
print(out['observation'])
sum(out['done']), len(out['observation'])

tensor([[ 7.1514e-02,  9.9880e-02,  0.0000e+00,  ...,  1.0000e+00,
          6.0000e-04,  0.0000e+00],
        [ 7.1514e-02,  9.9880e-02,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 7.1514e-02,  9.9820e-02,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00, -6.0000e-04],
        ...,
        [ 9.2857e-02,  7.0000e-02,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 9.2857e-02,  6.9940e-02,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00, -6.0000e-04],
        [ 9.2900e-02,  6.9880e-02,  1.0000e+00,  ...,  0.0000e+00,
          6.0000e-04, -6.0000e-04]], device='cuda:0')


(tensor([0], device='cuda:0'), 1000)

In [15]:
for row in out['observation']:
    print([round(float(num), 3) for num in row])

[0.072, 0.1, 0.0, 0.0, 0.0, 1.0, 0.001, 0.0]
[0.072, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
[0.072, 0.1, 0.0, 1.0, 0.0, 0.0, 0.0, -0.001]
[0.072, 0.1, 1.0, 0.0, 0.0, 0.0, 0.001, -0.001]
[0.072, 0.1, 0.0, 0.0, 0.0, 1.0, 0.001, 0.0]
[0.072, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
[0.072, 0.1, 0.0, 1.0, 0.0, 0.0, 0.0, -0.001]
[0.072, 0.1, 1.0, 0.0, 0.0, 0.0, 0.001, -0.001]
[0.072, 0.1, 0.0, 0.0, 0.0, 1.0, 0.001, 0.0]
[0.072, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
[0.072, 0.1, 0.0, 1.0, 0.0, 0.0, 0.0, -0.001]
[0.072, 0.1, 1.0, 0.0, 0.0, 0.0, 0.001, -0.001]
[0.072, 0.1, 0.0, 0.0, 0.0, 1.0, 0.001, 0.0]
[0.072, 0.1, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
[0.072, 0.099, 0.0, 1.0, 0.0, 0.0, 0.0, -0.001]
[0.072, 0.099, 1.0, 0.0, 0.0, 0.0, 0.001, -0.001]
[0.072, 0.099, 0.0, 0.0, 0.0, 1.0, 0.001, 0.0]
[0.072, 0.099, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]
[0.072, 0.099, 0.0, 1.0, 0.0, 0.0, 0.0, -0.001]
[0.072, 0.099, 1.0, 0.0, 0.0, 0.0, 0.001, -0.001]
[0.072, 0.099, 0.0, 0.0, 0.0, 1.0, 0.001, 0.0]
[0.072, 0.099, 0.0, 0.0, 1.0, 0

In [16]:
import torch.nn.functional as F
def simple_rollout(env, action_index, steps):
    td = env.reset()
    results = []
    
    # Get the number of classes from the environment spec
    # usually env.action_spec.shape[-1] for OneHot specs
    n_actions = env.action_spec.shape[-1]
    
    # Create the One-Hot tensor: e.g., 2 -> [0, 0, 1]
    action_one_hot = F.one_hot(torch.tensor(action_index), n_actions)
    
    for _ in range(steps):
        td['action'] = action_one_hot
        td = env.step(td)
        
        # Note: To print the index (2) instead of the vector, use argmax
        print(f"Action: {td['action'].argmax().item()} | Reward: {td[('next', 'reward')].item()}")
        
        results.append(td.clone())
        td = env.step_mdp(td)
    return torch.stack(results)

simple_rollout(env, torch.tensor(2), 100)

Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.019999999552965164
Action: 2 | Reward: -0.01999999955

  action_one_hot = F.one_hot(torch.tensor(action_index), n_actions)


TensorDict(
    fields={
        action: Tensor(shape=torch.Size([100, 3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        done: Tensor(shape=torch.Size([100, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([100, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                observation: Tensor(shape=torch.Size([100, 8]), device=cuda:0, dtype=torch.float32, is_shared=True),
                reward: Tensor(shape=torch.Size([100, 1]), device=cuda:0, dtype=torch.float32, is_shared=True),
                step_count: Tensor(shape=torch.Size([100, 1]), device=cuda:0, dtype=torch.float64, is_shared=True),
                terminated: Tensor(shape=torch.Size([100, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                truncated: Tensor(shape=torch.Size([100, 1]), device=cuda:0, dtype=torch.bool, is_shared=True)},
            batch_size=torch.Size([100]),
       

In [17]:
from torchrl.objectives.value import GAE
from torchrl.objectives import ClipPPOLoss

advantage_module = GAE(gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True, device=device)
loss_module = ClipPPOLoss(actor_network=policy_module, critic_network=value_module, clip_epsilon=clip_epsilon, entropy_coeff=entropy_eps)
optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, total_frames // frames_per_batch, 0.0)

In [18]:
from torchrl.collectors import SyncDataCollector
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement

collector = SyncDataCollector(env, policy_module, frames_per_batch=frames_per_batch, total_frames=total_frames, device=device)
replay_buffer = ReplayBuffer(storage=LazyTensorStorage(max_size=frames_per_batch), sampler=SamplerWithoutReplacement())

In [19]:
from torchrl.record import VideoRecorder
from torchrl.record.loggers.csv import CSVLogger

logger = CSVLogger(exp_name="Spaceship_Target", log_dir="target1_videos", video_format="mp4")
logged_env = torch_envs.transforms.TransformedEnv(logged_env, VideoRecorder(logger, tag="run_video", in_keys=['pixels']))  # should just use render

  instance: EnvBase = super(_EnvPostInit, self).__call__(*args, **kwargs)


In [20]:
out = logged_env.rollout(1000, lambda x, y, z: torch.tensor([0, 1, 0]))
logged_env.transform[-1].dump()



In [21]:
type(env.render())

numpy.ndarray

In [22]:
for data in collector:
    print(data)
    break

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1000, 3]), device=cuda:0, dtype=torch.float32, is_shared=True),
        action_log_prob: Tensor(shape=torch.Size([1000]), device=cuda:0, dtype=torch.float32, is_shared=True),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([1000]), device=cuda:0, dtype=torch.int64, is_shared=True)},
            batch_size=torch.Size([1000]),
            device=cuda:0,
            is_shared=True),
        done: Tensor(shape=torch.Size([1000, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        logits: Tensor(shape=torch.Size([1000, 3]), device=cuda:0, dtype=torch.float32, is_shared=True),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1000, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
                observation: Tensor(shape=torch.Size([1000, 8]), device=cuda:0, dtype=torch.float32, is_shared=True),
                rew

In [None]:
from tensordict.nn import set_interaction_type, InteractionType

logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""

for i, tensordict_data in enumerate(collector):
    for _ in range(num_epochs):
        advantage_module(tensordict_data)
        data_view = tensordict_data.reshape(-1)
        replay_buffer.extend(data_view.cpu()) # not exactly sure why cpu
        for _ in range(frames_per_batch // sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size)
            loss_vals = loss_module(subdata.to(device))
            loss_value = loss_vals["loss_objective"] + loss_vals["loss_critic"] + loss_vals["loss_entropy"]
            loss_value.backward()
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()
            
        logs["reward"].append(tensordict_data["next", "reward"].mean().item())
        pbar.update(tensordict_data.numel())
        
    cum_reward_str = (
        f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
    )
    logs["step_count"].append(tensordict_data["step_count"].max().item())
    stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    
    if i % 10 == 0:
        with set_interaction_type(InteractionType.DETERMINISTIC), torch.no_grad():  # magic
            eval_rollout = logged_env.rollout(1000, policy_module)
            logged_env.transform[-1].dump()
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(
                eval_rollout["next", "reward"].sum().item()
            )
            logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (
                f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
                f"eval step-count: {logs['eval step_count'][-1]}"
            )
            del eval_rollout
            
    pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))
    scheduler.step()


eval cumulative reward: -2.1200 (init: -2.6000), eval step-count: 80.0, average reward=-0.0250 (init=-0.0195), step count (max): 80.0, lr policy:  0.0002: : 3140000it [06:25, 12353.73it/s]                            

KeyboardInterrupt: 

eval cumulative reward: -2.1200 (init: -2.6000), eval step-count: 80.0, average reward=-0.0250 (init=-0.0195), step count (max): 80.0, lr policy:  0.0002: : 3140000it [06:40, 12353.73it/s]

In [None]:
plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()