# 🪷 Lotus Demonstration Notebook

To get started with Lotus in Google Colab, install with `pip` and ensure that your runtime has access to a hardware accelerator (GPU or TPU).

In [None]:
!pip install -q git+https://github.com/auxeno/lotus

In [None]:
import jax
import jax.numpy as jnp
import time

print('JAX device:', jax.devices())

### Train a Single Agent

Easily train a single agent on the MinAtar Breakout environment.

In [None]:
from lotus import PPO

# Create seed and params
agent = PPO.create(
    env='Breakout-MinAtar'
)
seed = 0

# Train agent
trained_agent = PPO.train(agent, seed)

In [None]:
from lotus.plotting import plot_results

# Plot Breakout training results
plot_results(
    {'PPO': trained_agent['logs']},
    title='Breakout Episodic Reward'
)

### Train Agents on Multiple Seeds

Training multiple PQN CartPole agents in parallel on 100 seeds.

In [None]:
from lotus import PQN

# Create agent and seeds
agent = PQN.create(
    env='CartPole-v1',
    hidden_dims=(32, 32),
    verbose=False
)
num_seeds = 100
seeds = jnp.arange(num_seeds)

# Start timing
start = time.time()

# Vectorised training
train_fn = jax.vmap(agent.train, in_axes=(None, 0))
trained_agents = train_fn(agent, seeds)

# End timing
end = time.time()
print(f'Time taken to train 100 agents: {(end - start):.1f} seconds')
print(f'FPS: {(num_seeds * 1_000_000 / (end - start)):,.1f}')

### Train Agents with Multiple Configurations

Training multiple PQN CartPole agents in parallel with multiple λ values.

In [None]:
def create_agent(td_lambda: float):
    return PQN.create(
        env='CartPole-v1',
        td_lambda=td_lambda,
        hidden_dims=(32, 32),
        verbose=False
    )

# Create agents and seed
agents = jax.vmap(create_agent)(jnp.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]))
seed = 0

# Vectorised training
train_fn = jax.vmap(agents.train, in_axes=(0, None))
trained_agents = train_fn(agents, seed)

# Gather results
results = {
    f'λ={float(key)}': jax.tree.map(lambda x: x[i], trained_agents['logs'])
    for i, key in enumerate(jnp.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]))
}

# Plot results
plot_results(
    results,
    title='PQN CartPole Episodic Reward Varying λ',
    colors='gradient'
)

### Train Agents with Multiple Seeds and Configurations

Training multiple CartPole agents in parallel with multiple λ values, each on 100 seeds.

In [None]:
def create_agent(lr: float):
    return PQN.create(
        env='CartPole-v1',
        td_lambda=td_lambda,
        hidden_dims=(32, 32),
        verbose=False
    )

# Create agents and seeds
agents = jax.vmap(create_agent)(jnp.array([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]))
seeds = jnp.arange(100)

# Vectorised training
train_fn = jax.vmap(agents.train, in_axes=(0, None))
train_fn = jax.vmap(train_fn, in_axes=(None, 0))
trained_agent = train_fn(agents, seeds)