## Pretrained Crafter Gameplay Demos

This demo notebook loads pretrained weights for a **pixel-based** Crafter agent and replays gameplay videos for user-specified instructions.

In [1]:
# Specify GPU device index for policy inference
GPU = 0

In [2]:
import random
import os
import glob
import warnings

import requests
from IPython.display import Video, HTML
import torch

crafter_example = __import__("07_crafter_with_instructions")
import amago
from amago.cli_utils import *
from amago.envs.builtin.crafter_envs import CrafterEnv

#### Step 1: Initialize a new agent with the correct architecture

In [3]:
# boilerplate to imitate training script command line interface
parser = ArgumentParser()
add_common_cli(parser)
args = parser.parse_args(["--run_name=crafter_dec23", "--buffer_dir=crafter_pretrained_example"])
args.no_log = True
args.gpu = GPU

config = {
    "amago.agent.Agent.reward_multiplier" : 10.,
    "amago.agent.Agent.tstep_encoder_Cls" : partial(crafter_example.CrafterTstepEncoder, obs_kind="crop"),
    "amago.nets.tstep_encoders.TstepEncoder.goal_emb_Cls" : amago.nets.goal_embedders.TokenGoalEmb,
    "amago.nets.goal_embedders.TokenGoalEmb.zero_embedding" : False,
    "amago.nets.goal_embedders.TokenGoalEmb.goal_emb_dim" : 64,
}

switch_traj_encoder(config, arch="transformer", memory_size=256, layers=3)
use_config(config, args.configs, finalize=False)

make_env = lambda: CrafterEnv(
        directed=True,
        k=5,
        min_k=1,
        time_limit=2500,
        obs_kind="crop",
        use_tech_tree=False,
        save_video_to="crafter_notebook_videos/",
)

group_name = f"crafter_dec23_directed_crafter_crop"
run_name = group_name + "_trial_0"
experiment = create_experiment_from_cli(
            args,
            make_train_env=make_env,
            make_val_env=make_env,
            max_seq_len=512,
            traj_save_len=2501,
            stagger_traj_file_lengths=False,
            run_name=run_name,
            group_name=group_name,
            batch_size=18,
            val_timesteps_per_epoch=5000,
            relabel="some",
            goal_importance_sampling=True,
        )

# build the agent with randomly initialized weights. This is a smaller architecture than used in the main results
experiment.start()



 		 AMAGO
        	 -------------------------
        	 Environment Horizon: 2500
        	 Policy Max Sequence Length: 512
        	 Trajectory File Sequence Length: 2501
        	 Mode: Fixed Context with Valid Relabeling (Approximate Meta-RL / POMDPs)
        	 Half Precision: False
        	 Fast Inference: True
        	 Total Parameters: 6,517,894 




#### Step 2: Download and replace parameters with pretrained checkpoint

In [4]:
# checkpoint from long pixel-based training run that closely reproduces Appendix C5 Table 2 using the public repo
ckpt_link = "https://utexas.box.com/shared/static/xvkgo02vkp8kn7j80051jbr6224tep9r.pt"
response = requests.get(ckpt_link)

# write file to the expected name and location based on this experiment's configuration (above).
with open(os.path.join("crafter_pretrained_example", run_name, "ckpts", "crafter_dec23_directed_crafter_crop_trial_0_BEST.pt"), "wb") as f:
    f.write(response.content)

# load checkpoint
# you would normally load the best checkpoint like this:
# experiment.load_checkpoint(loading_best=True)
# manual workaround for backwards-compatible old checkpoint
ckpt = torch.load(os.path.join(experiment.ckpt_dir, f"{experiment.run_name}_BEST.pt"), map_location=experiment.DEVICE)
experiment.policy.load_state_dict(ckpt["model_state"], strict=False)

_IncompatibleKeys(missing_keys=['maximized_critics.inp_layer.weight', 'maximized_critics.inp_layer.bias', 'maximized_critics.core_layers.0.weight', 'maximized_critics.core_layers.0.bias', 'maximized_critics.output_layer.weight', 'maximized_critics.output_layer.bias'], unexpected_keys=[])

#### Step 3: Evaluate and visualize

In [5]:
# Specify the task here! A list of any of up to 5 of the crafter achievements with _ separators. For example:
TASK = ["make_stone_pickaxe", "collect_coal", "travel_40m_40m", "place_stone"]

In [6]:
# reset video directory for a new task
files = glob.glob('crafter_notebook_videos/*')
for f in files:
    os.remove(f)

def make_eval_env():
    e = make_env()
    e.set_env_name(f"crafter_eval")
    # manually set the task
    e.set_fixed_task([t.split("_") for t in TASK])
    return e

experiment.parallel_actors = 6 # adjust as needed!

# runs the evaluation and saves videos to disk
warnings.filterwarnings("ignore", category=UserWarning)
success = experiment.evaluate_test(make_eval_env, timesteps=3500, render=False)["Average Success Rate in crafter_eval"]
print(f"\n\nTask \"{', '.join(TASK)}\" Success Rate: {success * 100 : .1f}%")

See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
                                                                                                                                               7.06it/s][0m



Task "make_stone_pickaxe, collect_coal, travel_40m_40m, place_stone" Success Rate:  58.0%


In [7]:
# display gameplay videos in the notebook; run again for a new random sample
MAX_VIDEOS = 8
videos = glob.glob("crafter_notebook_videos/*")
random.shuffle(videos)
html_str = f"<table><caption style='font-size: 24px'>{', '.join(TASK)}</caption><tr>"
for i, video_path in enumerate(videos):
    video_html = f"""
    <td>
        <video width=300px alt="Video" controls>
            <source src="{video_path}" type="video/mp4">
        </video>
    </td>
    """
    html_str += video_html
    if (i + 1) % 4 == 0:
        html_str += "</tr><tr>"
    if i >= MAX_VIDEOS - 1:
        break
html_str += "</tr></table>"
HTML(html_str)

0,1,2,3
,,,
,,,
,,,
