In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import gym

import time
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class CEMAgent(nn.Module):
    def __init__(self, state_n, action_n, lr, opt_f=torch.optim.Adam):
        super(CEMAgent, self).__init__()
        self.state_n = state_n
        self.action_n = action_n
        self.lr = lr
        self.net = nn.Sequential(
            nn.Linear(self.state_n, 5),
            nn.ReLU(),
            nn.Linear(5, self.action_n))
        
        self.softmax = nn.Softmax()
        self.loss_f = nn.CrossEntropyLoss()
        self.opt = opt_f(self.parameters(), lr=self.lr)
    def forward(self, x):
        return self.net(x)
    
    def get_action(self, state):
        state = torch.FloatTensor(state)
        logits = self(state)
        probs = self.softmax(logits)
        action = np.random.choice(self.action_n, p=probs.detach().numpy())
        return action
    def update_policy(self, elite_trajectories):
        elite_states, elite_actions = [], []
        for trajectory in elite_trajectories:
            elite_states.extend(trajectory['states'])
            elite_actions.extend(trajectory['actions'])
            
        elite_states = torch.FloatTensor(elite_states)
        elite_actions = torch.LongTensor(elite_actions)
        loss = self.loss_f(self(elite_states), elite_actions)
        loss.backward()
        self.opt.step()
        self.opt.zero_grad()
        return loss.item()

In [3]:
def get_trajectory(trajectory_len, env, agent):
    trajectory = {'states': [], 'actions': [], 'reward': 0}
    state = env.reset()
    trajectory['states'] += [state]
    for _ in range(trajectory_len):
        action = agent.get_action(state)
        state, reward, done, _ = env.step(action)
        trajectory['actions'] += [action]
        trajectory['reward'] += reward
        if done:
            break
        trajectory['states'] += [state]
    return trajectory

In [4]:
def get_elite_trajectories(trajectories, q_param):
    rewards = [trajectory['reward'] for trajectory in trajectories]
    q_value = np.quantile(rewards, q_param)
    return np.mean(rewards), [trajectory for trajectory in trajectories if trajectory['reward'] > q_value]

In [5]:
def train(epochs, env, agent, traj_per_epoch, traj_len, q_param):
    start = time.perf_counter()
    history = []
    for epoch in range(epochs):
        loss = 0
        trajectories = [get_trajectory(traj_len, env, agent) for _ in range(traj_per_epoch)]
        mean_reward, elite_trajectories = get_elite_trajectories(trajectories, q_param)
        history += [mean_reward]
        if len(elite_trajectories) > 0: 
            loss = agent.update_policy(elite_trajectories)
        print(f'{epoch=}, {loss=}, {mean_reward=}')
    end = time.perf_counter()
    print(f'Training took {round(end-start, 5)}')
    return history

In [6]:
env = gym.make('CartPole-v1')

state_n, action_n = 4, 2
lr = 0.1
agent = CEMAgent(state_n=state_n, action_n=action_n, lr=lr)

epochs = 51
traj_per_epoch = 100
traj_len = 500
q_param = 0.8

history = train(epochs, env, agent, traj_per_epoch, traj_len, q_param)

  deprecation(
  deprecation(
  probs = self.softmax(logits)
  elite_states = torch.FloatTensor(elite_states)


epoch=0, loss=0.6883884072303772, mean_reward=22.4
epoch=1, loss=0.6734998822212219, mean_reward=27.77
epoch=2, loss=0.656034529209137, mean_reward=35.23
epoch=3, loss=0.6410014629364014, mean_reward=39.16
epoch=4, loss=0.6343991756439209, mean_reward=40.24
epoch=5, loss=0.6196318864822388, mean_reward=48.61
epoch=6, loss=0.6039091944694519, mean_reward=49.56
epoch=7, loss=0.5854854583740234, mean_reward=55.22
epoch=8, loss=0.5794401168823242, mean_reward=56.82
epoch=9, loss=0.5746863484382629, mean_reward=61.17
epoch=10, loss=0.5696635842323303, mean_reward=68.91
epoch=11, loss=0.5547099113464355, mean_reward=82.39
epoch=12, loss=0.5460835695266724, mean_reward=78.4
epoch=13, loss=0.5378401875495911, mean_reward=96.82
epoch=14, loss=0.5218889117240906, mean_reward=116.98
epoch=15, loss=0.5323563814163208, mean_reward=178.06
epoch=16, loss=0.5283113121986389, mean_reward=243.79
epoch=17, loss=0.5257488489151001, mean_reward=270.16
epoch=18, loss=0, mean_reward=320.13
epoch=19, loss=0, 