In [1]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np
import copy

In [2]:
import gym
import highway_env
from matplotlib import pyplot as plt
from gym.spaces import Discrete, Box

In [3]:
import random
from keyboard import read_key

In [32]:
if torch.cuda.is_available():  
  dev = "cuda:0"
else:  
  dev = "cpu"

device = torch.device(dev)

In [46]:
env = gym.make("highway-fast-v0")
env.config["duration"] = 100
observation_dimension = env.observation_space.shape[0]
# print("observation_dimension: ", observation_dimension)
n_acts = env.action_space.n

In [47]:
#############################################
####### BUILDING A NEURAL NETWORK ###########
##### REPRESENTING A STOCHASTIC POLICY ######
#############################################

# net_stochastic_policy is a neural network representing a stochastic policy:
# it takes as inputs observations and outputs logits for each action
net_stochastic_policy = nn.Sequential(
    # nn.ReLU(),
    nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(in_features=36, out_features=36),
    nn.Tanh(),
    nn.Linear(in_features=36, out_features=n_acts)
    )
net_stochastic_policy.to(dev)
old_net_stochastic_policy = nn.Sequential(
    # nn.ReLU(),
    nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(in_features=36, out_features=36),
    nn.Tanh(),
    nn.Linear(in_features=36, out_features=n_acts)
    )
old_net_stochastic_policy.to(dev)

#### DECOMMENT THE FOLLOWING LINE TO LOAD A PRE-TRAINED MODEL ###
net_stochastic_policy.load_state_dict(torch.load("models PPO/model_2"))
#################################################################

# policy inputs an observation and computes a distribution on actions
def policy(observation):
    if len(observation.shape) == 2:
        observation = observation.unsqueeze(0).unsqueeze(0)
    else:
        observation = observation.unsqueeze(1)
    # print("\ndimension en entrée du réseau : ", observation.shape)
    logits = net_stochastic_policy(observation).squeeze(1)
    # print("\ndimension des logits : ", logits.shape)
    return Categorical(logits=logits)

def old_policy(observation):
    if len(observation.shape) == 2:
        observation = observation.unsqueeze(0).unsqueeze(0)
    else:
        observation = observation.unsqueeze(1)
    # print("\ndimension en entrée du réseau : ", observation.shape)
    logits = old_net_stochastic_policy(observation).squeeze(1)
    # print("\ndimension des logits : ", logits.shape)
    return Categorical(logits=logits)

# choose an action (outputs an int sampled from policy)
def choose_action(observation):
    # print("\ndimension en entrée de choose_action : ", observation.shape)
    observation = torch.as_tensor(observation, dtype=torch.float32).to(dev)
    return policy(observation).sample().item()

# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(batch_observations, batch_actions, batch_weights):
    batch_logprobability = policy(batch_observations).log_prob(batch_actions)
    batch_old_logprobability = old_policy(batch_observations).log_prob(batch_actions)
    # clip ratio to avoid instability of POO
    ratio = (batch_logprobability/batch_old_logprobability).clip(1-learning_rate,1+learning_rate)
    # return -(batch_logprobability * batch_weights).mean()
    return -(batch_logprobability * ratio * batch_weights).mean()

### Constants for training
learning_rate = 1e-2
epochs = 10 # 50
batch_size = 2000 # 5000
##########################

# make optimizer
optimizer = Adam(net_stochastic_policy.parameters(), lr = learning_rate)

In [48]:
#############################################
######### VANILLA POLICY GRADIENT ###########
#############################################

