In [1]:
import torch
from torch.distributions import Categorical
import gym
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.functional as F
from torch.utils.tensorboard import SummaryWriter
from collections import deque

In [2]:
env = gym.make("CartPole-v1")
num_batches = 2000
num_episodes = 10000
GAMMA = 1
learning_rate = 1e-3
batch_size = 1000
writer = SummaryWriter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size = env.observation_space.shape[0]
output_size = env.action_space.n

In [4]:
class policy_network(nn.Module):
    def __init__(self, input_dims, output_dims):
        super(policy_network, self).__init__()
        self.SeqLayer = nn.Sequential(
            nn.Linear(input_dims, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, output_dims),
            nn.Softmax(dim = 1)
        )

    def forward(self, x):
        out = self.SeqLayer(x)
        return out

    def act(self, state):
        state_T = torch.from_numpy(state).float().unsqueeze(0).to(device)
        action_distributions = self.forward(state_T)
        m = Categorical(action_distributions)
        action_sample = m.sample()
        return action_sample.item(), m.log_prob(action_sample)

class Value_Approx(nn.Module):
    def __init__(self, input_dims, output_dims):
        super(Value_Approx, self).__init__()
        self.SeqLayer = nn.Sequential(
            nn.Linear(input_dims, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, output_dims),
        )

    def forward(self, x):
        out = self.SeqLayer(x)
        return out


policy = policy_network(input_size, output_size).to(device)
value_function_estimator = Value_Approx(input_size, 1).to(device)
critic_optimizer = torch.optim.Adam(policy.parameters(), lr = learning_rate)
value_optimizer = torch.optim.Adam(value_function_estimator.parameters(), lr = learning_rate)
loss_fn = nn.MSELoss()

In [5]:
def RewardsToGo(rewards, baseline):
    subsequent_reward = []
    for i in range(len(rewards)):
        subsequent_reward.append(sum(rewards[i:])-baseline[i].item())
    return subsequent_reward

In [8]:
episodes_rewards = np.zeros(num_episodes)
ep_counter = 0
for i in tqdm(range(num_batches)):
    log_probs = []
    rewards = []
    batch_rewards = []
    baseline_predictions = []
    on_episode_predictions = []
    done = False
    state = env.reset()
    while not done:
        action, log_prob = policy.act(state)
        state = torch.from_numpy(state).float().to(device)
        state_value = value_function_estimator(state)
        log_probs.append(log_prob)
        on_episode_predictions.append(state_value)
        state, R, done, info = env.step(action)
        rewards.append(R)
        if done:
            ep_counter += 1
            subsequent_rewards = RewardsToGo(rewards, on_episode_predictions)
            batch_rewards.extend(subsequent_rewards)
            baseline_predictions.extend(on_episode_predictions)
            on_episode_predictions = []
            total_reward = sum(rewards)
            episodes_rewards[ep_counter] = total_reward
            writer.add_scalar("Episode Total Reward", episodes_rewards[ep_counter], ep_counter)
            rewards = []
            state = env.reset()
            done = False
            if len(log_probs) >= batch_size:
                break
    batch_rewards = np.asarray(batch_rewards)
    batch_probs = torch.cat(log_probs)
    batch_rewards_T = torch.from_numpy(batch_rewards).float().to(device)
    baseline_T = torch.cat(baseline_predictions)
    value_loss = loss_fn(batch_rewards_T, baseline_T)
    loss = -(batch_probs*batch_rewards_T).mean()+value_loss
    critic_optimizer.zero_grad()
    value_optimizer.zero_grad()
    loss.backward()
    critic_optimizer.step()
    value_optimizer.step()

100%|██████████| 2000/2000 [1:20:10<00:00,  2.41s/it]
