# Gymnasium Encoder Example
This example illustrates how to encode the observations of cartpole from the gymnasium package.

Prerequisites:
- gymnasium
- gymnasium[classic-control]
- This library

In [9]:
!pip install gymnasium 
!pip install gymnasium[classic-control]



In [10]:
import gymnasium as gym
from encoding.gymnasium_bounds_finder import ScalerFactory
from encoding.gymnasium_encoder import GymnasiumEncoder
import numpy as np

In [11]:
cartpole_env = gym.make("CartPole-v1")
# NOTE: the encoder is built around encoding multiple states at once
# Great for running multiple simulations in batches
batch_size = 1

# Get first observation: [cart position, cart velocity, cart angle (radians), cart angular velocity]
first_observation = cartpole_env.reset(seed=1970)[0]
print("First observation:", first_observation)

# Prepare state for encoding (needs to be np.array for batch processing)
state_to_encode = np.array([first_observation])

First observation: [ 0.00662101 -0.02290802 -0.00224132  0.00596699]


In [12]:
# Number of timesteps in the resulting spike train
seq_length = 10

# When True, negative values will be separated into different columns
# Example with exc_inh=True:  Observation: [-0.66] → Spike train: [[-1], [0], [-1]]
# Example with exc_inh=False: Observation: [-0.66] → Spike train: [[0, 1], [0, 0], [0, 1]]
exc_inh = False

# Set up scaling for the observations
scaler_factory = ScalerFactory()
max_values = np.array([2.4, 10, 0.2095, 10])
min_values = -max_values
scaler = scaler_factory.from_known_values(min_values, max_values)

In [13]:
# Create encoder
encoder = GymnasiumEncoder(
    cartpole_env.observation_space.shape[0],
    batch_size,
    seq_length,
    scaler,
    rate_coder=True,
    step_coder=False,
    split_exc_inh=exc_inh,
)

# Encode first state
encoded_first_state = encoder.encode(state_to_encode)
print("Encoded first state:")
print(encoded_first_state)

Encoded first state:
[[[1 1 1 1]]

 [[1 0 0 1]]

 [[0 1 1 0]]

 [[1 0 0 1]]

 [[0 1 1 0]]

 [[1 0 0 1]]

 [[0 1 1 0]]

 [[1 0 0 1]]

 [[0 0 0 0]]

 [[1 1 1 1]]]


In [14]:
# Get and encode second state
second_observation = cartpole_env.step(1)[0]
print("Second observation:", second_observation)

second_state_to_encode = np.array([second_observation])
encoded_second_state = encoder.encode(second_state_to_encode)
print("\nEncoded second state:")
print(encoded_second_state)

Second observation: [ 0.00616285  0.17224601 -0.00212199 -0.28742227]

Encoded second state:
[[[1 1 1 1]]

 [[1 1 0 0]]

 [[0 0 1 1]]

 [[1 1 0 0]]

 [[0 0 1 1]]

 [[1 1 0 0]]

 [[0 0 1 1]]

 [[1 1 0 0]]

 [[0 0 0 0]]

 [[1 1 1 1]]]
