In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import gymnasium as gym
import numpy as np
import random
from pathlib import Path
import matplotlib.pyplot as plt
import ale_py
import gymnasium as gym
gym.register_envs(ale_py)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [14]:
from gymnasium import spaces

env = gym.make("ALE/Pong-v5", frameskip=1, repeat_action_probability=0.0, full_action_space=True)
env = gym.make("AlienNoFrameskip-v4")
env = gym.wrappers.ResizeObservation(env, (64, 64))


def transform_obs(obs):
    obs_t = np.transpose(obs, (2, 0, 1)).astype(np.float32) / 255.0
    return obs_t

new_obs_space = spaces.Box(low=0.0, high=1.0, shape=(3, 64, 64), dtype=np.float32)

env = gym.wrappers.TransformObservation(
    env,
    func=transform_obs,
    observation_space=new_obs_space,
)

In [15]:
obs, info = env.reset()
assert obs.shape == (3, 64, 64), f"Expected (3, 64, 64), got {obs.shape}"

In [16]:
import sys
import os
local_repo_path = os.path.abspath(os.path.join(os.getcwd(), 'world-models'))
local_repo_path = os.path.abspath(os.path.join(os.getcwd(), '../world-models-fork'))
if local_repo_path not in sys.path:
    sys.path.append(local_repo_path)
from models.rssm import RSSM

In [None]:
from models.models import EncoderCNN, DecoderCNN, RewardModel
from models.dynamics import DynamicsModel

# Dimensions must match training
action_dim = 18
hidden_size = 1024
state_size = 32
embedding_dim = 1024
image_shape = (3, 64, 64)

encoder = EncoderCNN(3, embedding_dim, (image_shape[1], image_shape[2])).to(device)
decoder = DecoderCNN(hidden_size, state_size, embedding_dim, True, image_shape).to(device)
reward_model = RewardModel(hidden_size, state_size).to(device)
dynamics = DynamicsModel(hidden_size, action_dim, state_size, embedding_dim).to(device)

rssm = RSSM(encoder, decoder, reward_model, dynamics, hidden_size, state_size, action_dim, embedding_dim, device)
#checkpoint = torch.load("checkpoints/rssm/rssm_checkpoint_epoch_296.pth", map_location=device)
checkpoint = torch.load("checkpoints/rssm_best.pth", map_location=device)
rssm.encoder.load_state_dict(checkpoint["encoder"])
rssm.eval()
print("RSSM encoder loaded.")


RSSM encoder loaded.


In [18]:
class RSSMFeatureExtractor(nn.Module):
    def __init__(self, rssm):
        super().__init__()
        self.encoder = rssm.encoder
    def forward(self, obs):
        obs_t = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
        with torch.no_grad():
            z = self.encoder(obs_t)
        return z.squeeze(0)

In [19]:
class LinearPolicy(nn.Module):
    def __init__(self, latent_dim, action_dim):
        super().__init__()
        self.linear = nn.Linear(latent_dim, action_dim)
    def forward(self, z):
        logits = self.linear(z)
        return torch.distributions.Categorical(logits=logits)

In [20]:
encoder = RSSMFeatureExtractor(rssm).to(device)
policy = LinearPolicy(embedding_dim, env.action_space.n).to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)

In [21]:
def run_episode(env, encoder, policy, gamma=0.99, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        env.reset(seed=seed)

    obs, info = env.reset()
    log_probs, rewards = [], []
    done = False
    total_reward = 0

    while not done:
        z = encoder(obs)
        dist = policy(z)
        action = dist.sample()
        log_probs.append(dist.log_prob(action))

        obs, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated
        rewards.append(reward)
        total_reward += reward

    # discounted returns
    returns, G = [], 0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.tensor(returns, dtype=torch.float32, device=device)
    returns = (returns - returns.mean()) / (returns.std() + 1e-8)
    loss = -torch.sum(torch.stack(log_probs) * returns)
    return loss, total_reward


In [None]:
num_episodes = 50
reward_history = []

for ep in range(num_episodes):
    policy.train()
    loss, total_reward = run_episode(env, encoder, policy, seed=ep)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    reward_history.append(float(total_reward))
    torch.cuda.empty_cache()
    print(f"Episode {ep+1}/{num_episodes} | Reward: {total_reward:.1f}")

env.close()

Episode 1/10 | Reward: 80.0
Episode 2/10 | Reward: 110.0
Episode 3/10 | Reward: 100.0
Episode 4/10 | Reward: 120.0
Episode 5/10 | Reward: 140.0
Episode 6/10 | Reward: 100.0
Episode 7/10 | Reward: 150.0
Episode 8/10 | Reward: 140.0
Episode 9/10 | Reward: 100.0
Episode 10/10 | Reward: 120.0


In [None]:
print(len(reward_history), type(reward_history[0]))

10 <class 'float'>


: 

In [None]:
import matplotlib
matplotlib.use("Agg")

import matplotlib.pyplot as plt
plt.figure()
plt.plot(reward_history)
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("RSSM-based Linear Policy on Pong")
plt.savefig("reward_plot.png")
print("Saved to reward_plot.png")