# Policy Gradients

In [1]:
import random

import pandas as pd
import gym
from gym import wrappers, logger
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

from a2c import A2C, A2CBatch
from main import  init_environment, run_agent_on_environment

%load_ext autoreload

%autoreload 2
# matplotlib.use("TkAgg")

# 1 - Online A2C

## CartPole

In [5]:
# Initialisation de l'environnement
env_id = 'CartPole-v1'
env, envx = init_environment(env_id, seed=0)

agent = A2C(dim_input=env.observation_space.shape[0],
            dim_output=env.action_space.n,
            gamma=0.99,
            alpha=0.7,
            layers=[200],
            lr_V=0.001,
            lr_pi=0.001)
run_agent_on_environment(agent, env, envx,
                         max_episode=1000,
                         iter_print=10,
                         iter_show=100, name_file='a2c_cartpole.csv')

Mean of 10 last rewards for episode 10 : 11.090909090909092 (std 4.273694281288421)
Mean of 10 last rewards for episode 20 : 9.090909090909092 (std 2.9681504940571832)
Mean of 10 last rewards for episode 30 : 8.909090909090908 (std 3.0287874998104876)
Mean of 10 last rewards for episode 40 : 8.909090909090908 (std 2.906248611481051)
Mean of 10 last rewards for episode 50 : 8.818181818181818 (std 3.009626428590336)
Mean of 10 last rewards for episode 60 : 8.454545454545455 (std 2.9034035313947837)
Mean of 10 last rewards for episode 70 : 8.363636363636363 (std 2.7723546694503463)
Mean of 10 last rewards for episode 80 : 8.636363636363637 (std 2.8049542946439114)
Mean of 10 last rewards for episode 90 : 8.818181818181818 (std 2.9485954073139733)
Mean of 10 last rewards for episode 100 : 8.272727272727273 (std 2.699862255439545)
Mean of 10 last rewards for episode 110 : 8.454545454545455 (std 2.7090298955911525)
Mean of 10 last rewards for episode 120 : 8.636363636363637 (std 2.8049542946

Mean of 10 last rewards for episode 990 : 8.363636363636363 (std 2.739367122421702)
Mean of 10 last rewards for episode 1000 : 8.363636363636363 (std 2.67217062849074)


In [13]:
# Initialisation de l'environnement
env_id = 'LunarLander-v2'
SEED_ENVIRONMENT = 42
env, envx = init_environment(env_id, seed=SEED_ENVIRONMENT)

agent = A2C(dim_input=env.observation_space.shape[0],
            dim_output=env.action_space.n,
            gamma=0.99,
            alpha=0.7,
            layers=[30, 30],
            lr_V=0.01,
            lr_pi=0.001)
run_agent_on_environment(agent, env, envx,
                         max_episode=1000,
                         iter_print=10,
                         iter_show=100, name_file='a2c_lunarlander.csv')

Mean of 10 last rewards for episode 10 : -209.4925405328924 (std 142.0468293652938)
Mean of 10 last rewards for episode 20 : -111.38636467673562 (std 101.50882526912838)
Mean of 10 last rewards for episode 30 : -119.04245966131037 (std 41.70649174933107)
Mean of 10 last rewards for episode 40 : -101.08793293346058 (std 61.345528835104695)
Mean of 10 last rewards for episode 50 : -118.52248937433416 (std 43.6990999237616)
Mean of 10 last rewards for episode 60 : -152.61209938742897 (std 67.14826808628662)
Mean of 10 last rewards for episode 70 : -121.34483268044211 (std 49.07001925461506)
Mean of 10 last rewards for episode 80 : -148.31348765980113 (std 64.68029554130692)
Mean of 10 last rewards for episode 90 : -90.35026411576705 (std 58.337197854243605)
Mean of 10 last rewards for episode 100 : -129.2414190118963 (std 51.13575810151807)
Mean of 10 last rewards for episode 110 : -144.3570660677823 (std 104.30813203307079)
Mean of 10 last rewards for episode 120 : -144.83329010009766 (s

Mean of 10 last rewards for episode 970 : -326.65511044588953 (std 203.89327108823272)
Mean of 10 last rewards for episode 980 : -354.6663443825462 (std 247.74135229700667)
Mean of 10 last rewards for episode 990 : -291.7698235945268 (std 222.38492329359548)
Mean of 10 last rewards for episode 1000 : -273.45647915926844 (std 229.53723220971182)


# 2 - Batch A2C

In [105]:
# Initialisation de l'environnement
env_id = 'LunarLander-v2'
SEED_ENVIRONMENT = 42
env, envx = init_environment(env_id, seed=SEED_ENVIRONMENT)
print(env.observation_space.shape[0], env.action_space.n)
agent = A2CBatch(dim_input=env.observation_space.shape[0],
            dim_output=env.action_space.n,
            gamma=0.99,
            alpha=0.7,
            layers=[30, 30],
            lr_V=0.01,
            lr_pi=0.001)
run_agent_on_environment(agent, env, envx,
                         max_episode=1000,
                         iter_print=10,
                         iter_show=100, 
                         name_file='a2c_batch_lunarlander.csv')

8 4
Mean of 10 last rewards for episode 10 : -443.37635664506394 (std 218.91772012349224)
Mean of 10 last rewards for episode 20 : -550.518349387429 (std 237.8081268301725)
Mean of 10 last rewards for episode 30 : -543.5897438742898 (std 215.96027177761067)
Mean of 10 last rewards for episode 40 : -520.0400945490056 (std 260.1357025517608)
Mean of 10 last rewards for episode 50 : -587.2592329545455 (std 300.7694502696428)
Mean of 10 last rewards for episode 60 : -552.4935385964134 (std 224.93019392376536)
Mean of 10 last rewards for episode 70 : -567.5935890891335 (std 230.2363881511052)
Mean of 10 last rewards for episode 80 : -459.8474148837003 (std 224.95765631572644)
Mean of 10 last rewards for episode 90 : -602.0666170987216 (std 300.2314419911113)
Mean of 10 last rewards for episode 100 : -540.9567232998935 (std 245.33043529032628)
Mean of 10 last rewards for episode 110 : -477.47955599698156 (std 201.72422897444457)
Mean of 10 last rewards for episode 120 : -565.2390830300071 (s

Mean of 10 last rewards for episode 970 : -512.4780162464489 (std 229.8054231945858)
Mean of 10 last rewards for episode 980 : -447.5838068181818 (std 183.64975887590433)
Mean of 10 last rewards for episode 990 : -508.3662802956321 (std 263.6336368550151)
Mean of 10 last rewards for episode 1000 : -513.886191628196 (std 271.08494848390865)
