# Gradient-based policy

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import numpy as np
import random
from pathlib import Path
import matplotlib.pyplot as plt
import ale_py
import imageio
gym.register_envs(ale_py)

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage, DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback

from gymnasium import spaces


device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
from gymnasium import spaces

env = gym.make("ALE/Pong-v5", frameskip=1, repeat_action_probability=0.0, full_action_space=True)
env = gym.make("AlienNoFrameskip-v4")
env = gym.wrappers.ResizeObservation(env, (64, 64))


def transform_obs(obs):
    obs_t = np.transpose(obs, (2, 0, 1)).astype(np.float32) / 255.0
    return obs_t

new_obs_space = spaces.Box(low=0.0, high=1.0, shape=(3, 64, 64), dtype=np.float32)

env = gym.wrappers.TransformObservation(
    env,
    func=transform_obs,
    observation_space=new_obs_space,
)

In [15]:
obs, info = env.reset()
assert obs.shape == (3, 64, 64), f"Expected (3, 64, 64), got {obs.shape}"

In [16]:
import sys
import os
local_repo_path = os.path.abspath(os.path.join(os.getcwd(), 'world-models'))
local_repo_path = os.path.abspath(os.path.join(os.getcwd(), '../world-models-fork'))
if local_repo_path not in sys.path:
    sys.path.append(local_repo_path)
from models.rssm import RSSM

In [None]:
from models.models import EncoderCNN, DecoderCNN, RewardModel
from models.dynamics import DynamicsModel

# Dimensions must match training
action_dim = 18 
hidden_size = 1024
state_size = 32
embedding_dim = 1024
image_shape = (3, 64, 64)

encoder = EncoderCNN(3, embedding_dim, (image_shape[1], image_shape[2])).to(device)
decoder = DecoderCNN(hidden_size, state_size, embedding_dim, True, image_shape).to(device)
reward_model = RewardModel(hidden_size, state_size).to(device)
dynamics = DynamicsModel(hidden_size, action_dim, state_size, embedding_dim).to(device)

rssm = RSSM(encoder, decoder, reward_model, dynamics, hidden_size, state_size, action_dim, embedding_dim, device)
#checkpoint = torch.load("checkpoints/rssm/rssm_checkpoint_epoch_296.pth", map_location=device)
checkpoint = torch.load("checkpoints_less_diverse/rssm_best.pth", map_location=device)
rssm.encoder.load_state_dict(checkpoint["encoder"])
rssm.decoder.load_state_dict(checkpoint["decoder"])
rssm.reward_model.load_state_dict(checkpoint["reward_model"])
rssm.dynamics.load_state_dict(checkpoint["dynamics"])
rssm.eval()
print("RSSM encoder loaded.")

RSSM encoder loaded.


In [18]:
class RSSMFeatureExtractor(nn.Module):
    def __init__(self, rssm):
        super().__init__()
        self.encoder = rssm.encoder
    def forward(self, obs):
        obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
        with torch.no_grad():
            z = self.encoder(obs_t)
        return z.squeeze(0)

In [None]:
class LinearPolicy(nn.Module):
    def __init__(self, latent_dim, action_dim, hidden_dim=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, z):
        return self.net(z)

In [None]:
encoder = RSSMFeatureExtractor(rssm).to(device)
policy = LinearPolicy(embedding_dim, action_dim).to(device) # env.action_space.n
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)

In [None]:
def run_episode(env, encoder, policy, rssm, gamma=0.99, seed=None, save_video=True):
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        env.reset(seed=seed)

    obs, info = env.reset()
    log_probs, rewards = [], []
    done = False
    total_reward = 0

    raw_frames, recon_frames = [], []

    h = torch.zeros(1, hidden_size, device=device)
    s = torch.zeros(1, state_size, device=device)
    prev_action = torch.zeros(1, action_dim, device=device)

    while not done:
        frame = (obs * 255).astype(np.uint8).transpose(1, 2, 0)
        raw_frames.append(frame)

        z = encoder(obs)
        logits = policy(z)
        valid = env.action_space.n
        logits[valid:] = -1e9
        dist = torch.distributions.Categorical(logits=logits)
        action = dist.sample()
        
        log_probs.append(dist.log_prob(action))

        one_hot = torch.nn.functional.one_hot(
            torch.tensor([action.item()], device=device),
            rssm.action_dim
        ).float()

        with torch.no_grad():
            h, s = rssm.step(h, s, one_hot, z.unsqueeze(0))
            decoded = rssm.decoder(h, s)
            recon = decoded[0].permute(1, 2, 0).cpu().numpy()
            recon = np.clip(recon, 0.0, 1.0)
            recon_frames.append((recon * 255).astype(np.uint8))

        obs, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        rewards.append(reward)
        total_reward += reward
        prev_action = one_hot

    if save_video:
        imageio.mimsave("rssm_raw.gif", raw_frames, fps=15)
        imageio.mimsave("rssm_recon.gif", recon_frames, fps=15)
        paired = [np.hstack((f, r)) for f, r in zip(raw_frames, recon_frames)]
        imageio.mimsave("rssm_compare.gif", paired, fps=15)
        print("Saved: rssm_raw.gif, rssm_recon.gif, rssm_compare.gif")

    returns, G = [], 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.tensor(returns, dtype=torch.float32, device=device)
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)

    loss = -torch.sum(torch.stack(log_probs) * returns)
    return loss, float(total_reward)


