In [1]:
import os.path

import torch
from torch.distributions import Categorical
import gym
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch as T
import torch.optim as optim
import torch.functional as F
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import os

writer = SummaryWriter()

In [2]:
class PPOMemory:
    def __init__(self, batch_size):
        self.states = []
        self.probs = []
        self.vals = []
        self.actions = []
        self.rewards = []
        self.dones = []

        self.batch_size = batch_size

    def generate_batches(self):
        n_states = len(self.states)
        batch_start = np.arange(0, n_states, self.batch_size)
        indices = np.arange(n_states, dtype=np.int64)
        np.random.shuffle(indices)
        batches = [indices[i:i+self.batch_size] for i in batch_start]


        return np.array(self.states),\
                np.array(self.actions),\
                np.array(self.probs),\
                np.array(self.vals),\
                np.array(self.rewards),\
                np.array(self.dones),\
                batches

    def store_memory(self, state, action, probs, vals, reward, done):
        self.states.append(state)
        self.actions.append(action)
        self.probs.append(probs)
        self.vals.append(vals)
        self.rewards.append(reward)
        self.dones.append(done)

    def clear_memory(self):
        self.states = []
        self.probs = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.vals = []

class ActorNetwork(nn.Module):
    def __init__(self, n_actions, input_dims, alpha,
            fc1_dims=256, fc2_dims=256, chkpt_dir='tmp/ppo'):
        super(ActorNetwork, self).__init__()

        self.checkpoint_file = os.path.join(chkpt_dir, 'actor_torch_ppo')
        self.actor = nn.Sequential(
                nn.Linear(*input_dims, fc1_dims),
                nn.ReLU(),
                nn.Linear(fc1_dims, fc2_dims),
                nn.ReLU(),
                nn.Linear(fc2_dims, n_actions),
                nn.Softmax(dim=-1)
        )

        self.optimizer = optim.Adam(self.parameters(), lr=alpha)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        dist = self.actor(state)
        dist = Categorical(dist)
        entropy = dist.entropy()

        return dist, entropy

    def save_checkpoint(self):
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        self.load_state_dict(T.load(self.checkpoint_file))

class CriticNetwork(nn.Module):
    def __init__(self, input_dims, alpha, fc1_dims=256, fc2_dims=256,
            chkpt_dir='tmp/ppo'):
        super(CriticNetwork, self).__init__()

        self.checkpoint_file = os.path.join(chkpt_dir, 'critic_torch_ppo')
        self.critic = nn.Sequential(
                nn.Linear(*input_dims, fc1_dims),
                nn.ReLU(),
                nn.Linear(fc1_dims, fc2_dims),
                nn.ReLU(),
                nn.Linear(fc2_dims, 1)
        )

        self.optimizer = optim.Adam(self.parameters(), lr=alpha)
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state):
        value = self.critic(state)

        return value

    def save_checkpoint(self):
        T.save(self.state_dict(), self.checkpoint_file)

    def load_checkpoint(self):
        self.load_state_dict(T.load(self.checkpoint_file))

