In [22]:
import torch
from torch.distributions import Categorical
import gym
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.functional as F
from torch.utils.tensorboard import SummaryWriter
from collections import deque

In [23]:
env = gym.make("CartPole-v1")
num_batches = 2000
num_episodes = 10000
GAMMA = 1
learning_rate = 1e-3
batch_size = 1000
writer = SummaryWriter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size = env.observation_space.shape[0]
output_size = env.action_space.n

In [24]:
class policy_network(nn.Module):
    def __init__(self, input_dims, output_dims):
        super(policy_network, self).__init__()
        self.SeqLayer = nn.Sequential(
            nn.Linear(input_dims, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, output_dims),
            nn.Softmax(dim = 1)
        )

    def forward(self, x):
        out = self.SeqLayer(x)
        return out

    def act(self, state):
        state_T = torch.from_numpy(state).float().unsqueeze(0).to(device)
        action_distributions = self.forward(state_T)
        m = Categorical(action_distributions)
        action_sample = m.sample()
        return action_sample.item(), m.log_prob(action_sample).item()

class Value_Approx(nn.Module):
    def __init__(self, input_dims, output_dims):
        super(Value_Approx, self).__init__()
        self.SeqLayer = nn.Sequential(
            nn.Linear(input_dims, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, output_dims),
        )

    def forward(self, x):
        out = self.SeqLayer(x)
        return out


policy = policy_network(input_size, output_size).to(device)
value_function_estimator = Value_Approx(input_size, 1).to(device)
critic_optimizer = torch.optim.Adam(policy.parameters(), lr = learning_rate)
value_optimizer = torch.optim.Adam(value_function_estimator.parameters(), lr = learning_rate)
loss_fn = nn.MSELoss()

In [25]:
def Qvals_calculator(rewards, last_state, GAMMA = 1):
    last_state_value = value_function_estimator(last_state)
    Qvals = np.zeros_like(rewards)
    for t in reversed(range(len(rewards))):
        Qval = rewards[t] + GAMMA * last_state_value
        Qvals[t] = Qval
    return Qvals

In [30]:
def AgentLearning(state, reward, next_state, done, log_prob, qvals):
    state = np.asarray(state)
    reward = np.asarray(reward)
    next_state = np.asarray(next_state)
    log_prob = np.asarray(log_prob)
    Qvals = np.asarray(qvals)
    value_optimizer.zero_grad()
    critic_optimizer.zero_grad()
    Qvals = torch.from_numpy(Qvals).to(device)
    log_probs = torch.from_numpy(log_prob).to(device)
    state_T = torch.from_numpy(state).to(device)
    next_state_T = torch.from_numpy(next_state).to(device)
    reward = torch.from_numpy(reward).float().to(device)
    critic_value = value_function_estimator(state_T).squeeze(1)
    Advantage = Qvals - critic_value
    critic_loss = 0.5 * Advantage.pow(2).mean()
    policy_loss = -(log_probs*Advantage).mean()
    (critic_loss+policy_loss).backward()
    value_optimizer.step()
    critic_optimizer.step()


episodes_rewards = np.zeros(num_episodes)
batch_states = deque(maxlen = 1000)
batch_rewards = deque(maxlen = 1000)
batch_next_states = deque(maxlen = 1000)
batch_dones = deque(maxlen = 1000)
batch_logs = deque(maxlen = 1000)
batch_qvals = deque(maxlen = 1000)
ep_reward = []
for i in tqdm(range(num_episodes)):
    done = False
    state = env.reset()
    score = 0
    while not done:
        action, log_prob = policy.act(state)
        next_state, reward, done, info = env.step(action)
        batch_states.append(state)
        batch_rewards.append(reward)
        batch_next_states.append(next_state)
        batch_dones.append(done)
        batch_logs.append(log_prob)
        ep_reward.append(reward)
        if done:
            last_state = next_state
            last_state = torch.from_numpy(last_state).to(device)
            batch_qvals.extend(Qvals_calculator(ep_reward, last_state))
            ep_reward = []
        if len(batch_logs) >= batch_size and done:
            AgentLearning(batch_states, batch_rewards, batch_next_states, batch_dones, batch_logs, batch_qvals)
            batch_states = deque(maxlen = 1000)
            batch_rewards = deque(maxlen = 1000)
            batch_next_states = deque(maxlen = 1000)
            batch_dones = deque(maxlen = 1000)
            batch_qvals = deque(maxlen = 1000)
            batch_logs = deque(maxlen = 1000)
        state = next_state
        score += reward
    episodes_rewards[i] = score
    writer.add_scalar("Episode total reward", episodes_rewards[i], i)

 20%|█▉        | 1992/10000 [00:55<03:44, 35.66it/s]


KeyboardInterrupt: 