In [None]:
num_episodes = 1
reward_history = []

for ep in range(num_episodes):
    policy.train()
    loss, total_reward = run_episode(env, encoder, policy, rssm, seed=ep)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    reward_history.append(float(total_reward))
    torch.cuda.empty_cache()
    print(f"Episode {ep+1}/{num_episodes} | Reward: {total_reward:.1f}")

env.close()

Episode 1/10 | Reward: 80.0
Episode 2/10 | Reward: 110.0
Episode 3/10 | Reward: 100.0
Episode 4/10 | Reward: 120.0
Episode 5/10 | Reward: 140.0
Episode 6/10 | Reward: 100.0
Episode 7/10 | Reward: 150.0
Episode 8/10 | Reward: 140.0
Episode 9/10 | Reward: 100.0
Episode 10/10 | Reward: 120.0


In [None]:
print(len(reward_history), type(reward_history[0]))

In [None]:
import matplotlib
matplotlib.use("Agg")

import matplotlib.pyplot as plt
plt.figure()
plt.plot(reward_history)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("RSSM-based Linear Policy on Pong")
plt.savefig("reward_plot.png")
print("Saved to reward_plot.png")

# Satble baseline library

In [None]:
from stable_baselines3 import PPO
from gymnasium import spaces

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
checkpoint = torch.load("checkpoints_less_diverse/rssm_best.pth", map_location=device)

action_dim = 18
hidden_size = 1024
state_size = 32
embedding_dim = 1024
image_shape = (3, 64, 64)

encoder_cnn = EncoderCNN(3, embedding_dim, image_shape[1:]).to(device)
decoder = DecoderCNN(hidden_size, state_size, embedding_dim, True, image_shape).to(device)
reward_model = RewardModel(hidden_size, state_size).to(device)
dynamics = DynamicsModel(hidden_size, action_dim, state_size, embedding_dim).to(device)

rssm = RSSM(encoder_cnn, decoder, reward_model, dynamics,
            hidden_size, state_size, action_dim, embedding_dim, device)

rssm.encoder.load_state_dict(checkpoint["encoder"])
rssm.eval()


In [None]:
class RSSMEncodingWrapper(gym.ObservationWrapper):
    def __init__(self, env, encoder, embedding_dim, device):
        super().__init__(env)
        self.encoder = encoder
        self.device = device

        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf, shape=(embedding_dim,), dtype=np.float32
        )

    def observation(self, obs):
        obs = torch.tensor(obs, dtype=torch.float32, device=self.device).permute(2,0,1).unsqueeze(0)

        with torch.no_grad():
            z = self.encoder(obs)

        return z.squeeze(0).cpu().numpy()


## Stack latent vectors

In [None]:
from collections import deque
from gymnasium import spaces

class RSSMFrameStackWrapper(gym.ObservationWrapper):
    def __init__(self, env, encoder, embedding_dim, device, n_frames=4, use_encoding=True):
        super().__init__(env)
        self.encoder = encoder
        self.device = device
        self.n_frames = n_frames
        self.embedding_dim = embedding_dim
        self.use_encoding = use_encoding

        self.buffer = deque(maxlen=n_frames)

        if self.use_encoding:
            obs_shape = (embedding_dim * n_frames,)
            low, high = -np.inf, np.inf
        else:
            h, w, c = env.observation_space.shape
            obs_shape = (n_frames * c, h, w)
            low, high = 0.0, 1.0

        self.observation_space = spaces.Box(
            low=low, high=high, shape=obs_shape, dtype=np.float32
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        z = self.encode_obs(obs)

        self.buffer.clear()
        for _ in range(self.n_frames):
            self.buffer.append(z)

        return self.get_stacked(), info

    def observation(self, obs):
        z = self.encode_obs(obs)
        self.buffer.append(z)
        return self.get_stacked()

    def encode_obs(self, obs):
        # Remove extra frame dimension if present
        if obs.ndim == 4:
            obs = obs[0]

        if self.use_encoding:
            obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)\
                          .permute(2, 0, 1).unsqueeze(0) / 255.0
            with torch.no_grad():
                z = self.encoder(obs_t)
            return z.squeeze(0).cpu().numpy().astype(np.float32)
        else:
            return np.transpose(obs, (2,0,1)).astype(np.float32) / 255.0

    def get_stacked(self):
        return np.concatenate(list(self.buffer), axis=0)