class Agent:
    def __init__(self, n_actions, input_dims, gamma=0.99, alpha=0.0003, gae_lambda=0.95,
            policy_clip=0.2, batch_size=64, n_epochs=10):
        self.gamma = gamma
        self.policy_clip = policy_clip
        self.n_epochs = n_epochs
        self.gae_lambda = gae_lambda

        self.actor = ActorNetwork(n_actions, input_dims, alpha)
        self.critic = CriticNetwork(input_dims, alpha)
        self.memory = PPOMemory(batch_size)
        self.MSE = nn.MSELoss()

    def remember(self, state, action, probs, vals, reward, done):
        self.memory.store_memory(state, action, probs, vals, reward, done)

    def save_models(self):
        print('... saving models ...')
        self.actor.save_checkpoint()
        self.critic.save_checkpoint()

    def load_models(self):
        print('... loading models ...')
        self.actor.load_checkpoint()
        self.critic.load_checkpoint()

    def choose_action(self, observation):
        state = T.tensor([observation], dtype=T.float).to(self.actor.device)

        dist, entropy = self.actor(state)
        value = self.critic(state)
        action = dist.sample()

        probs = T.squeeze(dist.log_prob(action)).item()
        action = T.squeeze(action).item()
        value = T.squeeze(value).item()

        return action, probs, value

    def learn(self):
        for _ in range(self.n_epochs):
            state_arr, action_arr, old_prob_arr, vals_arr,\
            reward_arr, dones_arr, batches = \
                    self.memory.generate_batches()

            discounted_reward = 0
            rewards = []
            for r, is_terminal in zip(reversed(reward_arr), reversed(dones_arr)):
                if is_terminal:
                    discounted_reward = 0
                discounted_reward = r + (self.gamma * discounted_reward)
                rewards.insert(0, discounted_reward)

            values = T.from_numpy(vals_arr).float().to(self.actor.device)

            rewards = T.tensor(rewards, dtype=torch.float32).to(self.actor.device)
            rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-7)
            advantage = rewards - values
            for batch in batches:
                states = T.tensor(state_arr[batch], dtype=T.float).to(self.actor.device)
                old_probs = T.tensor(old_prob_arr[batch]).to(self.actor.device)
                actions = T.tensor(action_arr[batch]).to(self.actor.device)


                dist, entropy = self.actor(states)
                critic_value = self.critic(states)

                critic_value = T.squeeze(critic_value)

                new_probs = dist.log_prob(actions)
                prob_ratio = (new_probs - old_probs).exp()
                weighted_probs = advantage[batch] * prob_ratio
                weighted_clipped_probs = T.clamp(prob_ratio, 1-self.policy_clip,
                        1+self.policy_clip)*advantage[batch]
                actor_loss = -T.min(weighted_probs, weighted_clipped_probs).mean()

                critic_loss = self.MSE(critic_value, rewards[batch])

                total_loss = actor_loss + 0.5*critic_loss-0.01*entropy.mean()
                self.actor.optimizer.zero_grad()
                self.critic.optimizer.zero_grad()
                total_loss.backward()
                self.actor.optimizer.step()
                self.critic.optimizer.step()

        self.memory.clear_memory()


In [3]:
if __name__ == '__main__':
    env = gym.make('CartPole-v1')
    N=20
    batch_size = 5
    n_epochs = 4
    alpha = 0.0003
    agent = Agent(n_actions = env.action_space.n, batch_size = batch_size, alpha = alpha,
                  n_epochs = n_epochs, input_dims = env.observation_space.shape)
    n_games = 3000
    figure_file = 'plots/cartpole.png'

    best_score = env.reward_range[0]
    score_history = []

    learn_iters = 0
    avg_score = 0
    n_steps = 0

    for i in tqdm(range(n_games)):
        observation = env.reset()
        done = False
        score = 0
        while not done:
            action, prob, val = agent.choose_action(observation)
            observation_, reward, done, info = env.step(action)
            n_steps +=1
            score += reward
            agent.remember(observation, action, prob, val, reward, done)
            if n_steps % N == 0:
                agent.learn()
                learn_iters+=1
            observation = observation_
        score_history.append(score)
        avg_score = np.mean(score_history[-100:])
        writer.add_scalar("Episode total reward", score_history[i], i)

        if avg_score > best_score:
            best_score = avg_score
            agent.save_models()

        #print('episode ', i, ' score %.if ' % score, " avg score %.1f " % avg_score,
        #      " time steps", n_steps, " learning_steps ", learn_iters)

  state = T.tensor([observation], dtype=T.float).to(self.actor.device)
  0%|          | 1/3000 [00:01<1:23:15,  1.67s/it]

... saving models ...


  0%|          | 3/3000 [00:01<26:07,  1.91it/s]  

... saving models ...


  0%|          | 5/3000 [00:02<20:49,  2.40it/s]

... saving models ...


  1%|          | 26/3000 [00:06<14:01,  3.53it/s]

... saving models ...


  1%|          | 27/3000 [00:07<18:41,  2.65it/s]

... saving models ...


  1%|          | 28/3000 [00:07<16:30,  3.00it/s]

... saving models ...


  1%|          | 29/3000 [00:08<20:15,  2.44it/s]

... saving models ...


  1%|          | 30/3000 [00:09<23:01,  2.15it/s]

