In [1]:
import wandb
import os
import uuid
from pathlib import Path

sweep = "cmwbqphs"
# Download checkpoint from wandb
api = wandb.Api()
ckpt_dir = Path("./craftax_checkpoints")
ckpt_runs = api.sweep(f"oxwhirl/craftax-ppo-tuning/cmwbqphs").runs
for ckpt_run in ckpt_runs:
    num_train_interactions = ckpt_run.config["num_train_interactions"]
    reset_type = ckpt_run.config["ppo_reset_on_batch"]
    rand_end_string = str(uuid.uuid4())[:6]
    name = f"craftax_reset_type={reset_type}_num_train_interactions={num_train_interactions}_{rand_end_string}"
    for file in ckpt_run.files():
        file.download(root=(ckpt_dir / name), exist_ok=True)

KeyboardInterrupt: 

In [13]:
from pprint import pprint

run = ckpt_runs[3]
pprint(run.name)

'magic-sweep-15'


In [15]:
import jax
from util import *
from experiments.parse_args import parse_args
from agents.agents import get_agent
from environments.rollout import RolloutWrapper

cmd_args = "--agent=ppo --env_name=Craftax-Symbolic-v1 --gae_lambda=0.9 --layer_width=256 --lr=0.0003 --max_grad_norm=1 --num_agents=1 --num_env_workers=512 --num_minibatches=4 --num_rollout_steps=64 --num_train_interactions=32768 --ppo_num_epochs=4 --save_policy --seed=2".split(
    " "
)
args = parse_args(cmd_args)

rng = jax.random.PRNGKey(args.seed)
env = RolloutWrapper(args.env_name, args.num_rollout_steps)
env_params = env.default_env_params
rng, _rng = jax.random.split(rng)
_rng = jax.random.split(_rng, args.num_env_workers)
obsv, env_state = env.batch_reset(_rng, env_params)
train_state, aux_train_states, agent_train_step_fn = get_agent(
    args,
    rng,
    env.obs_shape,
    env.num_actions,
    env.discrete_actions,
    env.action_lims,
)
train_state

TrainState(step=0, apply_fn=<bound method Module.apply of ActorCritic(
    # attributes
    width = 256
    num_actions = 17
    activation = 'relu'
)>, params={'params': {'actor_0': {'kernel': Array([[ 0.00392421,  0.02899976, -0.01661767, ..., -0.02550365,
         0.09271209,  0.00108168],
       [ 0.02584682,  0.04291209, -0.02767812, ...,  0.00776166,
         0.00491402,  0.12317209],
       [-0.00770354, -0.03328874, -0.04437829, ..., -0.00889074,
         0.00904444,  0.02584643],
       ...,
       [ 0.04543079,  0.04821817, -0.01019958, ...,  0.027421  ,
        -0.01652528,  0.00430968],
       [-0.06850644, -0.00809783, -0.00156537, ...,  0.02846765,
         0.01694525, -0.02046602],
       [ 0.00968764,  0.0542812 ,  0.0134697 , ..., -0.02364796,
        -0.02413787, -0.02466886]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 

In [1]:
def train(rng, env, args, train_state):
    # --- Initialize environment ---
    env_params = env.default_env_params
    rng, _rng = jax.random.split(rng)
    _rng = jax.random.split(_rng, args.num_env_workers)
    obsv, env_state = env.batch_reset(_rng, env_params)

    # --- Initialize agent train states and step function ---
    rng, _rng = jax.random.split(rng)
    # train_state contains actor (and critic if used) to be used for rollouts,
    # aux_train_states contains all other trainable parameters
    _, aux_train_states, agent_train_step_fn = get_agent(
        args,
        _rng,
        env.obs_shape,
        env.num_actions,
        env.discrete_actions,
        env.action_lims,
    )

    # --- Execute train loop ---
    def _train_step(runner_state, _):
        train_state, aux_train_states, env_state, last_obs, rng = runner_state
        # --- Collect trajectories ---
        rng, _rng = jax.random.split(rng)
        _rng = jax.random.split(_rng, args.num_env_workers)
        rollout = env.batch_rollout(_rng, train_state, env_params, last_obs, env_state)
        if args.log_dormancy:
            new_env_state, new_last_obs, traj_batch, dormancy = rollout
        else:
            new_env_state, new_last_obs, traj_batch = rollout

        # --- Update agent ---
        rng, _rng = jax.random.split(rng)
        train_state, aux_train_states, loss, metric = agent_train_step_fn(
            train_state, aux_train_states, traj_batch, new_last_obs, _rng
        )

        # --- Aggregate metrics ---
        if args.log_dormancy:
            metric["dormancy"] = dormancy
        # Manually aggregate Craftax achievements - NaN when no episodes end
        metric.update(
            {
                k: (v * traj_batch.done).sum() / traj_batch.done.sum()
                for k, v in metric.items()
                if "achievements" in k.lower()
            }
        )
        metric, loss = jax.tree_map(jnp.mean, (metric, loss))

        runner_state = (
            train_state,
            aux_train_states,
            new_env_state,
            new_last_obs,
            rng,
        )
        return runner_state, (loss, metric)

    rng, _rng = jax.random.split(rng)
    runner_state = (train_state, aux_train_states, env_state, obsv, _rng)
    runner_state, (loss, metric) = jax.lax.scan(
        _train_step, runner_state, None, args.num_train_iters
    )
    ret = {
        "runner_state": runner_state,
        "metrics": metric,
        "loss": loss,
    }
    if args.save_policy:
        ret["policy"] = train_state
    return ret

In [3]:
from orbax.checkpoint import PyTreeCheckpointer
import jax

from util import *
from experiments.parse_args import parse_args
from agents.agents import get_agent
from environments.rollout import RolloutWrapper
from omegaconf import OmegaConf
from tqdm import tqdm
import wandb
from pathlib import Path

# Create placeholder train state
# TODO: This
# Restore checkpoint into placeholder train state
sweep = "cmwbqphs"
api = wandb.Api()
ckpt_dir = Path("./craftax_checkpoints")
ckpt_runs = api.sweep(f"oxwhirl/craftax-ppo-tuning/{sweep}").runs


def get_fake_train_state(run):
    args = OmegaConf.create(run.config)
    args.num_train_iters = 1
    env = RolloutWrapper(args.env_name, args.num_rollout_steps)
    env_params = env.default_env_params
    rng = jax.random.PRNGKey(args.seed)
    rng, _rng = jax.random.split(rng)
    _rng = jax.random.split(_rng, args.num_env_workers)
    obsv, env_state = env.batch_reset(_rng, env_params)
    train_state, aux_train_states, agent_train_step_fn = get_agent(
        args,
        rng,
        env.obs_shape,
        env.num_actions,
        env.discrete_actions,
        env.action_lims,
    )
    return train_state, env, args


def download_file(run):
    for file in run.files():
        file.download(root=(ckpt_dir / run.name), exist_ok=True)


rng = jax.random.PRNGKey(0)

ckptr = PyTreeCheckpointer()
for run in tqdm(ckpt_runs):
    download_file(run)
    fake_train_state, env, args = get_fake_train_state(run)
    path = ckpt_dir / run.name / "policy"
    train_state = ckptr.restore(path, item=fake_train_state)
    rng, _rng = jax.random.split(rng)
    ret = train(_rng, env, args, train_state)

  0%|          | 0/18 [00:00<?, ?it/s]