# Cross Entropy Example

In [1]:
HIDDEN_SIZE = 128
BATCH_SIZE = 16
PERCENTILE = 70

In [2]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self, obs_size: int, hidden_size: int, n_actions: int) -> None:
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)
    

In [7]:
from typing import Generator
import gymnasium as gym
import numpy as np

from velora import History, Trajectory, Episodes

from torch.distributions import Categorical


def iterate_batches(env: gym.Env, net: nn.Module, batch_size: int) -> Generator[Episodes, None, None]:
    batch = Episodes()
    episode = History()
    obs, _ = env.reset()

    while True:
        action_scores = net(obs.unsqueeze(0))
        probs: torch.Tensor = torch.softmax(action_scores, dim=-1)
        action = Categorical(probs).sample().item()

        next_obs, reward, terminated, truncated, _ = env.step(action)

        episode.add(
            Trajectory(
                action=action, 
                observation=obs, 
                reward=float(reward),
            )
        )

        if terminated or truncated:
            batch.add(episode)
            episode = History()
            next_obs, _ = env.reset()
            
            if len(batch) == batch_size:
                yield batch
                batch = Episodes()
        
        obs = next_obs


In [8]:
def filter_batch(batch: Episodes, percentile: int) -> tuple[torch.Tensor, torch.Tensor, float, float]:
    ep_scores = batch.scores().numpy()
    reward_bound = np.percentile(ep_scores, percentile).item()
    reward_mean = np.mean(ep_scores).item()

    best_batches = Episodes()

    for ep in batch:
        if ep.score() >= reward_bound:
            best_batches.add(ep)
    
    return best_batches.observations(), best_batches.actions(), reward_bound, reward_mean

In [12]:
from gymnasium.wrappers import NumpyToTorch
import torch.optim as optim

env: gym.Env = NumpyToTorch(gym.make("CartPole-v1"))

obs_size = env.observation_space.shape[0]
n_actions = env.action_space.n

net = Net(obs_size, HIDDEN_SIZE, n_actions)
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.01)

In [13]:
solve_threshold = 195

for i_batch, batch in enumerate(iterate_batches(env, net, BATCH_SIZE)):
    obs_v, acts_v, reward_b, reward_m = filter_batch(batch, PERCENTILE)

    if len(obs_v) == 0:
        continue

    optimizer.zero_grad()
    action_scores_v = net(obs_v)
    loss_v: torch.Tensor = loss(action_scores_v, acts_v)
    loss_v.backward()
    optimizer.step()

    print(f"{i_batch}: loss={loss_v:.3f}, reward_mean={reward_m:.1f}, rw_bound={reward_b:.1f}")

    if reward_m > solve_threshold:
        print("Solved!")
        break

0: loss=0.688, reward_mean=19.7, rw_bound=20.0
1: loss=0.683, reward_mean=21.6, rw_bound=22.5
2: loss=0.640, reward_mean=14.9, rw_bound=15.0
3: loss=0.681, reward_mean=18.6, rw_bound=21.0
4: loss=0.623, reward_mean=13.7, rw_bound=14.5
5: loss=0.688, reward_mean=19.7, rw_bound=21.5
6: loss=0.659, reward_mean=19.7, rw_bound=25.5
7: loss=0.688, reward_mean=17.6, rw_bound=19.5
8: loss=0.658, reward_mean=18.7, rw_bound=22.5
9: loss=0.693, reward_mean=25.8, rw_bound=25.5
10: loss=0.657, reward_mean=24.0, rw_bound=30.0
11: loss=0.669, reward_mean=23.2, rw_bound=23.0
12: loss=0.667, reward_mean=34.9, rw_bound=44.0
13: loss=0.670, reward_mean=28.1, rw_bound=35.0
14: loss=0.657, reward_mean=28.9, rw_bound=32.0
15: loss=0.650, reward_mean=33.7, rw_bound=37.5
16: loss=0.658, reward_mean=27.4, rw_bound=32.5
17: loss=0.640, reward_mean=45.9, rw_bound=52.0
18: loss=0.640, reward_mean=33.5, rw_bound=33.0
19: loss=0.631, reward_mean=42.1, rw_bound=45.0
20: loss=0.631, reward_mean=32.9, rw_bound=35.5
21