In [8]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np

In [9]:
import gym
import highway_env
from matplotlib import pyplot as plt
from gym.spaces import Discrete, Box

In [10]:
import random
from keyboard import read_key

In [11]:
if torch.cuda.is_available():  
  dev = "cuda:0"
else:  
  dev = "cpu"

device = torch.device(dev)

In [12]:
env = gym.make("highway-fast-v0")
observation_dimension = env.observation_space.shape[0]
# print("observation_dimension: ", observation_dimension)
n_acts = env.action_space.n

#############################################
####### BUILDING A NEURAL NETWORK ###########
##### REPRESENTING A STOCHASTIC POLICY ######
#############################################

# net_stochastic_policy is a neural network representing a stochastic policy:
# it takes as inputs observations and outputs logits for each action
net_stochastic_policy = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3),
        nn.Flatten(),
        nn.Linear(in_features=36, out_features=36),
        nn.Tanh(),
        nn.Linear(in_features=36, out_features=n_acts)
        )
net_stochastic_policy.to(dev)

#### DECOMMENT THE FOLLOWING LINE TO LOAD A PRE-TRAINED MODEL ###
# net_stochastic_policy.load_state_dict(torch.load("models/model_1"))
#################################################################

# policy inputs an observation and computes a distribution on actions
def policy(observation):
    if len(observation.shape) == 2:
        observation = observation.unsqueeze(0).unsqueeze(0)
    else:
        observation = observation.unsqueeze(1)
    # print("\ndimension en entrée du réseau : ", observation.shape)
    logits = net_stochastic_policy(observation).squeeze(1)
    # print("\ndimension des logits : ", logits.shape)
    return Categorical(logits=logits)

# choose an action (outputs an int sampled from policy)
def choose_action(observation):
    # print("\ndimension en entrée de choose_action : ", observation.shape)
    observation = torch.as_tensor(observation, dtype=torch.float32).to(dev)
    return policy(observation).sample().item()

# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(batch_observations, batch_actions, batch_weights):
    batch_logprobability = policy(batch_observations).log_prob(batch_actions)
    return -(batch_logprobability * batch_weights).mean()

### Constants for training
learning_rate = 5e-2
epochs = 15 # 50
batch_size = 1000 # 5000
##########################

# make optimizer
optimizer = Adam(net_stochastic_policy.parameters(), lr = learning_rate)

In [13]:
#############################################
######### VANILLA POLICY GRADIENT ###########
#############################################

def vanilla_policy_gradient():
    for i in range(epochs):
        batch_observations = [] 
        batch_actions = []      
        batch_weights = []      
        batch_returns = []      
        batch_lengths = []      

        observation = env.reset()
        # print("observation actual dimension: ", observation.shape)
        # print("observation: ", observation)
        done = False            
        rewards_in_episode = []            # list for rewards in the current episode

        # First step: collect experience by simulating the environment with current policy
        while True:
            # print("vroum")
            batch_observations.append(observation.copy())

            # act in the environment
            action = choose_action(observation)
            observation, reward, done, _ = env.step(action)

            # save action, reward
            batch_actions.append(action)
            rewards_in_episode.append(reward)

            if done:
                # print("\ncrash")
                # if episode is over, record info about episode
                episode_return, episode_length = sum(rewards_in_episode), len(rewards_in_episode)
                batch_returns.append(episode_return)
                batch_lengths.append(episode_length)

                # the weight for each logprobability(action|observation)
                batch_weights += [episode_return] * episode_length

                # reset episode-specific variables
                observation, done, rewards_in_episode = env.reset(), False, []

                # end experience loop if we have enough of it
                if len(batch_observations) > batch_size:
                    break

        # Step second: update the policy
        # take a single policy gradient update step
        optimizer.zero_grad()
        batch_loss = compute_loss(torch.as_tensor(batch_observations, dtype=torch.float32).to(dev),
                                  torch.as_tensor(batch_actions, dtype=torch.int32).to(dev),
                                  torch.as_tensor(batch_weights, dtype=torch.float32).to(dev)
                                  )
        batch_loss.backward()
        optimizer.step()

        print('epoch: %3d \t loss: %.3f \t return: %.3f \t episode_length: %.3f'%
                (i, batch_loss, np.mean(batch_returns), np.mean(batch_lengths)))

In [14]:
vanilla_policy_gradient()

epoch:   0 	 loss: 19.555 	 return: 8.194 	 episode_length: 11.011
epoch:   1 	 loss: 18.947 	 return: 19.120 	 episode_length: 26.737
epoch:   2 	 loss: 6.709 	 return: 19.947 	 episode_length: 28.657
epoch:   3 	 loss: 0.839 	 return: 20.643 	 episode_length: 29.343
epoch:   4 	 loss: 0.185 	 return: 20.734 	 episode_length: 29.912
epoch:   5 	 loss: 0.005 	 return: 20.620 	 episode_length: 29.429
epoch:   6 	 loss: 0.002 	 return: 20.806 	 episode_length: 29.735
epoch:   7 	 loss: 0.001 	 return: 20.487 	 episode_length: 29.500
epoch:   8 	 loss: 0.000 	 return: 20.975 	 episode_length: 29.618
epoch:   9 	 loss: 0.000 	 return: 20.903 	 episode_length: 30.000
epoch:  10 	 loss: 0.000 	 return: 20.407 	 episode_length: 29.171
epoch:  11 	 loss: 0.000 	 return: 20.863 	 episode_length: 29.882
epoch:  12 	 loss: 0.000 	 return: 20.115 	 episode_length: 29.029
epoch:  13 	 loss: 0.000 	 return: 20.420 	 episode_length: 29.114
epoch:  14 	 loss: 0.000 	 return: 21.118 	 episode_length: 3

In [22]:
torch.save(net_stochastic_policy.state_dict(), "models/model_2")

In [15]:
###### EVALUATION ############

def run_episode(env, render = False):
    obs = env.reset()
    total_reward = 0
    done = False
    while not done:
        if render:
            env.render()
        action = choose_action(obs)
        obs, reward, done, _ = env.step(action)
        total_reward += reward
        if done:
            break
    print("episode done")
    return total_reward

policy_scores = [run_episode(env) for _ in range(20)] #100
print("Average score of the policy: ", np.mean(policy_scores))

episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
Average score of the policy:  21.02022116903634


In [16]:
###### DEMONSTRATION ############

go = True
while go:
    for _ in range(5):
        run_episode(env, True)
    print("Press r to restart simulation, otherwise another key")
    if read_key() != "r":
        go = False

env.close()

episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key
episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key
episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key
episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key
episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key
episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key
episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key
episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key


In [20]:
torch.save(net_stochastic_policy.state_dict(), "models/model_1")