In [1]:
import torch
from torch.distributions import Categorical
import gym
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.functional as F
from torch.utils.tensorboard import SummaryWriter
from collections import deque

In [2]:
env = gym.make("CartPole-v1")
num_batches = 2000
GAMMA = 1
learning_rate = 1e-3
batch_size = 1000
writer = SummaryWriter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size = env.observation_space.shape[0]
output_size = env.action_space.n

In [3]:
class Policy(nn.Module):
    def __init__(self, input_dims, output_dims):
        super(Policy, self).__init__()
        self.Seq = nn.Sequential(
            nn.Linear(input_dims, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, output_dims),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        out = self.Seq(x)
        return out

    def act(self, state):
        state_t = torch.from_numpy(state).float().unsqueeze(0).to(device)
        action_distributions = self.forward(state_t)
        m = Categorical(action_distributions)
        action = m.sample()
        return action.item(), m.log_prob(action)

policy = Policy(input_size, output_size).to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr = learning_rate)

In [4]:
for i in tqdm(range(num_batches)):
    done = False
    saved_logProbs = []
    state = env.reset()
    batch_reward = []
    rewards = []
    while True:
        action, log_distribution = policy.act(state)
        saved_logProbs.append(log_distribution)
        state, R, done, _ = env.step(action)
        rewards.append(R)
        if done:
            ep_reward = sum(rewards)
            batch_reward += [ep_reward] * len(rewards)
            rewards = []
            state = env.reset()
            done = False
        if len(batch_reward) >= batch_size:
            batch_sum = sum(batch_reward)
            writer.add_scalar("Average reward per Batch", batch_sum/len(batch_reward), i)
            break
    batch_reward = np.asarray(batch_reward)
    batch_rewards_T = torch.from_numpy(batch_reward).float().to(device)
    log_batch = torch.cat(saved_logProbs)
    gradient_loss = -(log_batch*batch_rewards_T).mean()
    optimizer.zero_grad()
    gradient_loss.backward()
    optimizer.step()


100%|██████████| 2000/2000 [1:01:08<00:00,  1.83s/it]
