# Intro

World Model system includes:

1. Make Env - script to initialization of environment
2. Rollout - Trajectory $$obs \rightarrow action \rightarrow reward \rightarrow obs$$ repeated in time
3. Ring Buffer - Ring Buffer is the data structure (fixed-size table with overwrite).
4. Replay Buffer - Replay Buffer is the code/abstraction that defines how the agent reads and writes to that Ring Buffer (CRUD, sampling).
...

Core idea to understand within World Model training loop is that `env.step(action)` which directly communicate with environment, is not directly used to train the model.

Rollout which is iterative output of multiple `env.step(action)` first is fed to Ring Buffer (created by Replay Buffer), then input to our train model is sampled from Ring Buffer (sampling is also defined in Replay Buffer).

Why we're doing this? It's actually very intuitive, single `env.step(action)` is single point in time. Imagine we ask someone to define where and how fast car is moving by giving this person single position and time, e.g.

$$ 10.00 - \text{house} $$

No one is able to say how fast car is moving or even if at all it's moving. But if we'll provide information spreaded in time like:

$$ 10.00 - \text{house} $$
$$ 10.15 - \text{shop} $$

We can say that car made a distance from house to shop within maximum 15 minutes. Ring Buffer collects as many points in time as we define in `capacity` variable.

---
---

# Environment

## Setup

In [49]:
!pip install gymnasium



## Make Env

Initialization of environment.

In [50]:
import gymnasium as gym

def make_env(id = 'CartPole-v1', seed = 0):
  env = gym.make(id)
  env.reset(seed=seed)
  return env

In [51]:
make_env()

<TimeLimit<OrderEnforcing<PassiveEnvChecker<CartPoleEnv<CartPole-v1>>>>>

---
---

# Data

## Data Generation - Single Rollout

Here is simple single rollout which is actually part of Collect Random script, but it's good to udnerstand how it works beofre Replay Buffer.

Trajectory $$obs \rightarrow action \rightarrow reward \rightarrow obs$$ repeated in time

In [52]:
# from make_env import make_env

env = make_env()
obs, inf= env.reset()

for t in range(10):
  action = env.action_space.sample()
  obs, reward, terminated, truncated, inf = env.step(action)
  print(t, obs, action)
  if terminated or truncated:
    obs, inf = env.reset()

0 [ 0.03215253  0.23624298  0.01112257 -0.2663498 ] 1
1 [ 0.03687739  0.43120444  0.00579557 -0.5555039 ] 1
2 [ 0.04550148  0.23600158 -0.0053145  -0.2610007 ] 0
3 [ 0.05022151  0.0409559  -0.01053452  0.03000124] 0
4 [ 0.05104063 -0.15401341 -0.00993449  0.3193419 ] 0
5 [ 0.04796036 -0.34899247 -0.00354766  0.60887533] 0
6 [ 0.04098051 -0.54406464  0.00862985  0.9004388 ] 0
7 [ 0.03009922 -0.73930246  0.02663863  1.1958218 ] 0
8 [ 0.01531317 -0.93475896  0.05055506  1.4967333 ] 0
9 [-0.00338201 -1.1304578   0.08048972  1.8047632 ] 0


## Data Collection - Replay Buffer

Ring Buffer is the data structure (fixed-size table with overwrite).
Replay Buffer is the code/abstraction that defines how the agent reads and writes to that Ring Buffer (CRUD, sampling).

It's assembled from `__init__` with arrays (buffers) definitions, `add` which we use to update buffers and `sample` which is used to sample input for model train function.

In [53]:
class ReplayBuffer:
  def __init__(self, capacity, obs_shape, action_shape):
    self.capacity = capacity

    self.obs_buffer    = np.zeros((capacity, *obs_shape),    dtype = np.float32)
    self.action_buffer = np.zeros((capacity, *action_shape), dtype = np.float32)
    self.reward_buffer = np.zeros((capacity, ),              dtype = np.float32)
    self.dones_buffer  = np.zeros((capacity, ),              dtype = bool)

    self.ptr  = 0
    self.size = 0

  def add(self, obs, action, reward, done):
    self.obs_buffer[self.ptr]    = obs
    self.action_buffer[self.ptr] = action
    self.reward_buffer[self.ptr] = reward
    self.dones_buffer[self.ptr]  = terminated or truncated

    self.ptr  = (self.ptr + 1) % self.capacity
    self.size = min(self.size + 1, self.capacity)

  def sample(self, batch_size, seq_len):
    idxs = np.random.randint(0, self.size - seq_len, size = batch_size)

    obs     = np.stack([self.obs_buffer[i:i+seq_len]    for i in idxs])
    actions = np.stack([self.action_buffer[i:i+seq_len] for i in idxs])
    rewards = np.stack([self.reward_buffer[i:i+seq_len] for i in idxs])
    dones   = np.stack([self.dones_buffer[i:i+seq_len]  for i in idxs])

    return obs, actions, rewards, dones

