Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the normalize vector wrapper tests #659

Conversation

pseudo-rnd-thoughts
Copy link
Member

Description

We include tests that the wrappers.vector.NormalizeReward and wrappers.vector.NormalizeObservation wrappers work similarly to wrappers.NormalizeReward and wrappers.NormalizeObservation.
However, the bounds on similarity were too small and caused the tests to fail occur.

As a result, we fix a bug in wrappers.NormalizeObservation where every dimension was equally normalize rather than individually normalized.

For reference, the code used to test and evaluate this

#!/usr/bin/env python
# coding: utf-8

# In[9]:


import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import gymnasium as gym
from gymnasium.spaces import Box
from tests.testing_env import GenericTestEnv


# In[17]:


NUM_ENVS = 3
ENV_ID = "TestEnv-v0"
LENGTH = 500
REPEATS = 100

gym.register("TestEnv-v0", lambda: GenericTestEnv(observation_space=Box(low=np.array([0, -10, -5]), high=np.array([10, -5, 10]))))

obs_shape = gym.make(ENV_ID).observation_space.shape[0]


# In[18]:


vec_mean = np.zeros((REPEATS, LENGTH, obs_shape))
vec_var = np.zeros((REPEATS, LENGTH, obs_shape))

for i in tqdm(range(REPEATS)):
    envs = gym.make_vec(ENV_ID, NUM_ENVS, vectorization_mode="sync")
    envs = gym.wrappers.vector.NormalizeObservationV0(envs)

    envs.reset()
    for j in range(LENGTH):
        envs.step(envs.action_space.sample())

        vec_mean[i, j] = envs.obs_rms.mean
        vec_var[i, j] = envs.obs_rms.var

    envs.close()


# In[19]:


single_mean = np.zeros((REPEATS, LENGTH, obs_shape))
single_var = np.zeros((REPEATS, LENGTH, obs_shape))

for i in tqdm(range(REPEATS)):
    env = gym.wrappers.AutoresetV0(gym.make(ENV_ID))
    env = gym.wrappers.NormalizeObservationV0(env)

    env.reset()
    for j in range(LENGTH):
        env.step(env.action_space.sample())

        single_mean[i, j] = env.obs_rms.mean
        single_var[i, j] = env.obs_rms.var

    env.close()


# In[20]:


fig, axs = plt.subplots(nrows=2, ncols=obs_shape, figsize=(12, 6))
fig.suptitle("Obs Mean")

axs[0, 0].set_ylabel("Vec")
axs[1, 0].set_ylabel("Single")
for i in range(obs_shape):
    for row in range(REPEATS):
        axs[0, i].plot(np.arange(LENGTH), vec_mean[row, :, i])
        axs[1, i].plot(np.arange(LENGTH), single_mean[row, :, i])

plt.tight_layout()


# In[21]:


fig, axs = plt.subplots(nrows=2, ncols=obs_shape, figsize=(12, 6))
fig.suptitle("Obs Var")

axs[0, 0].set_ylabel("Vec")
axs[1, 0].set_ylabel("Single")
for i in range(obs_shape):
    for row in range(REPEATS):
        axs[0, i].plot(np.arange(LENGTH), vec_var[row, :, i])
        axs[1, i].plot(np.arange(LENGTH), single_var[row, :, i])

plt.tight_layout()


# In[3]:


## Testing with a seed
env = gym.wrappers.AutoresetV0(gym.make(ENV_ID))

env.action_space.seed(123)
actions = [env.action_space.sample() for _ in range(LENGTH)]

obs = np.zeros((LENGTH, obs_shape))
env.reset(seed=123)
for i, action in zip(range(LENGTH), actions):
    obs[i], _, _, _, _ = env.step(action)

env.reset(seed=123)
env = gym.wrappers.NormalizeObservationV0(env)
for action in actions:
    env.step(action)

print(f'Actual mean={np.mean(obs, axis=0)}, var={np.var(obs, axis=0)}')
print(f'Normalize mean={env.obs_rms.mean}, var={env.obs_rms.var}')


# In[29]:


from gymnasium.vector import SyncVectorEnv


def thunk():
    return GenericTestEnv(
        observation_space=Box(
            low=np.array([0, -10, -5], dtype=np.float32),
            high=np.array([10, -5, 10], dtype=np.float32),
        )
    )


rtols = np.zeros((250, 2, 3))
for i in tqdm(range(250)):
    vec_env = SyncVectorEnv([thunk for _ in range(NUM_ENVS)])
    vec_env = gym.wrappers.vector.NormalizeObservationV0(vec_env)

    vec_env.reset()
    for _ in range(250):
        vec_env.step(vec_env.action_space.sample())

    env = gym.wrappers.AutoresetV0(thunk())
    env = gym.wrappers.NormalizeObservationV0(env)
    env.reset()
    for _ in range(250 * 3):
        env.step(env.action_space.sample())

    mean_rtol = np.abs(env.obs_rms.mean - vec_env.obs_rms.mean) / np.abs(vec_env.obs_rms.mean)
    var_rtol = np.abs(env.obs_rms.var - vec_env.obs_rms.var) / np.abs(vec_env.obs_rms.var)
    # print(f'{mean_rtol=}, {var_rtol=}')
    rtols[i] = np.array([mean_rtol, var_rtol])

    # assert np.allclose(env.obs_rms.mean, vec_env.obs_rms.mean, rtol=0.07)
    # assert np.allclose(env.obs_rms.var, vec_env.obs_rms.var, rtol=0.07)


# In[30]:


fig, axs = plt.subplots(nrows=2, ncols=3)

axs[0, 0].set_ylabel("Mean")
axs[1, 0].set_ylabel("Var")
for i in range(3):
    axs[0, i].hist(rtols[:, 0, i], bins=25)
    axs[1, i].hist(rtols[:, 1, i], bins=25)

@pseudo-rnd-thoughts pseudo-rnd-thoughts merged commit 02b7d6d into Farama-Foundation:v1.0.0 Aug 10, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant