In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'


from poclaps.train.ckpt_cb import load_ckpt
from poclaps.train.ppo import make_train as make_ppo_train
from poclaps.train.ppo import FlattenObservationWrapper, LogWrapper
from poclaps import environments

from pathlib import Path
import yaml


run_dir = Path('outputs/2024-06-14/15-39-35/')


def load_config(run_dir):
    with open(f'{run_dir}/.hydra/config.yaml') as f:
        config = yaml.safe_load(f)
        config['output_dir'] = run_dir
    return config

config = load_config(run_dir)

init_state, train_fn = make_ppo_train(config)

print(config)

def load_pretrained_policy(run_dir, config, ckpt_step=195):
    init_state, _ = make_ppo_train(config)
    ckpt = load_ckpt(run_dir / 'checkpoints', ckpt_step, init_state)
    train_state, *_ = ckpt

    def pretrained_policy(obs):
        return train_state.apply_fn(train_state.vars, obs)

    return pretrained_policy

pretrained_policy = load_pretrained_policy(run_dir, config)
print('Loaded policy checkpoint.')

env, env_params = environments.make(config["env_name"],
                                    **config.get('env_kwargs', {}))
env = FlattenObservationWrapper(env)
env = LogWrapper(env)

I0000 00:00:1719770809.126435 2642041 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


{'algorithm': 'PPO', 'learning_rate': 0.00025, 'num_envs': 4, 'num_steps': 128, 'total_timesteps': 100000.0, 'update_epochs': 4, 'num_minibatches': 4, 'gamma': 0.99, 'gae_lambda': 0.95, 'clip_eps': 0.2, 'ent_coef': 0.01, 'vf_coef': 0.5, 'max_grad_norm': 0.5, 'activation': 'tanh', 'anneal_lr': True, 'seed': 0, 'env_name': 'SimpleGridWorld-v0', 'env_kwargs': {'grid_size': 5, 'max_steps_in_episode': 20}, 'wandb_entity': 'drcope', 'wandb_project': 'ppo-gridworld-example', 'wandb_mode': 'online', 'output_dir': PosixPath('outputs/2024-06-14/15-39-35'), 'num_updates': 195.0, 'minibatch_size': 128}
Loaded policy checkpoint.


In [2]:
from poclaps.simple_gridworld_game import (
    EnvState as SimpleGridWorldEnvState,
    Envvars as SimpleGridWorldEnvParams,
)
import jax
import jax.numpy as jnp
import numpy as np
from flax import struct
from chex import Array
from typing import NamedTuple


class Transition(NamedTuple):
    env_state: struct.PyTreeNode
    done: Array
    action: Array
    message: Array
    reward: Array
    log_prob: Array
    obs: Array
    info: dict
    episode_id: int


class SimpleGridWorldCommPolicy:
    """
    """

    def __init__(self, seed: int, env_params: SimpleGridWorldEnvParams):
        self.seed = seed
        self.env_params = env_params
        self.n_msgs = env_params.grid_size * env_params.grid_size
        grid_indices = list(range(self.n_msgs))
        np.random.seed(seed)
        np.random.shuffle(grid_indices)
        self.msg_map = dict(enumerate(grid_indices))
        self.mapping = jnp.array(list(self.msg_map.values()))

    def get_msg(self, goal_pos: jnp.array) -> int:
        pos_idx = goal_pos[0] * self.env_params.grid_size + goal_pos[1]
        return self.mapping[pos_idx]

comm_policy = SimpleGridWorldCommPolicy(0, env_params)


def rollout_with_msgs(env, policy, comm_policy, steps, n_envs=4, rng=None, rollout_state=None):

    @jax.jit
    def _env_step(rollout_state, _):
        env_state, last_obs, rng, ep_ids = rollout_state

        # SELECT ACTION
        rng, _rng = jax.random.split(rng)
        pi, _ = policy(last_obs)
        action = pi.sample(seed=_rng)
        log_prob = pi.log_prob(action)

        # STEP ENV
        rng, _rng = jax.random.split(rng)
        rng_step = jax.random.split(_rng, n_envs)
        obsv, env_state, reward, done, info = jax.vmap(
            env.step, in_axes=(0, 0, 0, None)
        )(rng_step, env_state, action, env_params)
        ep_ids = jnp.where(done, ep_ids + n_envs, ep_ids)

        msg = jax.lax.map(
            lambda g: comm_policy.get_msg(g),
            env_state.env_state.goal_pos
        )

        transition = Transition(
            env_state, done, action, msg, reward, log_prob, last_obs, info, ep_ids
        )
        rollout_state = (env_state, obsv, rng, ep_ids)
        return rollout_state, transition

    if rollout_state is None:
        if rng is None:
            rng = jax.random.PRNGKey(0)
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, n_envs)
        obsv, env_state = jax.vmap(env.reset,
                                   in_axes=(0, None))(reset_rng, env_params)
        ep_ids = jnp.arange(n_envs)
        rollout_state = (env_state, obsv, rng, ep_ids)

    rollout_state, traj_batch = jax.lax.scan(
        _env_step, rollout_state, None, steps
    )

    metrics = {}

    metrics['mean_reward'] = (
        (traj_batch.info["returned_episode_returns"] * traj_batch.info["returned_episode"]).sum()
        / traj_batch.info["returned_episode"].sum()
    )

    metrics['mean_episode_len'] = (
        (traj_batch.info["returned_episode_lengths"] * traj_batch.info["returned_episode"]).sum()
        / traj_batch.info["returned_episode"].sum()
    )

    metrics['n_episodes'] = traj_batch.info["returned_episode"].sum()

    return rollout_state, traj_batch, metrics

In [3]:
import flax.linen as nn
import jax
import functools


class ScannedRNN(nn.Module):
    hidden_size: int = 128

    @functools.partial(
        nn.scan,
        variable_broadcast="vars",
        in_axes=0,
        out_axes=0,
        split_rngs={"params": False},
    )
    @nn.compact
    def __call__(self, carry, x):
        """Applies the module."""
        rnn_state = carry
        ins, resets = x
        rnn_state = jnp.where(
            resets[:, np.newaxis],
            self.initialize_carry(*rnn_state.shape),
            rnn_state,
        )
        new_rnn_state, y = nn.GRUCell(features=self.hidden_size)(rnn_state, ins)
        return new_rnn_state, y

    @staticmethod
    def initialize_carry(n_envs, hidden_size):
        # Use a dummy key since the default state init fn is just zeros.
        cell = nn.GRUCell(features=hidden_size)
        return cell.initialize_carry(jax.random.PRNGKey(0), (n_envs, hidden_size))


class ScannedBiRNN(nn.Module):
    hidden_size: int = 128

    @nn.compact
    def __call__(self, carry, inputs):
        forward_carry, backward_carry = carry
        forward_carry, forward_embs = ScannedRNN(self.hidden_size)(
            forward_carry, inputs
        )
        feats, resets = inputs
        backward_inputs = (feats[::-1], resets[::-1])
        backward_carry, backward_embs = ScannedRNN(self.hidden_size)(
            backward_carry, backward_inputs
        )
        carry = (forward_carry, backward_carry)
        embs = jnp.concatenate([forward_embs, backward_embs], axis=-1)
        return carry, embs

    @staticmethod
    def initialize_carry(n_envs, hidden_size):
        return (
            ScannedRNN.initialize_carry(n_envs, hidden_size),
            ScannedRNN.initialize_carry(n_envs, hidden_size)
        )

In [21]:
class ObsModel(nn.Module):
    obs_size: int
    hidden_size: int = 128

    @nn.compact
    def __call__(self, inputs: tuple, train: bool=False):
        _, noise = inputs
        # feats = jnp.concatenate([actions, messages, noise], axis=-1)
        feats = noise
        # *_, dones, obs = inputs
        # feats = obs

        x = nn.Dense(self.hidden_size)(feats)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x)

        x = nn.Dense(self.hidden_size)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x)

        # _, n_envs, _ = x.shape
        # carry = ScannedBiRNN.initialize_carry(n_envs, self.hidden_size)
        # bi_lstm = ScannedBiRNN(self.hidden_size)
        # _, x = bi_lstm(carry, (x, dones))

        x = nn.Dense(self.hidden_size)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x)

        obs_preds = nn.Dense(self.obs_size)(x)
        obs_preds = nn.sigmoid(obs_preds)
        return obs_preds


class Discriminator(nn.Module):
    hidden_size: int = 128

    @nn.compact
    def __call__(self, x):
        if isinstance(x, tuple):
            x, *_ = x
        x = nn.Dense(self.hidden_size)(x)
        x = nn.leaky_relu(x)
        # x = nn.Dropout(0.3)(x)
        x = nn.Dense(self.hidden_size)(x)
        x = nn.leaky_relu(x)
        # x = nn.Dropout(0.3)(x)
        x = nn.Dense(self.hidden_size)(x)
        x = nn.leaky_relu(x)
        # x = nn.Dropout(0.3)(x)
        x = nn.Dense(self.hidden_size)(x)
        x = nn.leaky_relu(x)
        # x = nn.Dropout(0.3)(x)
        y = nn.Dense(1)(x)
        return y.squeeze()


N_ACTIONS = env.action_space(env_params).n

def sample_obs_modelling_batch(rng,
                               noise_dim: int = 64,
                               rollout_steps: int = 500,
                               return_traj: bool = False):
    rng, rollout_rng = jax.random.split(rng)
    _, traj_batch, metrics = rollout_with_msgs(
        env, pretrained_policy, comm_policy,
        steps=rollout_steps,
        rng=rollout_rng
    )

    *batch_shape, _ = traj_batch.obs.shape
    batch_dim = np.prod(batch_shape)
    obs_dim = traj_batch.obs.shape[-1]
    obs = traj_batch.obs.reshape((batch_dim, obs_dim))
    noise = jax.random.normal(rng, (batch_dim, noise_dim))
    inputs = (noise, obs)
    if return_traj:
        return inputs, metrics, traj_batch
    return inputs, metrics

In [37]:
hidden_size = 128
obs_space = env.observation_space(env_params)
obs_size = obs_space.shape[0]
discriminator = Discriminator(4 * hidden_size)
obs_model = ObsModel(obs_size, 4 * hidden_size)


from typing import Any, Tuple, Callable
from flax import core
from flax import struct
import optax
from optax.losses import sigmoid_binary_cross_entropy

from poclaps.train.training_cb import TrainerCallback


def create_discriminator_sample(obs_model_vars, batch_inputs):
    # compute obs preds
    *_, obs = batch_inputs
    batch_size, _ = obs.shape
    # batch_size = steps * n_envs
    obs_preds = obs_model.apply(obs_model_vars, batch_inputs)
    # obs_preds = obs_preds.reshape((batch_size, obs_dim))
    # obs_batch = obs.reshape((batch_size, obs_dim))

    # compute inputs and labels
    labels = jnp.concatenate([jnp.ones(batch_size),
                              jnp.zeros(batch_size)])
    obs_inputs = jnp.concatenate([obs, obs_preds], axis=0)

    return obs_inputs, labels


def compute_discriminator_loss(model_vars, inputs):
    discriminator_vars, obs_model_vars = model_vars
    obs_inputs, labels = create_discriminator_sample(obs_model_vars, inputs)
    label_pred = discriminator.apply(discriminator_vars, obs_inputs)
    return sigmoid_binary_cross_entropy(label_pred, labels).mean()


def compute_adv_obs_model_loss(model_vars, inputs):
    discriminator_vars, obs_model_vars = model_vars

    # # compute action preds loss
    obs_preds, updates = obs_model.apply(
        obs_model_vars, inputs, train=True,
        mutable=['batch_stats']
    )
    # pred_action_dist, _ = pretrained_policy(obs_preds)
    # actions, *_ = inputs
    # act_pred_loss = categorical_cross_entropy(
    #     pred_action_dist.logits, actions
    # ).mean()

    # compute adversarial loss
    discr_pred = discriminator.apply(
        discriminator_vars, obs_preds
    )
    discr_labels = jnp.ones_like(discr_pred)
    adv_loss = sigmoid_binary_cross_entropy(
        discr_pred, discr_labels
    ).mean()

    adv_factor = 1.0

    # return act_pred_loss + adv_factor * adv_loss
    return adv_factor * adv_loss, updates['batch_stats']


class AdvObsModellingTrainState(struct.PyTreeNode):
    step: int
    rng: jnp.ndarray
    tx: optax.GradientTransformation = struct.field(pytree_node=False)

    # obs model
    obs_model_apply_fn: Callable = struct.field(pytree_node=False)
    obs_model_vars: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    obs_model_opt_state: optax.OptState = struct.field(pytree_node=True)

    # discriminator
    discriminator_apply_fn: Callable = struct.field(pytree_node=False)
    discriminator_vars: core.FrozenDict[str, Any] = struct.field(pytree_node=True)
    discriminator_opt_state: optax.OptState = struct.field(pytree_node=True)


class AdversarialObsModellingTrainer:

    def __init__(self,
                 obs_model: ObsModel,
                 discriminator: Discriminator,
                 config: dict,
                 callback: TrainerCallback = None):
        self.obs_model = obs_model
        self.discriminator = discriminator
        self.callback = callback or TrainerCallback()
        self.config = config

        if 'seed' not in config:
            self.config['seed'] = 0
        if 'learning_rate' not in config:
            self.config['learning_rate'] = 1e-3
        if 'discriminator_updates_per_step' not in config:
            self.config['discriminator_updates_per_step'] = 1
        if 'obs_model_updates_per_step' not in config:
            self.config['obs_model_updates_per_step'] = 1

    def update_obs_model(self, tx, opt_state, model_vars, inputs):
        _, obs_model_vars = model_vars
        grad_fn = jax.value_and_grad(compute_adv_obs_model_loss, has_aux=True)
        (loss, batch_stats), (_, grad) = grad_fn(model_vars, inputs)
        updates, new_opt_state = tx.update(grad, opt_state)
        new_params = optax.apply_updates(obs_model_vars['params'], updates['params'])
        new_variables = {'params': new_params, 'batch_stats': batch_stats}
        return loss, new_variables, new_opt_state

    def update_discriminator(self, tx, opt_state, model_vars, inputs):
        discriminator_vars, _ = model_vars
        grad_fn = jax.value_and_grad(compute_discriminator_loss)
        loss, (grad, _) = grad_fn(model_vars, inputs)
        updates, new_opt_state = tx.update(grad, opt_state)
        new_variables = optax.apply_updates(discriminator_vars, updates)
        return loss, new_variables, new_opt_state

    def train_step(self, train_state: AdvObsModellingTrainState) -> AdvObsModellingTrainState:
        rng, _rng = jax.random.split(train_state.rng)
        inputs, metrics = sample_obs_modelling_batch(_rng)

        metrics = {
            f'sample/{k}': v for k, v in metrics.items()
        }

        def run_obs_model_substep(substep_state, _):
            obs_model_vars, opt_state = substep_state
            model_vars = (train_state.discriminator_vars, obs_model_vars)
            loss, new_params, new_opt_state = self.update_obs_model(
                train_state.tx, opt_state, model_vars, inputs
            )
            return (new_params, new_opt_state), loss
        
        (new_obs_model_vars, new_obs_model_opt_state), obs_model_loss = jax.lax.scan(
            run_obs_model_substep,
            (train_state.obs_model_vars, train_state.obs_model_opt_state),
            None,
            self.config['obs_model_updates_per_step']
        )
        metrics['train/obs_model_loss'] = obs_model_loss.mean()

        def run_discriminator_substep(substep_state, _):
            discr_vars, opt_state = substep_state
            model_vars = (discr_vars, train_state.obs_model_vars)
            loss, new_discr_vars, new_opt_state = self.update_discriminator(
                train_state.tx, opt_state, model_vars, inputs
            )
            return (new_discr_vars, new_opt_state), loss

        (new_discr_vars, new_discr_opt_state), discr_loss = jax.lax.scan(
            run_discriminator_substep,
            (train_state.discriminator_vars, train_state.discriminator_opt_state),
            None,
            self.config['discriminator_updates_per_step']
        )
        metrics['train/discriminator_loss'] = discr_loss.mean()

        new_params = (new_discr_vars, new_obs_model_vars)
        metrics.update(self.compute_eval_metrics(new_params, inputs))

        jax.experimental.io_callback(
            self.callback.on_iteration_end, None,
            train_state.step, train_state, metrics
        )

        train_state = train_state.replace(
            rng=rng,
            step=train_state.step + 1,
            obs_model_vars=new_obs_model_vars,
            obs_model_opt_state=new_obs_model_opt_state,
            discriminator_vars=new_discr_vars,
            discriminator_opt_state=new_discr_opt_state
        )

        return train_state, metrics

    def compute_eval_metrics(self, model_vars: tuple, inputs: tuple) -> dict:
        eval_metrics = {}

        discr_vars, obs_model_vars = model_vars
        
        obs_preds = self.obs_model.apply(obs_model_vars, inputs)
        mean_var = jnp.mean(jnp.var(obs_preds, axis=0))
        eval_metrics['eval/obs_pred_var'] = mean_var

        # noise, obs = inputs
        # eval_metrics['eval/obs_pred_err'] = jnp.square(obs_preds - obs).mean()

        # pred_action_dist, _ = pretrained_policy(obs_preds)
        # eval_metrics['eval/action_pred_acc'] = (
        #     (pred_action_dist.probs.argmax(axis=-1) == actions.argmax(axis=-1)).mean()
        # )

        obs_inputs, labels = create_discriminator_sample(obs_model_vars, inputs)
        label_pred = self.discriminator.apply(discr_vars, obs_inputs)
        eval_metrics['eval/discr_acc'] = (
            ((label_pred > 0) == labels).mean()
        )

        return eval_metrics

    def train(self, n_steps: int) -> Tuple[AdvObsModellingTrainState, dict]:
        optimizer = optax.adam(self.config['learning_rate'])
        rng = jax.random.PRNGKey(self.config['seed'])

        rng, _rng = jax.random.split(rng)
        inputs, _ = sample_obs_modelling_batch(_rng)

        rng, _rng = jax.random.split(rng)
        obs_model_vars = self.obs_model.init(_rng, inputs, train=False)
        obs_model_opt_state = optimizer.init(obs_model_vars)

        rng, _rng = jax.random.split(rng)
        discr_inp, *_ = create_discriminator_sample(obs_model_vars, inputs)
        discr_vars = self.discriminator.init(_rng, discr_inp)
        discr_opt_state = optimizer.init(discr_vars)

        init_train_state = AdvObsModellingTrainState(
            step=0,
            rng=rng,
            tx=optimizer,
            obs_model_apply_fn=self.obs_model.apply,
            obs_model_vars=obs_model_vars,
            obs_model_opt_state=obs_model_opt_state,
            discriminator_apply_fn=self.discriminator.apply,
            discriminator_vars=discr_vars,
            discriminator_opt_state=discr_opt_state,
        )

        self.callback.on_train_begin(self.config)

        try:
            train_results = jax.lax.scan(
                jax.jit(lambda s, _: self.train_step(s)),
                init_train_state, None, n_steps
            )
        finally:
            self.callback.on_train_end(None)

        return train_results

In [47]:
from poclaps.train.wandb_cb import WandbCallback


cb = WandbCallback(['OBS-GAN'])
config = {
    'seed': 0,
    'learning_rate': 1e-4,
    'discriminator_updates_per_step': 1,
    'obs_model_updates_per_step': 1,
    'wandb_entity': 'drcope',
    'wandb_project': 'poclaps-obs-modelling',
    'wandb_mode': 'online',
}
aom_trainer = AdversarialObsModellingTrainer(obs_model, discriminator, config, cb)

In [48]:
train_state, metrics = aom_trainer.train(5000)

Logged in to wandb using secrets/wandb_api_key.




VBox(children=(Label(value='0.015 MB of 0.015 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/discr_acc,██▃█▇▆▆▅▄▄▄▄▃▄▃▂▃▃▂▂▂▂▃▂▃▄▃▂▁▂▂▂▁▃▃▂▂▃▂▂
eval/obs_pred_var,▁▄▃▃▅▆▇▇▇▇▇█▇▇███▇███▇█▇█▇█▇▇▇█▇▇▇▇▇█▇██
sample/mean_episode_len,▆▃▅▄▅▄▆▅▅▆▂▃▆▄▃▄▂▃▄▁▅▄▆▄▅▄▅▄▆▄█▆▄▄▆▄▅▄▄▄
sample/mean_reward,▃▆▄▅▄▅▃▄▄▃▇▆▃▅▆▅▇▆▅█▄▅▃▅▄▅▄▅▃▅▁▃▅▅▃▅▄▅▅▅
sample/n_episodes,▂▆▄▅▄▅▃▄▄▃▇▅▃▅▆▅█▆▅█▃▅▂▅▄▄▃▅▂▅▁▃▅▅▃▅▄▄▅▄
train/discriminator_loss,▁▁▇▁▃▃▄▅▅▆▆▅▇▆▇▇▇▆█▇█▇▇▆▇▅▇██▇▇▇█▇▇█▇▆█▇
train/obs_model_loss,▁▁▄▇███▇█▇▇▇█▇▇▇▇▇▆▇▇▇▇▇▆▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇

0,1
eval/discr_acc,0.96275
eval/obs_pred_var,0.09643
sample/mean_episode_len,3.36195
sample/mean_reward,-1.36195
sample/n_episodes,594.0
train/discriminator_loss,0.1027
train/obs_model_loss,5.17709


In [49]:
inputs, _ = sample_obs_modelling_batch(jax.random.PRNGKey(0))
obs_preds = obs_model.apply(train_state.obs_model_vars, inputs)

In [56]:
obs_preds[:4].reshape((-1, 2, 2, 5)).argmax(axis=-1)

Array([[[4, 2],
        [3, 0]],

       [[4, 2],
        [0, 4]],

       [[3, 2],
        [0, 2]],

       [[4, 2],
        [1, 0]]], dtype=int32)

In [57]:
noise, obs = inputs
obs[:3].reshape((-1, 2, 2, 5)).argmax(axis=-1)

Array([[[0, 4],
        [1, 1]],

       [[2, 3],
        [2, 3]],

       [[2, 4],
        [0, 3]]], dtype=int32)