In [1]:
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.optim import Adam
import numpy as np

In [2]:
import gym
import highway_env
from matplotlib import pyplot as plt
from gym.spaces import Discrete, Box

In [3]:
import random
from keyboard import read_key

In [4]:
if torch.cuda.is_available():  
  dev = "cuda:0"
else:  
  dev = "cpu"

device = torch.device(dev)

In [10]:
env = gym.make("highway-fast-v0")
env.config["duration"] = 100
observation_dimension = env.observation_space.shape[0]
# print("observation_dimension: ", observation_dimension)
n_acts = env.action_space.n

In [11]:
#############################################
####### BUILDING A NEURAL NETWORK ###########
##### REPRESENTING A STOCHASTIC POLICY ######
#############################################

# net_stochastic_policy is a neural network representing a stochastic policy:
# it takes as inputs observations and outputs logits for each action
net_stochastic_policy = nn.Sequential(
    # nn.ReLU(),
    nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3),
    nn.ReLU(),
    nn.Flatten(),
    nn.Linear(in_features=36, out_features=36),
    nn.Tanh(),
    nn.Linear(in_features=36, out_features=n_acts)
    )
net_stochastic_policy.to(dev)

#### DECOMMENT THE FOLLOWING LINE TO LOAD A PRE-TRAINED MODEL ###
net_stochastic_policy.load_state_dict(torch.load("models/model_2"))
#################################################################

# policy inputs an observation and computes a distribution on actions
def policy(observation):
    if len(observation.shape) == 2:
        observation = observation.unsqueeze(0).unsqueeze(0)
    else:
        observation = observation.unsqueeze(1)
    # print("\ndimension en entrée du réseau : ", observation.shape)
    logits = net_stochastic_policy(observation).squeeze(1)
    # print("\ndimension des logits : ", logits.shape)
    return Categorical(logits=logits)

# choose an action (outputs an int sampled from policy)
def choose_action(observation):
    # print("\ndimension en entrée de choose_action : ", observation.shape)
    observation = torch.as_tensor(observation, dtype=torch.float32).to(dev)
    return policy(observation).sample().item()

# make loss function whose gradient, for the right data, is policy gradient
def compute_loss(batch_observations, batch_actions, batch_weights):
    batch_logprobability = policy(batch_observations).log_prob(batch_actions)
    return -(batch_logprobability * batch_weights).mean()

### Constants for training
learning_rate = 1e-2
epochs = 200 # 50
batch_size = 1000 # 5000
##########################

# make optimizer
optimizer = Adam(net_stochastic_policy.parameters(), lr = learning_rate)

In [12]:
#############################################
######### VANILLA POLICY GRADIENT ###########
#############################################

def vanilla_policy_gradient():
    for i in range(epochs):
        batch_observations = [] 
        batch_actions = []      
        batch_weights = []      
        batch_returns = []      
        batch_lengths = []      

        observation = env.reset()
        # print("observation actual dimension: ", observation.shape)
        # print("observation: ", observation)
        done = False            
        rewards_in_episode = []            # list for rewards in the current episode

        # First step: collect experience by simulating the environment with current policy
        while True:
            # print("vroum")
            batch_observations.append(observation.copy())

            # act in the environment
            action = choose_action(observation)
            observation, reward, done, _ = env.step(action)

            # save action, reward
            batch_actions.append(action)
            rewards_in_episode.append(reward)

            if done:
                # print("\ncrash")
                # if episode is over, record info about episode
                episode_return, episode_length = sum(rewards_in_episode), len(rewards_in_episode)
                batch_returns.append(episode_return)
                batch_lengths.append(episode_length)

                # the weight for each logprobability(action|observation)
                batch_weights += [episode_return] * episode_length

                # reset episode-specific variables
                observation, done, rewards_in_episode = env.reset(), False, []

                # end experience loop if we have enough of it
                if len(batch_observations) > batch_size:
                    break

        # Step second: update the policy
        # take a single policy gradient update step
        optimizer.zero_grad()
        batch_loss = compute_loss(torch.as_tensor(batch_observations, dtype=torch.float32).to(dev),
                                  torch.as_tensor(batch_actions, dtype=torch.int32).to(dev),
                                  torch.as_tensor(batch_weights, dtype=torch.float32).to(dev)
                                  )
        batch_loss.backward()
        optimizer.step()

        print('epoch: %3d \t loss: %.3f \t return: %.3f \t episode_length: %.3f'%
                (i, batch_loss, np.mean(batch_returns), np.mean(batch_lengths)))