#### PPO using encodings

In [None]:
from stable_baselines3.common.monitor import Monitor

env = gym.make("AlienNoFrameskip-v4")
env = gym.wrappers.ResizeObservation(env, (64, 64))
env = RSSMFrameStackWrapper(env, rssm.encoder, embedding_dim, device, n_frames=6)
env = Monitor(env)

In [None]:
from stable_baselines3.common.callbacks import BaseCallback
import matplotlib.pyplot as plt

class RewardPlotCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.episode_rewards = []

    def _on_step(self) -> bool:
        for info in self.locals.get("infos", []):
            if "episode" in info:
                self.episode_rewards.append(info["episode"]["r"])
        return True

In [None]:
callback = RewardPlotCallback()

model = PPO(
    policy="MlpPolicy",
    env=env,
    learning_rate=3e-4,
    verbose=0,
    device="cpu"
)

model.learn(total_timesteps=3_000_000, callback=callback)
model.save("ppo_rssm_latent_policy_stack_latents_3million")

In [None]:
plt.plot(callback.episode_rewards)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Training reward")
plt.savefig("ppo_rssm_reward_plot_stack_latents_3million.png")
plt.show()


#### PPO using encodings - approach 2

In [None]:
def make_encoded_env(n_stack=4):
    def _make():
        env = gym.make("AlienNoFrameskip-v4")

        env = gym.wrappers.AtariPreprocessing(
            env,
            screen_size=64,
            grayscale_obs=False,
            scale_obs=False
        )

        env = RSSMFrameStackWrapper(
            env,
            encoder=encoder,
            embedding_dim=embedding_dim,
            device=device,
            n_frames=n_stack,
            use_encoding=True
        )

        return env

    return _make


In [None]:

class RewardRecorderCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.episode_rewards = []

    def _on_step(self) -> bool:
        for info in self.locals["infos"]:
            if "episode" in info:
                self.episode_rewards.append(info["episode"]["r"])
        return True


In [None]:
n_envs = 8
latent_stack = 4

env = DummyVecEnv([make_encoded_env(latent_stack) for _ in range(n_envs)])

callback = RewardRecorderCallback()

model = PPO(
    policy="MlpPolicy",
    env=env,
    learning_rate=2.5e-4,
    n_steps=128,
    batch_size=256,
    n_epochs=4,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.1,
    ent_coef=0.01,
    device=device,
    verbose=1,
)

model.learn(
    total_timesteps=3_000_000,
    callback=callback,
)

model.save("ppo_alien_latent_2")


In [None]:
plt.figure(figsize=(10, 5))
plt.plot(callback.episode_rewards)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Training Rewards (PPO on RSSM Latents)")
plt.grid(True)
plt.savefig("alien_training_rewards_latent_2.png")
plt.show()


#### PPO using direct observations

In [None]:
env = gym.make("AlienNoFrameskip-v4")
env = gym.wrappers.ResizeObservation(env, (64, 64))
env = RSSMFrameStackWrapper(env, rssm.encoder, embedding_dim, device, n_frames=6, use_encoding=False)
env = Monitor(env)

In [None]:
obs_callback = RewardPlotCallback()

obs_model = PPO(
    policy="MlpPolicy",
    env=env,
    learning_rate=3e-4,
    verbose=0,
    device="cpu"
)

obs_model.learn(total_timesteps=1_000_000, callback=obs_callback)
obs_model.save("ppo_rssm_latent_policy_stack_latents_using_observations")


In [None]:
plt.plot(obs_callback.episode_rewards)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Training reward")
plt.savefig("ppo_rssm_reward_plot_stack_latents_using_observations.png")
plt.show()


#### PPO using direct observation (one observation, cnn policy)

In [None]:
env = gym.make("AlienNoFrameskip-v4")
env = gym.wrappers.ResizeObservation(env, (64, 64))
env = RSSMFrameStackWrapper(env, rssm.encoder, embedding_dim, device, n_frames=1, use_encoding=False)
env = Monitor(env)

In [None]:
one_obs_callback = RewardPlotCallback()

one_obs_model = PPO(
    policy="CnnPolicy",
    env=env,
    learning_rate=3e-4,
    verbose=0,
    device="cpu",
    policy_kwargs=dict(normalize_images=False)
)

