# Context Sampling In CARL

Let's take a look at how we can sample contexts and use them in the environments. We'll use CARLBraxAnt for a little demonstration.

In [2]:
from carl.context.context_space import NormalFloatContextFeature
from carl.context.sampler import ContextSampler
from carl.envs import CARLBraxAnt

/bigwork/nhwpeimt/miniconda3/envs/carl/lib/python3.9/site-packages/glfw/__init__.py:916: GLFWError: (65544) b'X11: The DISPLAY environment variable is missing'


Each environment has an associated context space. Before even instantiating the environment, it let's you take a look at which features can be used and what their default values and bounds are.

In [3]:
print(f"Context feature names for Ant: {CARLBraxAnt.get_context_space().context_feature_names}")
print(f"Default context for Ant: {CARLBraxAnt.get_context_space().get_default_context()}")
print(f"Context value bounds for friction in Ant: {CARLBraxAnt.get_context_space().get_lower_and_upper_bound('friction')}")

Context feature names for Ant: ['gravity', 'friction', 'elasticity', 'ang_damping', 'mass_torso', 'viscosity']
Default context for Ant: {'gravity': -9.8, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0}
Context value bounds for friction in Ant: (0.0, 100.0)


  "ang_damping": UniformFloatContextFeature(


We can use the built-in context sampler to get context values for training. Here we decide we want a normal distribution of float values for the 'gravity' context feature. The context space makes sure we stay within the bounds for the environment. Let's start with 5 contexts for now.

In [4]:
seed = 0
context_distributions = [NormalFloatContextFeature("gravity", mu=9.8, sigma=1)]
context_sampler = ContextSampler(
        context_distributions=context_distributions,
        context_space=CARLBraxAnt.get_context_space(),
        seed=seed,
    )
contexts = context_sampler.sample_contexts(n_contexts=5)
print(contexts)

{0: {'gravity': 11.564052345967665}, 1: {'gravity': 10.200157208367225}, 2: {'gravity': 10.77873798410574}, 3: {'gravity': 12.04089319920146}, 4: {'gravity': 11.667557990149968}}


To use the contexts during training, we simply pass them to the environment:

In [5]:
env = CARLBraxAnt(contexts=contexts)
print(f"Full context set: {env.contexts}")
env.reset()
print(f"Current context ID: {env.context_id}")
print(f"Current context: {env.context}")

  logger.deprecation(
  "ang_damping": UniformFloatContextFeature(


Full context set: {0: {'gravity': 11.564052345967665, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0}, 1: {'gravity': 10.200157208367225, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0}, 2: {'gravity': 10.77873798410574, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0}, 3: {'gravity': 12.04089319920146, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0}, 4: {'gravity': 11.667557990149968, 'friction': 1.0, 'elasticity': 0.0, 'ang_damping': -0.05, 'mass_torso': 10.0, 'viscosity': 0.0}}
Current context ID: 0
Current context: {'gravity': 11.564052345967665}


If we don't specify a context selector, a reset will automatically switch the context to the next one in our context set.

In [6]:
env.reset()
print(f"Current context ID: {env.context_id}")
print(f"Current context: {env.context}")

Current context ID: 1
Current context: {'gravity': 10.200157208367225}


We can also manually set the context by using its ID:

In [7]:
env.context_id = 4
print(f"Current context ID: {env.context_id}")
print(f"Current context: {env.context}")

Current context ID: 4
Current context: {'gravity': 11.667557990149968}


Apart from the context, CARLBraxAnt functions like any other gymnasium environment - so your training loops don't have to change at all.

In [8]:
done = False
while not done:
    action = env.action_space.sample()
    state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    env.render()

ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * most axes (56 of them) had size 1, e.g. axis 0 of argument state.pipeline_state.q of type float32[1,15];
  * one axis had size 8: axis 0 of argument action of type float32[8]