... saving models ...


  1%|          | 31/3000 [00:09<21:32,  2.30it/s]

... saving models ...


  1%|          | 32/3000 [00:09<18:44,  2.64it/s]

... saving models ...


  1%|          | 33/3000 [00:10<32:45,  1.51it/s]

... saving models ...


  1%|          | 34/3000 [00:11<33:06,  1.49it/s]

... saving models ...


  1%|          | 35/3000 [00:13<51:40,  1.05s/it]

... saving models ...


  1%|▏         | 40/3000 [00:16<35:27,  1.39it/s]

... saving models ...


  1%|▏         | 41/3000 [00:16<34:16,  1.44it/s]

... saving models ...


  1%|▏         | 42/3000 [00:17<31:41,  1.56it/s]

... saving models ...


  1%|▏         | 43/3000 [00:18<33:47,  1.46it/s]

... saving models ...


  1%|▏         | 44/3000 [00:19<40:11,  1.23it/s]

... saving models ...


  2%|▏         | 45/3000 [00:20<43:36,  1.13it/s]

... saving models ...


  2%|▏         | 46/3000 [00:21<42:45,  1.15it/s]

... saving models ...


  2%|▏         | 47/3000 [00:21<40:16,  1.22it/s]

... saving models ...


  3%|▎         | 94/3000 [00:40<59:48,  1.23s/it]

... saving models ...


  3%|▎         | 95/3000 [00:41<54:35,  1.13s/it]

... saving models ...


  3%|▎         | 96/3000 [00:42<48:19,  1.00it/s]

... saving models ...


  3%|▎         | 97/3000 [00:43<47:18,  1.02it/s]

... saving models ...


  3%|▎         | 99/3000 [00:44<40:44,  1.19it/s]

... saving models ...


  3%|▎         | 100/3000 [00:45<35:17,  1.37it/s]

... saving models ...


  3%|▎         | 101/3000 [00:45<28:00,  1.72it/s]

... saving models ...


  3%|▎         | 102/3000 [00:46<31:46,  1.52it/s]

... saving models ...


  3%|▎         | 104/3000 [00:46<22:03,  2.19it/s]

... saving models ...
... saving models ...


  4%|▎         | 106/3000 [00:47<19:38,  2.46it/s]

... saving models ...


  4%|▎         | 107/3000 [00:47<17:56,  2.69it/s]

... saving models ...


  4%|▎         | 108/3000 [00:48<22:36,  2.13it/s]

... saving models ...


  4%|▎         | 109/3000 [00:49<36:21,  1.33it/s]

... saving models ...


  4%|▎         | 110/3000 [00:50<32:16,  1.49it/s]

... saving models ...


  4%|▎         | 111/3000 [00:50<25:52,  1.86it/s]

... saving models ...


  4%|▎         | 112/3000 [00:51<28:13,  1.71it/s]

... saving models ...


  4%|▍         | 113/3000 [00:51<28:26,  1.69it/s]

... saving models ...


  4%|▍         | 114/3000 [00:52<24:53,  1.93it/s]

... saving models ...


  4%|▍         | 115/3000 [00:52<22:41,  2.12it/s]

... saving models ...


  4%|▍         | 116/3000 [00:52<20:49,  2.31it/s]

... saving models ...


  4%|▍         | 119/3000 [00:54<20:18,  2.37it/s]

... saving models ...


  4%|▍         | 120/3000 [00:54<20:48,  2.31it/s]

... saving models ...


  4%|▍         | 122/3000 [00:54<14:19,  3.35it/s]

... saving models ...


  5%|▍         | 141/3000 [01:05<40:01,  1.19it/s]

... saving models ...


  5%|▍         | 142/3000 [01:05<35:32,  1.34it/s]

... saving models ...


  5%|▍         | 143/3000 [01:06<41:17,  1.15it/s]

... saving models ...


  5%|▌         | 150/3000 [01:11<26:38,  1.78it/s]

... saving models ...


  5%|▌         | 151/3000 [01:11<27:53,  1.70it/s]

... saving models ...


  5%|▌         | 152/3000 [01:12<31:58,  1.48it/s]

