# Cross Entropy Method

## CartPole Example

In [1]:
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self, obs_size: int, hidden_size: int, n_actions: int) -> None:
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)
    

In [2]:
import gymnasium as gym
from gymnasium.wrappers import NumpyToTorch
import torch.optim as optim

from examples.cross_entropy import save_net_spec
from velora.utils import load_config

config = load_config("config/cp_ce.yaml")
env: gym.Env = NumpyToTorch(gym.make("CartPole-v1"))

obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n

net = SimpleNet(obs_size, config.model.hidden_size, n_actions)
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), **config.optimizer)

save_net_spec(net, "saved/")

In [None]:
import wandb

wandb.login()

In [3]:
from examples.cross_entropy import train_cartpole

train_cartpole(env, net, loss, optimizer, config, run_idx=3)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: achronus (achronus-uk). Use `wandb login --relogin` to force relogin


0: loss=0.708, reward_mean=16.9, rw_bound=18.5
1: loss=0.687, reward_mean=20.2, rw_bound=20.5
2: loss=0.684, reward_mean=28.1, rw_bound=29.5
3: loss=0.649, reward_mean=34.6, rw_bound=36.5
4: loss=0.664, reward_mean=36.4, rw_bound=40.5
5: loss=0.645, reward_mean=50.9, rw_bound=63.0
6: loss=0.643, reward_mean=47.9, rw_bound=55.5
7: loss=0.640, reward_mean=57.4, rw_bound=54.0
8: loss=0.632, reward_mean=39.6, rw_bound=43.0
9: loss=0.626, reward_mean=53.1, rw_bound=66.5
10: loss=0.619, reward_mean=56.5, rw_bound=65.0
11: loss=0.612, reward_mean=48.6, rw_bound=58.0
12: loss=0.630, reward_mean=50.5, rw_bound=60.0
13: loss=0.610, reward_mean=53.8, rw_bound=49.5
14: loss=0.609, reward_mean=57.7, rw_bound=79.0
15: loss=0.616, reward_mean=95.3, rw_bound=113.0
16: loss=0.610, reward_mean=82.3, rw_bound=102.0
17: loss=0.592, reward_mean=79.1, rw_bound=91.0
18: loss=0.598, reward_mean=92.7, rw_bound=106.0
19: loss=0.586, reward_mean=110.8, rw_bound=128.5
20: loss=0.592, reward_mean=105.2, rw_bound=1

0,1
ep_idx,▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
loss,█▇▇▅▆▅▄▄▄▃▃▃▄▃▃▃▃▂▂▁▂▂▁▁▁▁
reward_bound,▁▁▁▂▂▂▂▂▂▃▃▂▂▂▃▄▄▃▄▅▄▅▅▆▇█
reward_mean,▁▁▁▂▂▂▂▂▂▂▂▂▂▂▃▄▃▃▄▄▄▄▅▅▆█

0,1
ep_idx,25.0
loss,0.5815
reward_bound,228.5
reward_mean,206.8125


In [None]:
import numpy as np

class DiscreteObservationsToVector(gym.ObservationWrapper):
    """One hot encodes a discrete observation space."""
    def __init__(self, env: gym.Env) -> None:
        super().__init__(env)

        assert isinstance(env.observation_space, gym.spaces.Discrete)

        shape = (env.observation_space.n,)
        self.observation_space: gym.spaces.Box = gym.spaces.Box(0.0, 1.0, shape, dtype=np.float32)

    def observation(self, observation: int) -> np.ndarray:
        res = np.copy(self.observation_space.low)
        res[observation] = 1.0
        return res
    
    def step(self, action):
        # Fixes bug with 'NumPyToTorch' wrapper
        action = action.item() if isinstance(action, np.ndarray) else action
        observation, reward, terminated, truncated, info = self.env.step(action)
        return self.observation(observation), reward, terminated, truncated, info

## FrozenLake Example

In [None]:
HIDDEN_SIZE = 128
BATCH_SIZE = 100
GAMMA = 0.9

env2: gym.Env = NumpyToTorch(DiscreteObservationsToVector(gym.make("FrozenLake-v1", is_slippery=False)))

obs_size = env2.observation_space.shape[0]
n_actions = env2.action_space.n

net2 = SimpleNet(obs_size, HIDDEN_SIZE, n_actions)
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net2.parameters(), lr=0.01)

