# Day 28 - The Cross-Entropy Method

* The cross-entropy method is model-free, policy-based, and on-policy
* The method itself is quite simple
    1. Play $N$ episodes
    2. Calculate returns and set a boundary—usually 50th or 70th percentile
    3. Discard episodes below the boundary
    4. Perform supervised learning of the policy, using the remaining episodes
    5. Be satisfied, or goto 1

## The cross-entropy method on CartPole

In [1]:
import numpy as np
import gymnasium as gym
from dataclasses import dataclass
import typing as tt
from torch.utils.tensorboard.writer import SummaryWriter

import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
device = "cpu" # This is extremely slow on the GPU, due to all the data transfer

In [3]:
HIDDEN_SIZE = 128
BATCH_SIZE = 32
PERCENTILE = 70

In [4]:
class Net(nn.Module):
    def __init__(self, obs_size: int, hidden_size: int, n_actions: int):
        super(Net, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions)
        )

    def forward(self, x: torch.Tensor):
        return self.net(x)


@dataclass
class EpisodeStep:
    observation: np.ndarray
    action: int


@dataclass
class Episode:
    reward: float
    steps: tt.List[EpisodeStep]

In [5]:
def iterate_batches(env: gym.Env, net: Net, batch_size: int) -> tt.Generator[tt.List[Episode], None, None]:
    batch = []
    episode_reward = 0.0
    episode_steps = []
    obs, _ = env.reset()
    sm = nn.Softmax(dim=1).to(device=device)

    while True:
        obs_v = torch.tensor(obs, dtype=torch.float32, device=device)
        act_probs_v = sm(net(obs_v.unsqueeze(0)))
        act_probs = act_probs_v.cpu().data.numpy()[0]
        action = np.random.choice(len(act_probs), p=act_probs)
        next_obs, reward, is_done, is_trunc, _ = env.step(action)
        episode_reward += float(reward)
        step = EpisodeStep(observation=obs, action=action)
        episode_steps.append(step)
        if is_done or is_trunc:
            e = Episode(reward=episode_reward, steps=episode_steps)
            batch.append(e)
            episode_reward = 0.0
            episode_steps = []
            next_obs, _ = env.reset()
            if len(batch) == batch_size:
                yield batch
                batch = []

        obs = next_obs

In [6]:
def filter_batch(batch: tt.List[Episode], percentile: float) -> tt.Tuple[torch.FloatTensor, torch.LongTensor, float, float]:
    rewards = list(map(lambda s: s.reward, batch))
    reward_bound = float(np.percentile(rewards, percentile))
    reward_mean = float(np.mean(rewards))

    train_obs: tt.List[np.ndarray] = []
    train_act: tt.List[int] = []

    for episode in batch:
        if episode.reward < reward_bound:
            continue
        train_obs.extend(map(lambda step: step.observation, episode.steps))
        train_act.extend(map(lambda step: step.action, episode.steps))

    train_obs_v = torch.FloatTensor(np.vstack(train_obs))
    train_act_v = torch.LongTensor(train_act)

    return train_obs_v, train_act_v, reward_bound, reward_mean

In [7]:
env = gym.make("CartPole-v1", render_mode="rgb_array")
video_folder = "./DRL/videos/cartpole-CEM"
env = gym.wrappers.RecordVideo(env, video_folder=video_folder)
assert env.observation_space.shape is not None
obs_size = env.observation_space.shape[0]
assert isinstance(env.action_space, gym.spaces.Discrete)
n_actions = int(env.action_space.n)

  logger.warn(


In [8]:
net = Net(obs_size, HIDDEN_SIZE, n_actions).to(device=device)
print(net)
objective = nn.CrossEntropyLoss().to(device=device)
optimizer = optim.Adam(params=net.parameters(), lr=1e-1)
writer = SummaryWriter(comment="-cartpole")

for iter_no, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
    obs_v, acts_v, reward_b, reward_m = filter_batch(batch, PERCENTILE)
    optimizer.zero_grad()
    action_scores_v = net(obs_v.to(device=device, non_blocking=True))
    loss_v = objective(action_scores_v, acts_v.to(device=device))
    loss_v.backward()
    optimizer.step()
    print(
        "{0}: loss={1:.3f}, reward_mean={2:.1f}, rw_bound={3:.1f}".format(
        iter_no, loss_v.item(), reward_m, reward_b),
        end="\t\t\r"
    )
    writer.add_scalar("loss", loss_v.item(), iter_no)
    writer.add_scalar("reward_bound", reward_b, iter_no)
    writer.add_scalar("reward_mean", reward_m, iter_no)
    if reward_m > 499.9:
        print("\nSolved!")
        break

writer.close()

Net(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=2, bias=True)
  )
)
36: loss=0.106, reward_mean=500.0, rw_bound=500.0		
Solved!


In [9]:
from IPython.display import Video
import os
from pathlib import Path

# Get list of video files with their modification times
video_files = [(f, os.path.getmtime(os.path.join(video_folder, f))) 
               for f in os.listdir(video_folder)]

# Sort by modification time (newest first) and get the most recent file
most_recent_video = sorted(video_files, key=lambda x: x[1], reverse=True)[0][0]
print(most_recent_video)

# Display the most recent video
Video(url=video_folder + "/" + most_recent_video)

rl-video-episode-1000.mp4