one_obs_model.learn(total_timesteps=1_000_000, callback=one_obs_callback)
one_obs_model.save("ppo_rssm_latent_policy_stack_latents_using_one_observation")


In [None]:

plt.plot(one_obs_callback.episode_rewards)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Training reward")
plt.savefig("ppo_rssm_reward_plot_stack_latents_using_one_observation.png")
plt.show()


In [None]:
class RewardRecorderCallback(BaseCallback):
    def __init__(self):
        super().__init__()
        self.episode_rewards = []

    def _on_step(self) -> bool:
        # Vectorized env -> info is list of info dicts
        for info in self.locals["infos"]:
            if "episode" in info:
                ep_reward = info["episode"]["r"]
                self.episode_rewards.append(ep_reward)
        return True

In [None]:
env = make_atari_env("AlienNoFrameskip-v4", n_envs=8, seed=0)
env = VecFrameStack(env, n_stack=4)

reward_callback = RewardRecorderCallback()

In [None]:
model = PPO(
    policy="CnnPolicy",
    env=env,
    learning_rate=2.5e-4,
    n_steps=128,
    batch_size=256,
    n_epochs=4,
    gamma=0.99,
    gae_lambda=0.95,
    clip_range=0.1,
    ent_coef=0.01,
    device="cuda",
    verbose=1,
)

In [None]:
model.learn(
    total_timesteps=3_000_000,
    callback=reward_callback
)

model.save("ppo_alien")


In [None]:
plt.figure(figsize=(10,5))
plt.plot(reward_callback.episode_rewards)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Training Rewards for PPO on Alien")
plt.grid(True)
plt.savefig("alien_training_rewards.png")
plt.show()

## Use [z,h]

In [None]:
class RSSMZPlusHWrapper(gym.Wrapper):
    def __init__(self, env, rssm, embedding_dim, hidden_size, state_size, action_dim, device):
        super().__init__(env)

        self.rssm = rssm
        self.encoder = rssm.encoder
        self.device = device

        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.state_size = state_size
        self.action_dim = action_dim

        self.h = None  # deterministic RNN hidden state
        self.s = None  # stochastic state

        # PPO receives concat[z, h]
        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(embedding_dim + hidden_size,),
            dtype=np.float32
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)

        self.h = torch.zeros(1, self.hidden_size, device=self.device)
        self.s = torch.zeros(1, self.state_size, device=self.device)

        z = self.encode_obs(obs)

        return self.concat_zh(z, self.h), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)

        z = self.encode_obs(obs)

        a_onehot = self.one_hot(action)

        actions = a_onehot.unsqueeze(0).unsqueeze(1)
        actions = torch.cat([actions, actions], dim=1).to(self.device)

        obs_seq = z.unsqueeze(0).unsqueeze(1)
        obs_seq = torch.cat([obs_seq, obs_seq], dim=1).to(self.device)

        with torch.no_grad():
            h_seq, prior_states, post_states, _, _, _, _ = self.rssm.dynamics(
                self.h,
                self.s,
                actions,
                obs_seq
            )

        self.h = h_seq[:, -1]
        self.s = post_states[:, -1]

        return self.concat_zh(z, self.h), reward, terminated, truncated, info

    def encode_obs(self, obs):
        obs_t = torch.tensor(obs, dtype=torch.float32, device=self.device)\
                      .permute(2, 0, 1).unsqueeze(0)
        with torch.no_grad():
            z = self.encoder(obs_t)
        return z.squeeze(0).cpu()

    def one_hot(self, action):
        v = torch.zeros(self.action_dim, device=self.device)
        v[action] = 1.0
        return v

    def concat_zh(self, z, h):
        z = z if isinstance(z, torch.Tensor) else torch.tensor(z)
        return torch.cat([z, h.squeeze(0).cpu()], dim=-1).numpy().astype(np.float32)


In [None]:
env = gym.make("AlienNoFrameskip-v4")
env = gym.wrappers.ResizeObservation(env, (64, 64))

env = RSSMZPlusHWrapper(
    env,
    rssm=rssm,
    embedding_dim=embedding_dim,
    hidden_size=hidden_size,
    state_size=state_size,
    action_dim=action_dim,
    device=device
)

env = Monitor(env)

In [None]:
callback = RewardPlotCallback()

model = PPO(
    policy="MlpPolicy",
    env=env,
    learning_rate=3e-4,
    verbose=0,
    device="cpu"
)

model.learn(total_timesteps=1_000, callback=callback)
model.save("ppo_rssm_latent_policy_z_h")


In [None]:

plt.plot(callback.episode_rewards)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Training reward")
plt.savefig("ppo_rssm_reward_plot_z_h.png")
plt.show()