In [44]:
vanilla_policy_gradient()

epoch:   0 	 loss: 21.628 	 return: 9.036 	 episode_length: 12.383
epoch:   1 	 loss: 22.697 	 return: 10.112 	 episode_length: 13.877
epoch:   2 	 loss: 25.223 	 return: 11.685 	 episode_length: 16.190
epoch:   3 	 loss: 26.539 	 return: 14.105 	 episode_length: 19.667
epoch:   4 	 loss: 25.200 	 return: 14.839 	 episode_length: 21.000
epoch:   5 	 loss: 25.079 	 return: 18.807 	 episode_length: 26.447
epoch:   6 	 loss: 22.997 	 return: 17.824 	 episode_length: 25.225
epoch:   7 	 loss: 20.117 	 return: 19.799 	 episode_length: 28.194
epoch:   8 	 loss: 17.901 	 return: 19.466 	 episode_length: 27.595
epoch:   9 	 loss: 17.739 	 return: 21.223 	 episode_length: 29.971
epoch:  10 	 loss: 14.885 	 return: 19.414 	 episode_length: 27.757
epoch:  11 	 loss: 15.296 	 return: 19.977 	 episode_length: 28.333
epoch:  12 	 loss: 13.586 	 return: 20.407 	 episode_length: 28.714
epoch:  13 	 loss: 12.723 	 return: 20.867 	 episode_length: 29.618
epoch:  14 	 loss: 11.307 	 return: 20.520 	 epis

epoch: 122 	 loss: 6.761 	 return: 17.312 	 episode_length: 23.953
epoch: 123 	 loss: 9.135 	 return: 19.414 	 episode_length: 26.789
epoch: 124 	 loss: 9.369 	 return: 17.911 	 episode_length: 25.049
epoch: 125 	 loss: 11.189 	 return: 19.565 	 episode_length: 27.026
epoch: 126 	 loss: 11.726 	 return: 18.392 	 episode_length: 25.744
epoch: 127 	 loss: 12.647 	 return: 20.552 	 episode_length: 28.743
epoch: 128 	 loss: 13.569 	 return: 19.910 	 episode_length: 28.083
epoch: 129 	 loss: 13.663 	 return: 21.163 	 episode_length: 30.000
epoch: 130 	 loss: 14.135 	 return: 21.065 	 episode_length: 29.912
epoch: 131 	 loss: 14.665 	 return: 20.465 	 episode_length: 28.914
epoch: 132 	 loss: 14.565 	 return: 20.578 	 episode_length: 29.257
epoch: 133 	 loss: 14.738 	 return: 20.024 	 episode_length: 28.167
epoch: 134 	 loss: 15.066 	 return: 20.639 	 episode_length: 29.200
epoch: 135 	 loss: 14.081 	 return: 20.580 	 episode_length: 29.257
epoch: 136 	 loss: 13.544 	 return: 20.856 	 episod

In [45]:
#### DECOMMENT TO SAVE THE MODEL, BE CAREFUL OF ERASING ANOTHER ONE ###
# torch.save(net_stochastic_policy.state_dict(), "models/model_4")

In [13]:
###### EVALUATION ############

def run_episode(env, render = False):
    obs = env.reset()
    total_reward = 0
    done = False
    while not done:
        if render:
            env.render()
        action = choose_action(obs)
        obs, reward, done, _ = env.step(action)
        total_reward += reward
        if done:
            break
    print("episode done")
    return total_reward

policy_scores = [run_episode(env) for _ in range(20)] #100
print("Average score of the policy: ", np.mean(policy_scores))

episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
episode done
Average score of the policy:  70.37688783112466


In [14]:
###### DEMONSTRATION ############

go = True
while go:
    for _ in range(5):
        run_episode(env, True)
    print("Press r to restart simulation, otherwise another key")
    if read_key() != "r":
        go = False

env.close()

episode done
episode done
episode done
episode done
episode done
Press r to restart simulation, otherwise another key
