In [1]:
%load_ext autoreload

In [1]:
import functools
import gym
from Config import Config
# from util import train
from Models import ActorCritic
from Networks import cnn_head_model, actor_model, critic_model, head_model
from Memory import Memory
from baselines.common.cmd_util import make_env
from baselines.common.atari_wrappers import wrap_deepmind, make_atari

import matplotlib.pyplot as plt

# env_id = "BreakoutNoFrameskip-v4"
# env = make_atari(env_id)
# env = wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=True, scale=False)

config = Config(gym.make('CartPole-v1'))

config.update_every = 500
config.num_learn = 4
config.win_condition = 230
config.n_episodes = 1000
config.max_t = 700

config.Memory = Memory
config.Model = ActorCritic
# config.head_model = functools.partial(cnn_head_model, config)
config.head_model = functools.partial(head_model, config)
config.actor_model = functools.partial(actor_model, config)
config.critic_model = functools.partial(critic_model, config)


In [4]:
import copy
import gym
import torch
import numpy as np
from collections import deque
from PPO import PPO
from Config import Config
import pdb

def train(config):
    env = copy.deepcopy(config.env)
    steps = 0
    scores_deque = deque(maxlen=100)
    scores = []
    average_scores = []
    max_score = -np.Inf

    agent = PPO(config)

    for i_episode in range(1, config.n_episodes+1):
        state = env.reset()
        score = 0
        for t in range(config.max_t):
            steps += 1

            action, log_prob = agent.act(torch.FloatTensor(state))
            print("Action space shape: {}".format(action.shape))
            print("Action: {}".format(action))
            print("Log Probabilities: {}".format(log_prob))
            print("Action item: {}".format(action.item()))
            next_state, reward, done, _ = env.step(action.item())

            agent.mem.add(torch.FloatTensor(state), action, reward, log_prob, done)

            # Update 
            state = next_state
            score += reward


            if steps >= config.update_every:
                agent.learn(config.num_learn)
                agent.mem.clear()
                steps = 0

            if done:
                break 

        # Book Keeping
        scores_deque.append(score)
        scores.append(score)
        average_scores.append(np.mean(scores_deque))

        if i_episode % 10 == 0:
            print("\rEpisode {}	Average Score: {:.2f}	Score: {:.2f}".format(i_episode, np.mean(scores_deque), score), end="")
        if i_episode % 100 == 0:
            print("\rEpisode {}	Average Score: {:.2f}".format(i_episode, np.mean(scores_deque)))   

        if np.mean(scores_deque) > config.win_condition:
            print("\nEnvironment Solved!")
            break

    return scores, average_scores


In [None]:
scores, average_score = train(config)
plt.plot(scores)
plt.plot(average_score)
plt.show()

Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.6826119422912598
Action item: 0
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7172805070877075
Action item: 1
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.682884156703949
Action item: 0
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.6697466969490051
Action item: 0
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.666181743144989
Action item: 0
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7328488826751709
Action item: 1
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7204564809799194
Action item: 1
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.670913577079773
Action item: 0
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7205312252044678
Action item: 1
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.6713144779205322
Action item: 0
Action space 

Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.6859999895095825
Action item: 0
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7138509154319763
Action item: 1
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7020349502563477
Action item: 1
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.7038992643356323
Action item: 0
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7038517594337463
Action item: 1
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.7022061347961426
Action item: 0
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7051038146018982
Action item: 1
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.6856895685195923
Action item: 1
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.6780983805656433
Action item: 1
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.6764447093009949
Action item: 1
Action spa

Action item: 0
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7362929582595825
Action item: 1
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7293131351470947
Action item: 1
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7198418378829956
Action item: 1
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.7127653956413269
Action item: 1
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.6801742315292358
Action item: 0
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.6741742491722107
Action item: 0
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.6674377918243408
Action item: 0
Action space shape: torch.Size([])
Action: 0
Log Probabilities: -0.6629998087882996
Action item: 0
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.730987012386322
Action item: 1
Action space shape: torch.Size([])
Action: 1
Log Probabilities: -0.702467679977417
Action item: