# Lotus Demonstration Notebook

Install requirements

In [None]:
!pip install -q git+https://github.com/auxeno/lotus

Imports

In [None]:
import jax
import jax.numpy as jnp
import time

### Train a Single Agent

Easily train a single agent on MinAtar's Breakout environment.

In [None]:
from lotus import PPO

# Create seed and params
agent = PPO.create(
    env='Breakout-MinAtar'
    )
seed = 0

trained_agent = PPO.train(agent, seed)

### Train Agents on Multiple Seeds

Training multiple agents in parallel on 100 seeds.

In [None]:
from lotus import PQN

# Create agent and seeds
agent = PQN.create(
    env='CartPole-v1',
    hidden_dims=(16, 16),
    verbose=False
    )
num_seeds = 100
seeds = jnp.arange(num_seeds)

# Start timing
start = time.time()

# Vectorised training
train_fn = jax.vmap(agent.train, in_axes=(None, 0))
trained_agents = train_fn(agent, seeds)

# End timing
end = time.time()
print(f'Time taken: {(end - start):.1f} seconds')
print(f'FPS: {(num_seeds * 1_000_000 / (end - start)):,.1f}')

### Train Agents with Multiple Configurations

Training multiple agents in parallel with 10 different learning rates.

In [None]:
def make_agent(lr: float):
    return PQN.create(
        env='CartPole-v1',
        learning_rate=lr
    )

# Create seed and agents
agents = jax.vmap(make_agent)(jnp.linspace(1e-4, 1e-3, 10))
seed = 0

# Vectorised training
train_fn = jax.vmap(agents.train, in_axes=(0, None))
trained_agent = train_fn(agents, seed)

### Train Agents with Multiple Seeds and Configurations

Training multiple agents in parallel with 10 different learning rates, each on 100 seeds.

In [None]:
def make_agent(lr: float):
    return PQN.create(
        env='CartPole-v1',
        learning_rate=lr
    )

# Create seed and params
agents = jax.vmap(make_agent)(jnp.linspace(1e-4, 1e-3, 10))
seeds = jnp.arange(100)

# Vectorised training
train_fn = jax.vmap(agents.train, in_axes=(0, None))
train_fn = jax.vmap(train_fn, in_axes=(None, 0))
trained_agent = train_fn(agents, seeds)