**Breake a code**

###init

In [54]:
capacity = 10
obs_shape = (4,)
action_shape = (1,)

# __init__
obs_buffer             = np.zeros((capacity, *obs_shape), dtype = np.float32)
actions_buffer         = np.zeros((capacity, *action_shape), dtype = np.float32)
rewards_buffer         = np.zeros((capacity,), dtype = np.float32)
dones_buffer           = np.zeros((capacity,), dtype = bool)
# index in Ring Buffer
ptr  = 0
# size of current fill of Ring Buffer
size = 0

###add

In [55]:
for t in range(10):
  action = env.action_space.sample()
  obs, reward, terminated, truncated, inf = env.step(action)

  # add
  obs_buffer[ptr]     = obs
  actions_buffer[ptr] = action
  rewards_buffer[ptr] = reward
  dones_buffer[ptr]   = terminated or truncated

  ptr  = (ptr + 1) % capacity
  size = min(size + 1, capacity)

  if terminated or truncated:
    obs, inf = env.reset()

In [56]:
obs_buffer, obs_buffer.shape

(array([[-2.59911604e-02, -1.32638085e+00,  1.16584994e-01,
          2.12133479e+00],
        [-5.25187775e-02, -1.52245486e+00,  1.59011692e-01,
          2.44764781e+00],
        [-8.29678774e-02, -1.32900453e+00,  2.07964644e-01,
          2.20768571e+00],
        [-1.09547965e-01, -1.52542937e+00,  2.52118349e-01,
          2.55667639e+00],
        [ 5.23264380e-03, -1.52053043e-01,  3.05908322e-02,
          2.52752513e-01],
        [ 2.19158316e-03,  4.26190458e-02,  3.56458835e-02,
         -3.01266965e-02],
        [ 3.04396404e-03,  2.37212166e-01,  3.50433476e-02,
         -3.11353266e-01],
        [ 7.78820738e-03,  4.31817800e-01,  2.88162827e-02,
         -5.92781842e-01],
        [ 1.64245628e-02,  2.36304551e-01,  1.69606451e-02,
         -2.91162938e-01],
        [ 2.11506542e-02,  4.31180596e-01,  1.11373868e-02,
         -5.78448772e-01]], dtype=float32),
 (10, 4))

In [57]:
actions_buffer, actions_buffer.shape

(array([[0.],
        [0.],
        [1.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.],
        [0.],
        [1.]], dtype=float32),
 (10, 1))

In [58]:
rewards_buffer, rewards_buffer.shape

(array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32), (10,))

In [59]:
dones_buffer, dones_buffer.shape

(array([False, False, False,  True, False, False, False, False, False,
        False]),
 (10,))

### sample

In [60]:
seq_len = 2
batch_size = 3

i = 0

idxs           = np.random.randint(0, size - seq_len, size=batch_size)

obs_buffer     = np.stack([obs_buffer[i:i+seq_len] for i in idxs])
actions_buffer = np.stack([actions_buffer[i:i+seq_len] for i in idxs])
rewards_buffer = np.stack([rewards_buffer[i:i+seq_len] for i in idxs])
dones_buffer   = np.stack([dones_buffer[i:i+seq_len] for i in idxs])

In [61]:
idxs, idxs.shape

(array([7, 7, 0]), (3,))

In [62]:
obs_buffer, obs_buffer.shape

(array([[[ 0.00778821,  0.4318178 ,  0.02881628, -0.59278184],
         [ 0.01642456,  0.23630455,  0.01696065, -0.29116294]],
 
        [[ 0.00778821,  0.4318178 ,  0.02881628, -0.59278184],
         [ 0.01642456,  0.23630455,  0.01696065, -0.29116294]],
 
        [[-0.02599116, -1.3263808 ,  0.11658499,  2.1213348 ],
         [-0.05251878, -1.5224549 ,  0.15901169,  2.4476478 ]]],
       dtype=float32),
 (3, 2, 4))

In [63]:
actions_buffer, actions_buffer.shape

(array([[[1.],
         [0.]],
 
        [[1.],
         [0.]],
 
        [[0.],
         [0.]]], dtype=float32),
 (3, 2, 1))

In [64]:
rewards_buffer, rewards_buffer.shape

(array([[1., 1.],
        [1., 1.],
        [1., 1.]], dtype=float32),
 (3, 2))

In [65]:
dones_buffer, dones_buffer.shape

(array([[False, False],
        [False, False],
        [False, False]]),
 (3, 2))

## Test - Collect Random

This is main training loop that collects input to model from Ring Buffer.

In [70]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd()))

import numpy as np
# from src.envs.make_env import make_env
# from src.data.replay_buffer import ReplayBuffer

# config
CAPACITY = 10000
STEPS = 1000

