This is a fresh start on experiments in stackelberg actor critic. Only write down what is sure to be right!

In [None]:
from environments import ENV_NAMES

from models.critic import Critic, PixelCritic
from models.discrete_actor import DiscreteActor, DiscretePixelActor
import models.params
from models.params import DynParam

import flax
import flax.linen as nn
from flax.training.train_state import TrainState
import functools
import gymnax
import jax
from jax import grad, jacfwd, jacrev
import jax.numpy as jnp
import optax
from jax import flatten_util
from jax.scipy.sparse.linalg import cg
import json


from algos.core.env_config import Hyperparams
# from algos.core.env_config import ENV_CONFIG
from algos.StackelbergRL.understanding_gradients import cosine_similarity, project_B_onto_A

Run one rollout and transitions and last observation for advantage calculation

In [None]:
@flax.struct.dataclass
class Transition:
    """A data class that stores a state transition."""
    observation: jnp.ndarray
    action: jnp.ndarray
    reward: jnp.ndarray
    done: jnp.ndarray

def run_rollout(env, env_params, length, actor_state, rng_key):
    """Collects an actor policy rollout with a fixed number of steps."""
    rng_key, reset_key = jax.random.split(rng_key, 2)
    observation, env_state = env.reset(reset_key, env_params)

    def step(rollout_state, x):
        """Advances the environment by 1 step by sampling from the policy."""
        # Sample action
        actor_state, env_state, observation, rng_key = rollout_state
        rng_key, action_key, step_key = jax.random.split(rng_key, 3)
        action_dist = actor_state.apply_fn(actor_state.params, observation)
        action = action_dist.sample(seed=action_key)

        # Run environment step
        next_observation, next_state, reward, done, i = env.step(
            step_key, env_state, action, env_params,
        )
        transition = Transition(
            observation, action, reward, done,
        )

        next_step = (actor_state, next_state, next_observation, rng_key)
        return (next_step, transition)

    rollout_state, transitions = jax.lax.scan(
        step,
        init=(actor_state, env_state, observation, rng_key),
        length=length,
    )
    a, n, last_observation, r = rollout_state
    return (transitions, last_observation)

Calculate the advantages using GAE and the targets using Monte Carlo as indicated by the ratliff paper. 

In [None]:

@jax.jit
def calc_values(critic_state, transitions, last_observation, discount_rate, gae_lambda = 0.97):
    """Calculates the advantage estimate at each time step."""
    values = jax.vmap(critic_state.apply_fn, in_axes=(None, 0))(critic_state.params, transitions.observation)
    last_val = critic_state.apply_fn(critic_state.params, last_observation)

    def _calculate_gae(values, last_val):
        def _get_advantages(gae_and_next_value, transition):
            gae, next_value = gae_and_next_value
            done, value, reward = transition
            delta = reward + discount_rate * next_value * (1 - done) - value
            gae = (
                delta
                + discount_rate * gae_lambda * (1 - done) * gae
            )
            return (gae, value), gae

        _, advantages = jax.lax.scan(
            _get_advantages,
            init = (jnp.zeros_like(last_val), last_val),
            xs = (transitions.done, values, transitions.reward),
            reverse=True,
            unroll=16,
        )

        # def _get_targets(cummulative_reward, transition):
        #     reward, done = transition
        #     cum_rew = reward + discount_rate * cummulative_reward * (1 - done)
        #     return cum_rew, cum_rew

        # _, targets = jax.lax.scan(
        #     _get_targets, 
        #     init = last_val,
        #     xs = (transitions.reward, transitions.done),
        #     reverse=True,
        # )

        return advantages, advantages + values

    advantages, targets = _calculate_gae(values, last_val)

    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    return (advantages, targets)

The update functions for both critic and actor

In [None]:

# follower-objective
def target_loss(params, transitions, targets, critic_state):
    """Calculates the mean squared error on a batch of transitions."""
    values = jax.vmap(critic_state.apply_fn, in_axes=(None, 0))(params, transitions.observation)
    errors = jnp.square(targets - values)
    # jax.debug.print(f"targets shape: {targets.shape()} values shape: {values.shape()}")
    return jnp.mean(errors)

