In [None]:
import gymnasium as gym
import numpy as np
import torch
from torch import nn
import time


In [None]:
# ============================
# Agent
# ============================
class DeepCrossEntropyAgent(nn.Module):
    def __init__(self, state_dim, num_action, load_state_path=None):
        super().__init__()
        self.state_dim = state_dim
        self.num_action = num_action

        self.nn_model = self.get_model()

        if load_state_path is not None:
            checkpoint = torch.load(load_state_path)
            self.nn_model.load_state_dict(checkpoint["model_state_dict"])

        self.softmax = nn.Softmax(dim=-1)
        self.optimizer = torch.optim.Adam(self.nn_model.parameters(), lr=0.01)
        self.loss = nn.CrossEntropyLoss()

    def get_model(self):
        return nn.Sequential(
            nn.Linear(self.state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, self.num_action)
        )

    def forward(self, x):
        return self.nn_model(x)

    def get_action(self, state):
        state = torch.tensor(state, dtype=torch.float32)
        logits = self.forward(state)
        probs = self.softmax(logits).detach().numpy()
        return np.random.choice(self.num_action, p=probs)

    def update_policy(self, elite_trajectories):
        elite_states = []
        elite_actions = []

        for traj in elite_trajectories:
            elite_states.extend(traj["states"])
            elite_actions.extend(traj["actions"])

        elite_states = torch.tensor(elite_states, dtype=torch.float32)
        elite_actions = torch.tensor(elite_actions, dtype=torch.long)

        self.optimizer.zero_grad()
        loss = self.loss(self.forward(elite_states), elite_actions)
        loss.backward()
        self.optimizer.step()

        print(f"Loss: {loss.item():.4f}")

        torch.save(
            {"model_state_dict": self.nn_model.state_dict()},
            "nn_model_lander.pth"
        )

# ============================
# Trajectory generation
# ============================
def get_trajectory(agent, env, max_len=400, visualize=False):
    trajectory = {"states": [], "actions": [], "tot_reward": 0}

    state, _ = env.reset()

    for _ in range(max_len):
        trajectory["states"].append(state)

        action = agent.get_action(state)
        trajectory["actions"].append(action)

        state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated
        trajectory["tot_reward"] += reward

        if visualize:
            env.render()
            time.sleep(0.03)

        if done:
            break

    return trajectory


# ============================
# Elite trajectories selection
# ============================
def get_elite_trajectories(trajectories, q_param):
    rewards = [t["tot_reward"] for t in trajectories]
    threshold = np.quantile(rewards, q_param)
    return [t for t in trajectories if t["tot_reward"] > threshold]

# ============================
# Environment
# ============================
env = gym.make("LunarLander-v3")
state_dim = 8
num_action = 4


In [None]:
# ============================
# Training
# ============================
def train():
    num_trajectory = 200 #500
    max_len = 300 #400
    q_param = 0.6
    num_iter = 100

    agent = DeepCrossEntropyAgent(state_dim, num_action)

    for i in range(num_iter):
        t_start = time.time()
        trajectories = [
            get_trajectory(agent, env, max_len=max_len)
            for _ in range(num_trajectory)
        ]

        mean_reward = np.mean([t["tot_reward"] for t in trajectories])
        mean_len = np.mean([len(t["states"]) for t in trajectories])

        elite_trajectories = get_elite_trajectories(trajectories, q_param)

        if elite_trajectories:
            agent.update_policy(elite_trajectories)

        print(
            f"Iter {i:03d} | "
            f"Mean Reward: {mean_reward:.2f} | "
            f"Mean Len: {mean_len:.1f} | "
            f"Elite: {len(elite_trajectories)} | "
            f"Iter time: {time.time() - t_start}"
        )


# ============================
# Run
# ============================
train()


In [None]:
# ============================
# Testing
# ============================
def test(max_len=500, load_state_path="nn_model_lander.pth"):
    env = gym.make("LunarLander-v3", render_mode="human")
    agent = DeepCrossEntropyAgent(
        state_dim, num_action, load_state_path=load_state_path
    )

    traj = get_trajectory(agent, env, max_len=max_len, visualize=True)
    print("Total reward:", traj["tot_reward"])

test(load_state_path='nn_model_lander.pth')