In [None]:
def final_returns(batch: Episodes, gamma: float = 0.9) -> torch.FloatTensor:
    """
    Returns a tensor of final discounted returns for each episode.
    Calculated as: final_reward * (gamma ** (episode_length - 1))

    Args:
        gamma: Discount factor (default: 0.9)

    Returns:
        torch.FloatTensor: A tensor containing the final discounted return for each episode.
    """
    returns = [ep.score() * (gamma ** len(ep)) for ep in batch]
    return torch.tensor(returns, dtype=torch.float32)


def filter_batch2(batch: Episodes, percentile: int) -> tuple[Episodes, torch.Tensor,torch.Tensor, float]:
    ep_returns = final_returns(batch, GAMMA)
    reward_bound = np.percentile(ep_returns.numpy(), percentile)
    reward_mean = batch.scores().mean(dtype=torch.float32).item()

    best_batches = Episodes()

    for ep, disc_reward in zip(batch, ep_returns):
        if disc_reward >= reward_bound and disc_reward != 0.:
            best_batches.add(ep)

    return best_batches, best_batches.observations(), best_batches.actions(), reward_bound, reward_mean

In [None]:
def train2(env: gym.Env, net: nn.Module) -> None:
    solve_threshold = 0.8
    
    wb = WeightsAndBiases(
        project_name="CrossEntropyFrozenLake", 
        config={
            "env": "FrozenLake-v1",
            "slippery": False,
            "hidden_size": HIDDEN_SIZE,
            "batch_size": BATCH_SIZE,
            "percentile": PERCENTILE,
            "gamma": GAMMA,
            "solve_threshold": solve_threshold,
            "architecture": net
        }
    )
    wb.init()
    full_batch = Episodes()

    for i_batch, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
        full_batch, obs, acts, reward_bound, reward_mean = filter_batch2(batch + full_batch, PERCENTILE)

        if not full_batch:
            continue

        full_batch = full_batch[-500:]

        optimizer.zero_grad()
        action_scores_v = net(obs)
        loss_v: torch.Tensor = loss(action_scores_v, acts)
        loss_v.backward()
        optimizer.step()

        print(f"{i_batch}: loss={loss_v.item():.3f}, reward_mean={reward_mean:.3f}, rw_bound={reward_bound:.3f}, batch={len(full_batch)}")
        wb.log({
            "ep_idx": i_batch,
            "loss": loss_v,
            "reward_mean": reward_mean,
            "reward_bound": reward_bound,
            "batch": len(full_batch),
        })

        if reward_mean > solve_threshold:
            print("Solved!")
            wb.finish()
            break

train2(env2, net2)

0: loss=1.384, reward_mean=0.010, rw_bound=0.000, batch=1
1: loss=1.311, reward_mean=0.010, rw_bound=0.000, batch=1
2: loss=1.285, reward_mean=0.040, rw_bound=0.000, batch=4
3: loss=1.259, reward_mean=0.106, rw_bound=0.000, batch=11
4: loss=1.226, reward_mean=0.153, rw_bound=0.000, batch=17
5: loss=1.172, reward_mean=0.222, rw_bound=0.000, batch=26
6: loss=1.099, reward_mean=0.325, rw_bound=0.117, batch=38
7: loss=1.020, reward_mean=0.435, rw_bound=0.282, batch=46
8: loss=0.921, reward_mean=0.486, rw_bound=0.349, batch=50
9: loss=0.802, reward_mean=0.547, rw_bound=0.430, batch=52
10: loss=0.690, reward_mean=0.599, rw_bound=0.478, batch=52
11: loss=0.630, reward_mean=0.586, rw_bound=0.478, batch=67
12: loss=0.505, reward_mean=0.707, rw_bound=0.531, batch=51
13: loss=0.428, reward_mean=0.709, rw_bound=0.531, batch=69
14: loss=0.374, reward_mean=0.769, rw_bound=0.531, batch=100
15: loss=0.318, reward_mean=0.795, rw_bound=0.531, batch=137
16: loss=0.275, reward_mean=0.840, rw_bound=0.531, 

0,1
batch,▁▁▁▁▂▂▂▃▃▃▃▄▃▄▅▆█
ep_idx,▁▁▂▂▃▃▄▄▅▅▅▆▆▇▇██
loss,██▇▇▇▇▆▆▅▄▄▃▂▂▂▁▁
reward_bound,▁▁▁▁▁▁▃▅▆▇▇▇█████
reward_mean,▁▁▁▂▂▃▄▅▅▆▆▆▇▇▇██

0,1
batch,181.0
ep_idx,16.0
loss,0.27474
reward_bound,0.53144
reward_mean,0.83966