def update_leaderactor(actor_state, critic_state, transitions, advantages, targets, vanilla=False, lambda_reg=0):
    # Define the loss functions
    def advantage_loss(params, transitions, advantages, lambda_reg = 0.01):
        action_dists = jax.vmap(actor_state.apply_fn, in_axes=(None, 0))(params, transitions.observation)
        log_probs = action_dists.log_prob(transitions.action)

        # l2_loss = lambda_reg * sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(params))

        return -jnp.mean(advantages * log_probs) 
    
    def leader_f2_loss(actor_params, critic_params, transitions, targets):
        action_dists = jax.vmap(actor_state.apply_fn, in_axes=(None, 0))(actor_params, transitions.observation)
        log_probs = action_dists.log_prob(transitions.action)
        values = jax.vmap(critic_state.apply_fn, in_axes=(None, 0))(critic_params, transitions.observation)

        # l2_loss = lambda_reg * sum(jnp.sum(jnp.square(p)) for p in jax.tree_util.tree_leaves(critic_params))
        
        return 2 * jnp.mean(log_probs * advantages) * (targets[0] - values[0])
    
    # Single Gradients
    vgj = jax.value_and_grad(advantage_loss)
    actor_loss, grad_theta_J = vgj(actor_state.params, transitions, advantages)
    grad_w_J = grad(target_loss, 0)(critic_state.params, transitions, targets, critic_state)

    def hvp(v):
        critic_params_flat, unravel_fn = jax.flatten_util.ravel_pytree(critic_state.params)
        def loss_grad_flat(p):
            return jax.flatten_util.ravel_pytree(
                jax.grad(target_loss, argnums=0)(unravel_fn(p), transitions, targets, critic_state)
            )[0]
        hvp = jax.jvp(loss_grad_flat, (critic_params_flat,), (v,))[1] + lambda_reg * v
        return hvp
    
    grad_w_J_flat, unflatten_fn = jax.flatten_util.ravel_pytree(grad_w_J)
    def cg_solve(v):
        return jax.scipy.sparse.linalg.cg(hvp, v, maxiter=10, tol=1e-10)[0]
    inverse_hvp_flat = cg_solve(grad_w_J_flat)
    inverse_hvp = unflatten_fn(inverse_hvp_flat)

    # 6. Compute mixed gradient and its transpose: [∇²_θ,ν V_s(ν, θ*(ν))]^T
    def mixed_grad_fn(policy_params, critic_params):
        return jax.grad(leader_f2_loss)(policy_params, critic_params, transitions, targets)

    # 7. Compute the final product: [∇²_θ,ν V_s(ν, θ*(ν))]^T * [∇²_θ V_s(ν, θ*(ν))]^(-1) * ∇_θ L_pref(ν)
    # We use JVP to compute this product efficiently
    _, final_product = jax.jvp(
        lambda p: mixed_grad_fn(actor_state.params, p),
        (critic_state.params,),
        (inverse_hvp,)
    )

    # vanilla=True

    # final_product = clip_grad_norm(final_product, 0.2*optax.global_norm(grad_theta_J))
    hypergradient = jax.tree_util.tree_map(lambda x, y: x - y, grad_theta_J, final_product)
    hypergradient = jax.lax.cond(vanilla, lambda: grad_theta_J, lambda: hypergradient)
    actor_state = actor_state.apply_gradients(grads=hypergradient)
    
    #print all norms
    hypergradient_norms = optax.global_norm(hypergradient)
    final_product_norms = optax.global_norm(final_product)
    co_sim = cosine_similarity(final_product, grad_theta_J)

    return (actor_state, (hypergradient_norms, final_product_norms, co_sim), actor_loss)

def clip_grad_norm(grad, max_norm):
    norm = optax.global_norm(grad)
    factor = jnp.minimum(max_norm, max_norm / (norm + 1e-6))
    return jax.tree_map((lambda x: x * factor), grad)

def update_critic(critic_state, transitions, targets):
    """Calculates and applies the value target gradient at each time step."""
    target_grad = jax.value_and_grad(target_loss)
    loss, grads = target_grad(critic_state.params, transitions, targets, critic_state)
    critic_state = critic_state.apply_gradients(grads=grads)
    return (critic_state, loss)

The update functions to perform update on a batch

In [None]:
def calc_episode_rewards(transitions):
    """Calculates the total real reward for each episode."""
    def calc_reward(prev_total, transition):
        """Adds the current reward to the total."""
        total = prev_total + transition.reward
        next_total = total * (1 - transition.done)
        return (next_total, total)

    s, rewards = jax.lax.scan(
        calc_reward,
        init=jnp.float32(0),
        xs=transitions,
    )
    return rewards

def run_update(env, env_params, actor_state, critic_state, rng_key, hyperparams, vanilla=False, lam=0):
    """Runs an iteration of the training loop with the vanilla parallel update."""
    rng_key, rollout_key = jax.random.split(rng_key, 2)
    transitions, last_observation = run_rollout(env, env_params, hyperparams.rollout_len, actor_state, rollout_key)
    advantages, targets = calc_values(critic_state, transitions, last_observation, hyperparams.discount_rate)

    actor_state, actor_info, actor_loss = update_leaderactor(actor_state, critic_state, transitions, advantages, targets, vanilla, lam)
    critic_loss = 0
    for c in range(hyperparams["nested_updates"]):
        critic_state, critic_loss = update_critic(critic_state, transitions, targets)

    total_rewards = calc_episode_rewards(transitions)
    average_reward = jnp.sum(total_rewards * transitions.done) / jnp.sum(transitions.done)
    return (actor_state, critic_state, (average_reward, actor_info, actor_loss, critic_loss), rng_key)