... saving models ...


  5%|▌         | 153/3000 [01:13<27:51,  1.70it/s]

... saving models ...


  5%|▌         | 154/3000 [01:13<28:16,  1.68it/s]

... saving models ...


  5%|▌         | 155/3000 [01:14<35:33,  1.33it/s]

... saving models ...


  5%|▌         | 156/3000 [01:16<44:23,  1.07it/s]

... saving models ...


  5%|▌         | 157/3000 [01:17<45:15,  1.05it/s]

... saving models ...


  5%|▌         | 158/3000 [01:18<48:59,  1.03s/it]

... saving models ...


  5%|▌         | 159/3000 [01:19<44:51,  1.06it/s]

... saving models ...


  5%|▌         | 160/3000 [01:20<43:40,  1.08it/s]

... saving models ...


  5%|▌         | 161/3000 [01:22<1:02:07,  1.31s/it]

... saving models ...


  5%|▌         | 162/3000 [01:23<1:04:45,  1.37s/it]

... saving models ...


  5%|▌         | 163/3000 [01:25<1:04:44,  1.37s/it]

... saving models ...


  5%|▌         | 164/3000 [01:26<1:00:59,  1.29s/it]

... saving models ...


  6%|▌         | 165/3000 [01:27<53:14,  1.13s/it]  

... saving models ...


  6%|▌         | 166/3000 [01:28<52:50,  1.12s/it]

... saving models ...


  6%|▌         | 167/3000 [01:29<51:27,  1.09s/it]

... saving models ...


  6%|▌         | 168/3000 [01:29<46:46,  1.01it/s]

... saving models ...


  6%|▌         | 169/3000 [01:30<43:26,  1.09it/s]

... saving models ...


  6%|▌         | 170/3000 [01:31<40:52,  1.15it/s]

... saving models ...


  6%|▌         | 171/3000 [01:32<41:07,  1.15it/s]

... saving models ...


  6%|▌         | 172/3000 [01:33<38:53,  1.21it/s]

... saving models ...


  6%|▌         | 173/3000 [01:33<36:14,  1.30it/s]

... saving models ...


  6%|▌         | 174/3000 [01:34<36:07,  1.30it/s]

... saving models ...


  6%|▌         | 175/3000 [01:35<35:33,  1.32it/s]

... saving models ...


  6%|▌         | 176/3000 [01:35<36:54,  1.28it/s]

... saving models ...


  6%|▌         | 177/3000 [01:36<38:17,  1.23it/s]

... saving models ...


  6%|▌         | 178/3000 [01:37<41:11,  1.14it/s]

... saving models ...


  6%|▌         | 179/3000 [01:38<39:09,  1.20it/s]

... saving models ...


  6%|▌         | 180/3000 [01:39<37:27,  1.25it/s]

... saving models ...


  6%|▋         | 188/3000 [01:46<51:14,  1.09s/it]

... saving models ...


  6%|▋         | 189/3000 [01:48<59:18,  1.27s/it]

... saving models ...


  7%|▋         | 196/3000 [01:56<1:17:32,  1.66s/it]

... saving models ...


  7%|▋         | 197/3000 [01:57<1:12:33,  1.55s/it]

... saving models ...


  7%|▋         | 198/3000 [01:58<1:04:02,  1.37s/it]

... saving models ...


  9%|▉         | 281/3000 [02:57<1:44:12,  2.30s/it]

... saving models ...


  9%|▉         | 282/3000 [03:00<1:54:34,  2.53s/it]

... saving models ...


  9%|▉         | 283/3000 [03:03<1:52:38,  2.49s/it]

... saving models ...


  9%|▉         | 284/3000 [03:04<1:32:23,  2.04s/it]

... saving models ...


 10%|▉         | 293/3000 [03:11<46:43,  1.04s/it]  

... saving models ...


 10%|▉         | 294/3000 [03:12<46:32,  1.03s/it]

... saving models ...


 10%|█         | 313/3000 [03:26<30:28,  1.47it/s]

... saving models ...


 10%|█         | 314/3000 [03:27<30:06,  1.49it/s]

... saving models ...


 10%|█         | 315/3000 [03:27<27:53,  1.60it/s]

