In [13]:
import gymnasium as gym
import torch
from dataclasses import dataclass

@dataclass
class Config:
    env_id: str = "CartPole-v1"

    total_timesteps: int = 250_000
    """ total number of timesteps collected for the training """
    num_steps: int = 5_000
    """ number of steps per rollout (between updates) """

    pi_lr: float = 1e-2

In [14]:
env = gym.vector.SyncVectorEnv([lambda: gym.make("FrozenLake-v1")])

In [15]:
observation, info = env.reset()
done = True

for t in range(1000):
    if done:
        print(t)
        print(observation)
        print(done)
        print("---------------")

    action = env.action_space.sample()
    observation, reward, done, truncated, info = env.step(action)

    


0
[0]
True
---------------
3
[0]
[ True]
---------------
13
[0]
[ True]
---------------
18
[0]
[ True]
---------------
22
[0]
[ True]
---------------
31
[0]
[ True]
---------------
38
[0]
[ True]
---------------
41
[0]
[ True]
---------------
50
[0]
[ True]
---------------
57
[0]
[ True]
---------------
70
[0]
[ True]
---------------
89
[0]
[ True]
---------------
92
[0]
[ True]
---------------
99
[0]
[ True]
---------------
103
[0]
[ True]
---------------
113
[0]
[ True]
---------------
121
[0]
[ True]
---------------
135
[0]
[ True]
---------------
139
[0]
[ True]
---------------
151
[0]
[ True]
---------------
157
[0]
[ True]
---------------
159
[0]
[ True]
---------------
172
[0]
[ True]
---------------
176
[0]
[ True]
---------------
181
[0]
[ True]
---------------
183
[0]
[ True]
---------------
187
[0]
[ True]
---------------
199
[0]
[ True]
---------------
201
[0]
[ True]
---------------
204
[0]
[ True]
---------------
209
[0]
[ True]
---------------
236
[0]
[ True]
-----------

In [122]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np

class DebugEnv(gym.Env):
    def __init__(self):
        super(DebugEnv, self).__init__()
        # Action space: Choose 0 or 1 (no consequence)
        self.action_space = spaces.Discrete(2)
        # Observation space: An integer representing the state
        self.observation_space = spaces.Discrete(1000)  # A large number for demonstration
        self.state = 0
        self.done = False

    def reset(self, seed=None, options=None):
        # Reset state to 0 at the beginning of an episode
        self.state = 0
        self.done = False
        return self.state, {}  # Observation and info (empty for now)

    def step(self, action):
        # Increment state
        self.state += 1
        
        # Reward is randomly 0 or 1
        reward = np.random.choice([0, 1])
        
        # End the episode randomly
        if np.random.rand() > 0.9 and self.state>=5:  # 20% chance to end at each timestep
            self.done = True
            reward = 10

        if self.state==1:
            reward = 2
        
        return self.state, reward, self.done, False, {}  # state, reward, done, truncated, info

    def render(self):
        # Simple render function to print the state
        print(f"State: {self.state}")

# To use this environment for debugging, register and use it like this:

# Register the custom environment
gym.envs.registration.register(
    id='DebugEnv-v0',
    entry_point=DebugEnv,
)

In [123]:
device = "cpu"
config = Config(total_timesteps=100, num_steps=25)

env = gym.vector.SyncVectorEnv([lambda: gym.make('DebugEnv-v0')])

In [124]:
batch_obs = torch.zeros((config.num_steps,) + env.observation_space.shape).to(device) # not used
#batch_logprobs = torch.zeros(config.num_steps).to(device)
batch_actions = torch.zeros(config.num_steps).to(device)
batch_rewards = torch.zeros(config.num_steps).to(device)
batch_dones = torch.zeros(config.num_steps).to(device)

observation, _ = env.reset()
observation = torch.Tensor(observation).to(device)
done = torch.ones(1)

for t in range(config.num_steps):
    batch_obs[t] = observation
    batch_dones[t] = done

    with torch.no_grad():
        action = env.action_space.sample()

    # env step
    observation, reward, terminated, truncated, info = env.step(action)
    done = np.logical_or(terminated, truncated)
    observation, done = torch.Tensor(observation).to(device), torch.Tensor(done).to(device)

    #batch_logprobs[t] = logprob
    batch_actions[t] = torch.Tensor(action)
    batch_rewards[t] = torch.tensor(reward)

batch_returns = torch.zeros(config.num_steps)
curr_ret = 0
last_ep_idx = config.num_steps-1
for t in reversed(range(config.num_steps)):
    curr_ret += batch_rewards[t]

    if batch_dones[t]:
        batch_returns[t:last_ep_idx+1] = curr_ret
        curr_ret = 0
        last_ep_idx = t-1

In [125]:
torch.stack([batch_obs.squeeze(1), batch_dones, batch_rewards, batch_returns]).T