env = make_env()
obs, info = env.reset()

obs_shape = obs.shape
action_shape = (1,) if np.isscalar(env.action_space.sample()) else env.action_space.sample().shape

buffer = ReplayBuffer(CAPACITY, obs_shape, action_shape)

for _ in range(STEPS):
    action = env.action_space.sample()
    next_obs, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated

    buffer.add(obs, action, reward, done)

    obs = next_obs

    if done:
        obs, info = env.reset()

# test sample
batch = buffer.sample(batch_size=4, seq_len=8)
batch

(array([[[ 0.14538473,  0.6041498 , -0.19429037, -1.2399229 ],
         [-0.01840565, -0.03172876,  0.03800981,  0.03123354],
         [-0.01904022,  0.16282807,  0.03863448, -0.24921873],
         [-0.01578366, -0.0328237 ,  0.03365011,  0.05539564],
         [-0.01644013, -0.22841159,  0.03475802,  0.35850263],
         [-0.02100836, -0.42400998,  0.04192808,  0.6619398 ],
         [-0.02948856, -0.61968946,  0.05516687,  0.96752435],
         [-0.04188235, -0.42534995,  0.07451735,  0.6926694 ]],
 
        [[ 0.09671275,  0.59212524, -0.04471171, -0.84491694],
         [ 0.10855526,  0.39764094, -0.06161004, -0.566623  ],
         [ 0.11650807,  0.20343493, -0.0729425 , -0.29396853],
         [ 0.12057677,  0.00942463, -0.07882188, -0.025153  ],
         [ 0.12076526,  0.20558329, -0.07932493, -0.34162706],
         [ 0.12487693,  0.40173897, -0.08615747, -0.6582324 ],
         [ 0.13291171,  0.5979478 , -0.09932213, -0.9767529 ],
         [ 0.14487067,  0.79425126, -0.11885718, -1.

Let's see shapes

In [78]:
names = ["obs", "actions", "rewards", "dones"]
for name, tensor in zip(names, batch):
    print(f"{name:8s} -> shape {tensor.shape}")

obs      -> shape (4, 8, 4)
actions  -> shape (4, 8, 1)
rewards  -> shape (4, 8)
dones    -> shape (4, 8)


---
---

# Model

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from src.data.replay_buffer import ReplayBuffer
from src.envs.make_env import make_env


# ---------- config ----------
CAPACITY   = 100_000
BATCH_SIZE = 32
SEQ_LEN    = 16
LR         = 3e-4
STEPS      = 10_000
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"


# ---------- world model (MINIMAL) ----------
class WorldModel(nn.Module):
    def __init__(self, obs_dim, action_dim, latent_dim=32):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )

        self.dynamics = nn.GRU(
            input_size=latent_dim + action_dim,
            hidden_size=latent_dim,
            batch_first=True
        )

        self.decoder = nn.Linear(latent_dim, obs_dim)
        self.reward_head = nn.Linear(latent_dim, 1)

    def forward(self, obs, actions):
        # obs: [B, T, obs_dim]
        B, T, _ = obs.shape

        z = self.encoder(obs.view(B * T, -1))
        z = z.view(B, T, -1)

        x = torch.cat([z, actions], dim=-1)
        h, _ = self.dynamics(x)

        obs_hat = self.decoder(h)
        reward_hat = self.reward_head(h).squeeze(-1)

        return obs_hat, reward_hat


# ---------- training ----------
def main():
    env = make_env()
    obs, _ = env.reset()

    obs_dim = obs.shape[0]
    action_sample = env.action_space.sample()
    action_dim = 1 if np.isscalar(action_sample) else action_sample.shape[0]

    buffer = ReplayBuffer(CAPACITY, (obs_dim,), (action_dim,))

    # fill buffer with random data
    for _ in range(5_000):
        action = env.action_space.sample()
        next_obs, reward, terminated, truncated, _ = env.step(action)
        buffer.add(obs, action, reward, terminated, truncated)
        obs = next_obs
        if terminated or truncated:
            obs, _ = env.reset()

    model = WorldModel(obs_dim, action_dim).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    mse = nn.MSELoss()

    for step in range(STEPS):
        obs_b, act_b, rew_b, _ = buffer.sample(BATCH_SIZE, SEQ_LEN)

        obs_b = torch.tensor(obs_b, dtype=torch.float32, device=DEVICE)
        act_b = torch.tensor(act_b, dtype=torch.float32, device=DEVICE)
        rew_b = torch.tensor(rew_b, dtype=torch.float32, device=DEVICE)

        obs_hat, rew_hat = model(obs_b, act_b)

        obs_loss = mse(obs_hat, obs_b)
        rew_loss = mse(rew_hat, rew_b)

        loss = obs_loss + rew_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print(f"[{step}] loss={loss.item():.4f}")


if __name__ == "__main__":
    main()