@functools.partial(jax.jit, static_argnums=0)
def run_batch(env, env_params, actor_state, critic_state, rng_key, hyperparams, vanilla=False, lam=0):
    """Trains the model for a batch of updates."""
    def run_once(batch_state, x):
        """Runs an update and carries over the train state."""
        actor_state, critic_state, rng_key = batch_state
        actor_state, critic_state, metrics, rng_key = \
            run_update(env, env_params, actor_state, critic_state, rng_key, hyperparams, vanilla, lam)
        return ((actor_state, critic_state, rng_key), metrics)

    batch_state, batch_metrics = jax.lax.scan(
        run_once,
        init=(actor_state, critic_state, rng_key),
        length=hyperparams.batch_count,
    )
    actor_state, critic_state, rng_key = batch_state
    return (actor_state, critic_state, batch_metrics, rng_key)


Main function

In [None]:
ENV_CONFIG = {
    "cartpole": {
        "actor_model": DiscreteActor,
        "actor_params": [(64, 2), DynParam.ActionCount],  # Two hidden layers with 64 units each
        "critic_model": Critic,
        "critic_params": [(64, 2)],
        "hyperparams": Hyperparams(
            num_updates=500,
            batch_count=25,
            rollout_len=2000,
            discount_rate=0.99,
            actor_learning_rate=0.0025,
            nested_updates=25,
            critic_learning_rate=0.008,
            adam_eps=1e-5,
        ),
    },
}

def train(env_key, seed, logger, verbose = False, metrics=None, vanilla=False, save_charts=False, description=None, lam=0):
    # Create environment
    config = ENV_CONFIG[env_key]
    hyperparams = config["hyperparams"]

    rng_key, actor_key, critic_key = jax.random.split(jax.random.key(seed), 3)
    env, env_params = gymnax.make(ENV_NAMES[env_key])
    empty_observation = jnp.empty(env.observation_space(env_params).shape)

    # Initialize actor model
    actor_model_params = models.params.init(env, env_params, config["actor_params"])
    actor = config["actor_model"](*actor_model_params)
    actor_params = actor.init(actor_key, empty_observation)

    # Initialize critic model
    critic_model_params = models.params.init(env, env_params, config["critic_params"])
    critic = config["critic_model"](*critic_model_params)
    critic_params = critic.init(critic_key, empty_observation)

    # Create actor and critic train states
    actor_state = TrainState.create(
        apply_fn=jax.jit(actor.apply),
        params=actor_params,
        tx=optax.adam(hyperparams["actor_learning_rate"], eps=hyperparams.adam_eps),
    )
    critic_state = TrainState.create(
        apply_fn=jax.jit(critic.apply),
        params=critic_params,
        tx=optax.adam(hyperparams["critic_learning_rate"], eps=hyperparams.adam_eps),
    )
    
    # Set logger info
    logger.set_interval(hyperparams.rollout_len)

    # Run the training loop
    num_batches = int(hyperparams.num_updates / hyperparams.batch_count)
    for b in range(num_batches):
        actor_state, critic_state, batch_metrics, rng_key = \
            run_batch(env, env_params, actor_state, critic_state, rng_key, hyperparams, vanilla, lam)
        logger.log_metrics({
            "reward": batch_metrics[0],
            "actor_loss": batch_metrics[2],
            "critic_loss": batch_metrics[3],
            "hypergradient": batch_metrics[1][0],
            "final_product": batch_metrics[1][1],
            "cosine_similarities": batch_metrics[1][2]
        })
        if verbose:
            print(f"[Update {(b + 1) * hyperparams.batch_count}]: Average reward {batch_metrics[0][-1]}, Hypergradient Norm {jnp.mean(batch_metrics[1][0])}, finalProduct Norm{jnp.mean(batch_metrics[1][1])}, cosine similarity {batch_metrics[1][2][-1]}")

Run one experiment!

In [None]:
import argparse
import jax
import os

from algos.StackelbergRL import ratliff, stac_critic, stac_Actor_newGrad, stac_Critic
from algos.baselines import discrete_actor_critic, discrete_ppo, discrete_reinforce, actor_critic_NoNesting
from bilevel_actor_critic import unrolling_actor_redo, lambda_regret

from loggers.chart_logger import ChartLogger
from algos.core.config import ALGO_CONFIG

algo = "cartpole"

metrics = [
    "reward",
    "actor_loss",
    "critic_loss",
    "hypergradient",
    "final_product",
    "cosine_similarities"
] 
logger = ChartLogger(metrics)

config = ALGO_CONFIG[algo]
description = config["description"]

if not(args.description==""):
        description = args.description

folder_path = f"charts/{args.algo}/{args.task}_{description}"
for metric in metrics:
        file_path = f"{folder_path}/{args.task}_{metric}.png"
        
        logger.set_info(
            metric,
            f"[{args.task}] SA2C {metric}",
            file_path,
        )

algo = algos[args.algo]
# Ensure the data directory for the task exists
os.makedirs(f'data/{args.task}', exist_ok=True)
if not(vanilla):
    # print("yes!")
    algo.train(args.task, 0, logger, verbose=True, lam=args.lam)
else:
    algo.train(args.task, 0, logger, verbose=True)
logger.log_to_csv(f'data/{args.task}/{args.algo}_{description}.csv')

# Plot metrics
if(args.plot):
    os.makedirs(folder_path, exist_ok=True)
    for m in metrics:
        logger.plot_metric(m)

train("cartpole", 0, )