tensor([[ 0.,  1.,  2., 16.],
        [ 1.,  0.,  1., 16.],
        [ 2.,  0.,  1., 16.],
        [ 3.,  0.,  1., 16.],
        [ 4.,  0.,  1., 16.],
        [ 5.,  0., 10., 16.],
        [ 0.,  1.,  2., 19.],
        [ 1.,  0.,  0., 19.],
        [ 2.,  0.,  0., 19.],
        [ 3.,  0.,  0., 19.],
        [ 4.,  0.,  1., 19.],
        [ 5.,  0.,  1., 19.],
        [ 6.,  0.,  1., 19.],
        [ 7.,  0.,  0., 19.],
        [ 8.,  0.,  0., 19.],
        [ 9.,  0.,  1., 19.],
        [10.,  0.,  1., 19.],
        [11.,  0.,  0., 19.],
        [12.,  0.,  1., 19.],
        [13.,  0.,  0., 19.],
        [14.,  0.,  1., 19.],
        [15.,  0., 10., 19.],
        [ 0.,  1.,  2.,  3.],
        [ 1.,  0.,  0.,  3.],
        [ 2.,  0.,  1.,  3.]])

In [1]:
import torch
import numpy as np
import gymnasium as gym

import interaction

from vpg import Config

In [2]:
config = Config(num_steps=25)
env = gym.vector.SyncVectorEnv([lambda: gym.make('DebugEnv-v0')])

gamma = 0.9

In [3]:
batch_obs = torch.zeros((config.num_steps,) + env.single_observation_space.shape).to(config.device)
batch_actions = torch.zeros((config.num_steps,) + env.single_action_space.shape).to(config.device)
batch_rewards = torch.zeros(config.num_steps).to(config.device)
batch_dones = torch.zeros(config.num_steps).to(config.device)

# observation : (1, obs_dim)
# action : (1, action_dim)
# reward : (1,)
# done : (1,)

observation, _ = env.reset()
observation = torch.Tensor(observation).to(config.device)
done = torch.ones(1)

for t in range(config.num_steps):
    batch_obs[t] = observation
    batch_dones[t] = done

    with torch.no_grad():
        action = env.action_space.sample()

    # env step
    observation, reward, terminated, truncated, info = env.step(action)
    done = np.logical_or(terminated, truncated)
    observation, done = torch.Tensor(observation).to(config.device), torch.Tensor(done).to(config.device)

    batch_actions[t] = torch.tensor(action)
    batch_rewards[t] = torch.tensor(reward).to(config.device)

batch_rtg = torch.zeros(config.num_steps)
returns = [] # for logging, one per traj
rtg = 0
for t in reversed(range(config.num_steps)):
    rtg += batch_rewards[t]
    batch_rtg[t] = rtg
    rtg *= gamma

    if batch_dones[t]:
        returns.append(rtg)
        rtg = 0

In [4]:
torch.stack([batch_obs, batch_dones, batch_rewards, batch_rtg]).T

tensor([[ 0.0000,  1.0000,  2.0000,  8.0793],
        [ 1.0000,  0.0000,  0.0000,  6.7548],
        [ 2.0000,  0.0000,  1.0000,  7.5053],
        [ 3.0000,  0.0000,  0.0000,  7.2281],
        [ 4.0000,  0.0000,  1.0000,  8.0312],
        [ 5.0000,  0.0000,  0.0000,  7.8125],
        [ 6.0000,  0.0000,  1.0000,  8.6805],
        [ 7.0000,  0.0000,  1.0000,  8.5339],
        [ 8.0000,  0.0000,  1.0000,  8.3710],
        [ 9.0000,  0.0000,  0.0000,  8.1900],
        [10.0000,  0.0000,  1.0000,  9.1000],
        [11.0000,  0.0000,  0.0000,  9.0000],
        [12.0000,  0.0000, 10.0000, 10.0000],
        [ 0.0000,  1.0000,  2.0000,  8.0346],
        [ 1.0000,  0.0000,  1.0000,  6.7052],
        [ 2.0000,  0.0000,  0.0000,  6.3391],
        [ 3.0000,  0.0000,  1.0000,  7.0434],
        [ 4.0000,  0.0000,  0.0000,  6.7149],
        [ 5.0000,  0.0000,  0.0000,  7.4610],
        [ 6.0000,  0.0000,  1.0000,  8.2900],
        [ 7.0000,  0.0000,  0.0000,  8.1000],
        [ 8.0000,  0.0000,  0.0000

In [5]:
returns

[tensor(1.8000), tensor(7.2312), tensor(7.2713)]

In [6]:
a = b = 3

In [9]:
b

3

In [10]:
for t in reversed(range(5)):
    print(t)

4
3
2
1
0


In [12]:
for t in range(5-1, -1, -1):
    print(t)

4
3
2
1
0