def vanilla_policy_gradient():
    for i in range(epochs):
        batch_observations = [] 
        batch_actions = []      
        batch_weights = []
        batch_weights_old = []
        batch_returns = []      
        batch_lengths = []      

        observation = env.reset()
        # print("observation actual dimension: ", observation.shape)
        # print("observation: ", observation)
        done = False            
        rewards_in_episode = []            # list for rewards in the current episode

        # First step: collect experience by simulating the environment with current policy
        while True:
            # print("vroum")
            batch_observations.append(observation.copy())

            # act in the environment
            action = choose_action(observation)
            observation, reward, done, _ = env.step(action)

            # save action, reward
            batch_actions.append(action)
            rewards_in_episode.append(reward)

            if done:
                # print("\ncrash")
                # if episode is over, record info about episode
                episode_return, episode_length = sum(rewards_in_episode), len(rewards_in_episode)
                batch_returns.append(episode_return)
                batch_lengths.append(episode_length)

                # the weight for each logprobability(action|observation)
                batch_weights += [episode_return] * episode_length

                # reset episode-specific variables
                observation, done, rewards_in_episode = env.reset(), False, []

                # end experience loop if we have enough of it
                if len(batch_observations) > batch_size:
                    break

        # Step second: update the policy
        # if i==0:
        #     batch_weights_old = batch_weights.copy()
        # ratio = np.divide(batch_weights, batch_weights_old)
        # if ratio.all() > 1+learning_rate:
        #     ratio = 1+learning_rate
        # if ratio.all() < 1-learning_rate:
        #     ratio = 1-learning_rate
        # J = batch_weights*ratio
        # take a single policy gradient update step
        old_net_dict = copy.deepcopy(net_stochastic_policy.state_dict())
        if(i == 0):
            old_net_stochastic_policy.load_state_dict(old_net_dict)
        optimizer.zero_grad()
        batch_loss = compute_loss(torch.as_tensor(batch_observations, dtype=torch.float32).to(dev),
                                  torch.as_tensor(batch_actions, dtype=torch.int32).to(dev),
                                  torch.as_tensor(batch_weights, dtype=torch.float32).to(dev)
                                  )
        batch_loss.backward()
        optimizer.step()
        if i>0:
            old_net_stochastic_policy.load_state_dict(old_net_dict)

        print('epoch: %3d \t loss: %.3f \t return: %.3f \t episode_length: %.3f'%
                (i, batch_loss, np.mean(batch_returns), np.mean(batch_lengths)))

In [36]:
vanilla_policy_gradient()

epoch:   0 	 loss: 19.411 	 return: 8.217 	 episode_length: 10.822
epoch:   1 	 loss: 21.024 	 return: 9.126 	 episode_length: 12.164
epoch:   2 	 loss: 20.731 	 return: 8.929 	 episode_length: 12.006
epoch:   3 	 loss: 23.982 	 return: 11.004 	 episode_length: 14.765
epoch:   4 	 loss: 22.899 	 return: 10.536 	 episode_length: 14.355
epoch:   5 	 loss: 24.091 	 return: 11.841 	 episode_length: 16.104
epoch:   6 	 loss: 25.094 	 return: 13.368 	 episode_length: 18.400
epoch:   7 	 loss: 25.677 	 return: 14.662 	 episode_length: 20.253
epoch:   8 	 loss: 25.709 	 return: 16.005 	 episode_length: 22.333
epoch:   9 	 loss: 25.105 	 return: 16.864 	 episode_length: 23.512


In [38]:
#### DECOMMENT TO SAVE THE MODEL, BE CAREFUL OF ERASING ANOTHER ONE ###
# torch.save(net_stochastic_policy.state_dict(), "models PPO/model_2")

In [44]:
###### EVALUATION ############

def run_episode(env, render = False):
    obs = env.reset()
    total_reward = 0
    done = False
    while not done:
        if render:
            env.render()
        action = choose_action(obs)
        obs, reward, done, _ = env.step(action)
        total_reward += reward
        if done:
            break
    print("episode done")
    return total_reward

policy_scores = [run_episode(env) for _ in range(20)] #100
print("Average score of the policy: ", np.mean(policy_scores))

episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
Average score of the policy:  9.536407168804105


In [49]:
###### DEMONSTRATION ############

go = True
while go:
    for _ in range(5):
        run_episode(env, True)
    print("Press r to restart simulation, otherwise another key")
    if read_key() != "r":
        go = False

env.close()

episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key
episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key