... saving models ...


 11%|█         | 316/3000 [03:28<25:50,  1.73it/s]

... saving models ...


 11%|█         | 317/3000 [03:28<23:06,  1.93it/s]

... saving models ...


 11%|█         | 318/3000 [03:29<21:19,  2.10it/s]

... saving models ...


 11%|█         | 319/3000 [03:29<20:59,  2.13it/s]

... saving models ...


 11%|█         | 320/3000 [03:30<21:18,  2.10it/s]

... saving models ...


 11%|█         | 321/3000 [03:30<22:48,  1.96it/s]

... saving models ...


 11%|█         | 322/3000 [03:31<22:45,  1.96it/s]

... saving models ...


 11%|█         | 323/3000 [03:31<24:53,  1.79it/s]

... saving models ...


 11%|█         | 324/3000 [03:32<25:30,  1.75it/s]

... saving models ...


 11%|█         | 326/3000 [03:33<25:05,  1.78it/s]

... saving models ...


 11%|█         | 327/3000 [03:34<26:13,  1.70it/s]

... saving models ...


 11%|█         | 328/3000 [03:34<26:41,  1.67it/s]

... saving models ...


 11%|█         | 329/3000 [03:35<27:22,  1.63it/s]

... saving models ...


 11%|█         | 330/3000 [03:35<25:57,  1.71it/s]

... saving models ...


 11%|█         | 331/3000 [03:36<26:17,  1.69it/s]

... saving models ...


 11%|█         | 332/3000 [03:36<23:46,  1.87it/s]

... saving models ...


 11%|█         | 333/3000 [03:37<20:28,  2.17it/s]

... saving models ...


 11%|█         | 335/3000 [03:37<15:29,  2.87it/s]

... saving models ...
... saving models ...


 18%|█▊        | 544/3000 [05:23<45:28,  1.11s/it]  

... saving models ...


 18%|█▊        | 545/3000 [05:24<46:37,  1.14s/it]

... saving models ...


 18%|█▊        | 546/3000 [05:25<47:24,  1.16s/it]

... saving models ...


 18%|█▊        | 547/3000 [05:27<48:04,  1.18s/it]

... saving models ...


 18%|█▊        | 548/3000 [05:28<48:52,  1.20s/it]

... saving models ...


 18%|█▊        | 549/3000 [05:29<50:38,  1.24s/it]

... saving models ...


 18%|█▊        | 550/3000 [05:30<50:49,  1.24s/it]

... saving models ...


 18%|█▊        | 551/3000 [05:32<56:27,  1.38s/it]

... saving models ...


 18%|█▊        | 552/3000 [05:35<1:16:54,  1.89s/it]

... saving models ...


 18%|█▊        | 553/3000 [05:38<1:31:47,  2.25s/it]

... saving models ...


 18%|█▊        | 554/3000 [05:41<1:32:42,  2.27s/it]

... saving models ...


 18%|█▊        | 555/3000 [05:43<1:28:56,  2.18s/it]

... saving models ...


 19%|█▊        | 556/3000 [05:44<1:20:11,  1.97s/it]

... saving models ...


 19%|█▊        | 557/3000 [05:45<1:12:32,  1.78s/it]

... saving models ...


 19%|█▊        | 558/3000 [05:47<1:05:34,  1.61s/it]

... saving models ...


 19%|█▊        | 559/3000 [05:48<58:30,  1.44s/it]  

... saving models ...


 19%|█▊        | 560/3000 [05:49<51:37,  1.27s/it]

... saving models ...


 19%|█▊        | 562/3000 [05:50<45:46,  1.13s/it]

... saving models ...


 19%|█▉        | 563/3000 [05:52<47:14,  1.16s/it]

... saving models ...


 19%|█▉        | 564/3000 [05:53<49:21,  1.22s/it]

... saving models ...


 19%|█▉        | 565/3000 [05:54<51:12,  1.26s/it]

... saving models ...


 19%|█▉        | 566/3000 [05:56<52:07,  1.28s/it]

... saving models ...


 23%|██▎       | 700/3000 [07:08<23:28,  1.63it/s]


KeyboardInterrupt: 