## Open notebook in:
| Colab                                 |  
|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nicolepcx/transformers-the-definitive-guide/blob/master/CH07/ch07_STORM.ipynb)                                             

#About this notebook

This notebook is implements and evaluates the STORM model, which stands for [Stochastic Transformer-based wORld Model](https://openreview.net/pdf?id=WxnrX42rnS). The STORM model is a reinforcement learning architecture that efficiently combines the sequence modeling power of transformers with the stochastic nature of variational autoencoders to enhance agent performance in complex environments.

Key Features of This Notebook:
- Repository Cloning and Setup: The notebook starts by cloning the official STORM repository, installing necessary dependencies, and setting up the environment for running the model.
- Model Checkpoint Management: It includes downloading and decompressing model checkpoints that are essential for evaluating the STORM model's performance.
- Evaluation Script Creation: The notebook creates a custom Python evaluation script that sets up the environment, loads the trained world model and agent, and evaluates the agent's performance by generating a video of its actions within the environment.
- Execution and Output: The notebook concludes by running the evaluation script and saving the output video, which visualizes the agent's behavior and performance within the selected environment.

Context of STORM:
The STORM model is an advanced approach in model-based reinforcement learning, where a parameterized simulation model of the environment is constructed through self-supervised learning. This model helps the agent improve its policy without constantly relying on real environment samples, making training more efficient.

Performance:
STORM sets a new benchmark with a mean human performance of 126.7% on the Atari 100k benchmark, and its efficient training process makes it highly practical for real-world applications.




# Install Dependencies

In [None]:
# Clone the repository
!git clone https://github.com/Nicolepcx/STORM.git


Cloning into 'STORM'...
remote: Enumerating objects: 66, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 66 (delta 23), reused 16 (delta 16), pack-reused 26[K
Receiving objects: 100% (66/66), 516.24 KiB | 1.82 MiB/s, done.
Resolving deltas: 100% (30/30), done.


In [None]:

# Change directory to the cloned repo
%cd STORM


/content/STORM


In [None]:
%ls

agents.py      env_wrapper.py  readme.md         [0m[01;34msub_models[0m/     train.sh
[01;34mconfig_files[0m/  eval.py         replay_buffer.py  TensorBoard.sh  utils.py
D_TRAJ.7z      eval.sh         requirements.txt  train.py


In [None]:
# Install dependencies
!pip install -r requirements.txt -qqq




# Download the checkpoints from Google Drive

Here you download the model checkpoints from my training of STORM.

In [None]:
!gdown --id 1qbQ5b6cfuQHf-nanfXAKcRxlpZJ1zU7R --output model_checkpoints.tar.gz

Downloading...
From (original): https://drive.google.com/uc?id=1qbQ5b6cfuQHf-nanfXAKcRxlpZJ1zU7R
From (redirected): https://drive.google.com/uc?id=1qbQ5b6cfuQHf-nanfXAKcRxlpZJ1zU7R&confirm=t&uuid=94bd83d9-edc3-4020-a364-c1eb884b7aec
To: /content/STORM/model_checkpoints.tar.gz
100% 3.03G/3.03G [01:35<00:00, 31.7MB/s]


# Decompress the checkpoints

In [None]:

!tar -xzvf model_checkpoints.tar.gz -C .


ckpt/
ckpt/.ipynb_checkpoints/
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/agent_100000.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/world_model_100000.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/agent_97500.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/world_model_97500.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/agent_95000.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/world_model_95000.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/agent_92500.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/world_model_92500.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/agent_90000.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/world_model_90000.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/agent_87500.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/world_model_87500.pth
ckpt/RoadRunner-life_done-wm_2L512D8H-100k-seed1/agent_85000.pth
ckpt/RoadRunner-life_done-wm_2L512D8

# Run Evaluation Script

In [None]:
!mkdir -p eval_result

# Make the eval.sh script executable
!chmod +x eval.sh

# Run the eval.sh script
!./eval.sh

# Create the evaluation script
eval_script = """
import gymnasium
import argparse
from utils import load_config
from sub_models.world_models import WorldModel
import agents
import torch
import imageio
import env_wrapper  # Ensure this is imported

def build_single_env(env_name, image_size, seed):
    env = gymnasium.make(env_name, full_action_space=False, render_mode="rgb_array", frameskip=1)
    env = env_wrapper.SeedEnvWrapper(env, seed=seed)
    env = env_wrapper.MaxLast2FrameSkipWrapper(env, skip=4)
    env = gymnasium.wrappers.ResizeObservation(env, shape=image_size)
    env = env_wrapper.LifeLossInfo(env)
    return env

def evaluate(env, world_model, agent, video_path):
    obs, _ = env.reset()
    total_reward = 0
    done = False
    frames = []

    while not done:
        frames.append(env.render())
        action = agent.select_action(world_model.encode_obs(obs))
        obs, reward, done, _, _ = env.step(action)
        total_reward += reward

    env.close()
    imageio.mimsave(video_path, frames, fps=30)
    return total_reward

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-n", type=str, required=True)
    parser.add_argument("-config_path", type=str, required=True)
    parser.add_argument("-env_name", type=str, required=True)
    parser.add_argument("-checkpoint_path", type=str, required=True)
    parser.add_argument("-video_path", type=str, required=True, help="Path to save the output video")
    args = parser.parse_args()

    conf = load_config(args.config_path)
    env = build_single_env(args.env_name, conf.BasicSettings.ImageSize, seed=0)
    action_dim = env.action_space.n

    world_model = WorldModel(
        in_channels=conf.Models.WorldModel.InChannels,
        action_dim=action_dim,
        transformer_max_length=conf.Models.WorldModel.TransformerMaxLength,
        transformer_hidden_dim=conf.Models.WorldModel.TransformerHiddenDim,
        transformer_num_layers=conf.Models.WorldModel.TransformerNumLayers,
        transformer_num_heads=conf.Models.WorldModel.TransformerNumHeads
    ).cuda()
    agent = agents.ActorCriticAgent(
        feat_dim=32*32+conf.Models.WorldModel.TransformerHiddenDim,
        num_layers=conf.Models.Agent.NumLayers,
        hidden_dim=conf.Models.Agent.HiddenDim,
        action_dim=action_dim,
        gamma=conf.Models.Agent.Gamma,
        lambd=conf.Models.Agent.Lambda,
        entropy_coef=conf.Models.Agent.EntropyCoef,
    ).cuda()

    world_model.load_state_dict(torch.load(f"{args.checkpoint_path}/world_model.pth"))
    agent.load_state_dict(torch.load(f"{args.checkpoint_path}/agent.pth"))

    total_reward = evaluate(env, world_model, agent, args.video_path)
    print(f"Total reward: {total_reward}")
   """

with open("eval.py", "w") as file:
    file.write(eval_script)

# Step 10: Run the evaluation script and save the video
!python eval.py -n STORM_eval -config_path config_files/STORM.yaml -env_name ALE/RoadRunner-v5 -checkpoint_path ckpt/MsPacman-life_done-wm_2L512D8H-100k-seed1 -video_path output_video.mp4


[31mNamespace(config_path='config_files/STORM.yaml', env_name='ALE/RoadRunner-v5', run_name='RoadRunner-life_done-wm_2L512D8H-100k-seed1')[0m
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
[100000]
  0% 0/1 [00:00<?, ?it/s]Current env: [33mALE/RoadRunner-v5[0m
Mean reward: [33m16960.0[0m
100% 1/1 [00:49<00:00, 49.96s/it]
A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
Traceback (most recent call last):
  File "/content/STORM/eval.py", line 66, in <module>
    world_model.load_state_dict(torch.load(f"{args.checkpoint_path}/world_model.pth"))
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 997, in load
    with _open_file_like(f, 'rb') as opened_file:
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 444, in _open_file_like
    return _open_file(name_or_buffer, mode)
  File "/usr/local/lib/python3.10/dist-packages/torch/serialization.py", line 425, in __init__
  