In [1]:
import jax.numpy as jnp
import numpy as np
import time
import chex
import jax
import gymnasium as gym
import exciting_environments as excenv
import random

## Comparison Gym and ExcEnv:

### Batch_Size = 1 :

In [3]:
BATCH_SIZE=1
env=excenv.make('CartPole-v0',batch_size=BATCH_SIZE)
env_gym= gym.make("CartPole-v1")
env.reset(random_key=jax.random.PRNGKey(9))
env_gym.reset(seed=9)
action_gym=env_gym.action_space.sample()
action_exc=env.action_space.sample(jax.random.PRNGKey(34))
print("Gym:")
%timeit env_gym.step(action_gym)
print("ExcEnv:")
%timeit env.step(action_exc)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Gym:


  logger.warn(


4.01 µs ± 4.73 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
ExcEnv:
8.23 µs ± 10.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### Batch_Size = 10:

In [4]:
BATCH_SIZE=10
env=excenv.make('CartPole-v0',batch_size=BATCH_SIZE)
env_gym= gym.vector.make("CartPole-v1",num_envs=BATCH_SIZE)
env.reset(random_key=jax.random.PRNGKey(9))
env_gym.reset(seed=9)
action_gym=env_gym.action_space.sample()
action_exc=env.action_space.sample(jax.random.PRNGKey(34))

In [5]:
print("Gym:")
%timeit env_gym.step(action_gym)

Gym:
217 µs ± 8.42 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
print("ExcEnv:")
%timeit env.step(action_exc)

ExcEnv:
8.38 µs ± 13 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### Batch_Size=100

In [7]:
BATCH_SIZE=100
env=excenv.make('CartPole-v0',batch_size=BATCH_SIZE)
#env_gym= gym.vector.make("CartPole-v1",num_envs=BATCH_SIZE)
env.reset(random_key=jax.random.PRNGKey(9))
#env_gym.reset(seed=9)
#action_gym=env_gym.action_space.sample()
action_exc=env.action_space.sample(jax.random.PRNGKey(34))


In [8]:
print("Gym:")
%timeit env_gym.step(action_gym)

Gym:
213 µs ± 3.54 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
print("ExcEnv:")
%timeit env.step(action_exc)

ExcEnv:
11.2 µs ± 43.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### Batch_Size > 1 :

#### ExcEnv:

In [12]:
sizes=[10,100,250,500,1000,10000]
for size in sizes:
    env=excenv.make('CartPole-v0',batch_size=size)
    env.reset(random_key=jax.random.PRNGKey(9))
    action_exc=env.action_space.sample(jax.random.PRNGKey(34))
    print(f'''Batch_size={size}''')
    %timeit env.step(action_exc)
    print("\n")

Batch_size=10
8.53 µs ± 34.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Batch_size=100
11.2 µs ± 18.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Batch_size=250
15.2 µs ± 58.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


Batch_size=500
23.3 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Batch_size=1000
39.2 µs ± 56.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Batch_size=10000
384 µs ± 5.97 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)




#### Gym:

#### asynchronous = True

In [8]:
sizes=[10,20,30,40,50]
for size in sizes:
    env_gym=gym.vector.make("CartPole-v1",num_envs=size,asynchronous=True)
    env_gym.reset(seed=9)
    action_gym=env_gym.action_space.sample()
    print(f'''Batch_size={size}''')
    %timeit env_gym.step(action_gym)
    print("\n")

Batch_size=10
219 µs ± 6.43 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Batch_size=20
378 µs ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Batch_size=30
528 µs ± 15.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)




#### asynchronous = False

In [4]:
sizes=[10,20,30,40,50]
for size in sizes:
    env_gym=gym.vector.make("CartPole-v1",num_envs=size,asynchronous=False)
    env_gym.reset(seed=9)
    action_gym=env_gym.action_space.sample()
    print(f'''Batch_size={size}''')
    %timeit env_gym.step(action_gym)
    print("\n")

Batch_size=10
74.5 µs ± 498 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Batch_size=20
136 µs ± 266 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Batch_size=30
197 µs ± 927 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Batch_size=40
256 µs ± 387 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Batch_size=50
314 µs ± 407 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)




In [6]:
BATCH_SIZE=10
env_gym= gym.vector.make("CartPole-v1",num_envs=BATCH_SIZE,asynchronous=False)
env_gym.reset(seed=9)
action_gym=env_gym.action_space.sample()


In [7]:
print("Gym:")
%timeit env_gym.step(action_gym)

Gym:
73.6 µs ± 254 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
act=jnp.zeros(10).reshape(-1,1)
%timeit env.step(act)

## Simulation Time:

In [88]:
BATCH_SIZE=2
env=excenv.make('CartPole-v0',batch_size=BATCH_SIZE)
env_gym= gym.vector.make("CartPole-v1",num_envs=BATCH_SIZE)

In [89]:
data_batch1 = np.zeros(5000)
act=[env.action_space.sample(jax.random.PRNGKey(random.randint(0,100000))) for _ in range(len(data_batch1))]
act_gym=[env_gym.action_space.sample() for _ in range(len(data_batch1))]
init_gym=env_gym.reset(seed=9)
env.reset(initial_values=init_gym)


(array([[ 0.03702492, -0.02131828,  0.01031481,  0.02775341],
        [ 0.04560017, -0.02923182,  0.03284449, -0.03507179]],
       dtype=float32),
 {})

In [90]:
env.reset(random_key=jax.random.PRNGKey(9))
start_t = time.time()
for i in range(5000):
    #obs,reward,a,b,_ = env.step(env.action_space.sample(jax.random.PRNGKey(random.randint(0,100000))))
    obs,reward,a,b,_ = env.step(act[i])
    #data_batch1[i]=(np.array(obs)[1,0]*np.pi)
ex_time= time.time() -start_t
print(f"Jax Compilation Time: {ex_time} \n")

env_gym.reset(seed=9)
start_t = time.time()
for i in range(5000):
    #obs,reward,a,b,_ = env.step(env.action_space.sample(jax.random.PRNGKey(random.randint(0,100000))))
    obs,reward,a,b,_ = env_gym.step(act_gym[i])
    #data_batch1=jnp.append(data_batch1,obs[0][0]).block_until_ready()
    #data_batch1[i]=(np.array(obs)[1,0]*np.pi)
ex_time= time.time() -start_t
print(f"Gym Compilation Time: {ex_time}")

Jax Compilation Time: 0.1261141300201416 

Gym Compilation Time: 0.5104639530181885


In [None]:
act=jnp.zeros(2).reshape(-1,1)
data_batch1 = np.zeros(5000)
for i in range(5000):
    obs,reward,a,b,_ = env.step(act)
    data_batch1[i]=(np.array(obs)[1,0]*np.pi)

In [None]:
import matplotlib.pyplot as plt

ax =plt.plot(np.sin(data_batch1),np.cos(data_batch1))
plt.xlabel('x')
plt.ylabel('y')
plt.xlim(-1,1)
plt.ylim(-1,1)
plt.plot([0,0], marker='+', ls= 'none')