# Deep Crossentropy Method

In [None]:
import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
n_sessions = 100      # sample this many sessions
percentile = 50       # take this percent of session with highest rewards
learning_rate = 0.01  # for optimizer

BATCH_SIZE = 100
STOP_VALUE_SCORE = 195

In [None]:
class DeepCEM(nn.Module):

    def __init__(self, n_states, n_actions):
        super().__init__()
        self.fc1 = nn.Linear(n_states, 200)
        self.fc2 = nn.Linear(200, n_actions)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


def generate_batch(env, t_max=5_000):
    
    activation = nn.Softmax(dim=1)
    batch_actions, batch_states, batch_rewards = [], [], []
    
    for b in range(BATCH_SIZE):

        s = env.reset()
        total_reward = 0.
        states, actions = [], []
        
        for t in range(t_max):
            
            if b == 0:
                env.render()
            s_v = torch.FloatTensor([s])
            act_probs_v = activation(nn_cem(s_v))
            act_probs = act_probs_v.data.numpy()[0]
            a = np.random.choice(len(act_probs), p=act_probs)

            new_s, r, done, info = env.step(a)

            states.append(s)
            actions.append(a)
            total_reward += r

            s = new_s

            if done:
                batch_actions.append(actions)
                batch_states.append(states)
                batch_rewards.append(total_reward)
                break
                
    return batch_states, batch_actions, batch_rewards


def filter_batch(states, actions, rewards, percentile=70):
    
    reward_threshold = np.percentile(rewards, percentile)
    
    elite_states, elite_actions = [], []
    
    for i in range(len(rewards)):
        if rewards[i] > reward_threshold:
            for j in range(len(states[i])):
                elite_states.append(states[i][j])
                elite_actions.append(actions[i][j])
    
    return elite_states, elite_actions

In [None]:
env = gym.make('CartPole-v1')

n_states = env.observation_space.shape[0]
n_actions = env.action_space.n
print(f"n_states={n_states}, n_actions={n_actions}")

In [None]:
nn_cem = DeepCEM(n_states, n_actions)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=nn_cem.parameters(), lr=learning_rate)

In [None]:
for i in range(n_sessions):
    
    batch_states, batch_actions, batch_rewards = generate_batch(env)
    elite_states, elite_actions = filter_batch(batch_states, batch_actions, batch_rewards, percentile)
    
    optimizer.zero_grad()
    tensor_states = torch.FloatTensor(elite_states)
    tensor_actions = torch.LongTensor(elite_actions)
    
    predicted_actions = nn_cem(tensor_states)
        
    loss_value = criterion(predicted_actions, tensor_actions)
    loss_value.backward()
    optimizer.step()

    mean_reward = np.mean(batch_rewards)
    threshold = np.percentile(batch_rewards, percentile)
        
    print(f"{i}: loss={loss_value.item():.3f}, reward_mean={mean_reward:.1f}, reward_threshold={threshold:.1f}")
    
    if mean_reward > STOP_VALUE_SCORE:
        print('Congratulations, you\'ve solved this challenge!')
        break