In [21]:
import torch
from torch.distributions import Categorical
import gym
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.functional as F
from torch.utils.tensorboard import SummaryWriter
from collections import deque

In [22]:
env = gym.make("CartPole-v1")
num_batches = 2000
num_episodes = 10000
GAMMA = 1
learning_rate = 1e-5
batch_size = 1000
writer = SummaryWriter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_size = env.observation_space.shape[0]
output_size = env.action_space.n

In [23]:
class policy_network(nn.Module):
    def __init__(self, input_dims, output_dims):
        super(policy_network, self).__init__()
        self.SeqLayer = nn.Sequential(
            nn.Linear(input_dims, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, output_dims),
            nn.Softmax(dim = 1)
        )

    def forward(self, x):
        out = self.SeqLayer(x)
        return out

    def act(self, state):
        state_T = torch.from_numpy(state).float().unsqueeze(0).to(device)
        action_distributions = self.forward(state_T)
        m = Categorical(action_distributions)
        action_sample = m.sample()
        return action_sample.item(), m.log_prob(action_sample)

class Value_Approx(nn.Module):
    def __init__(self, input_dims, output_dims):
        super(Value_Approx, self).__init__()
        self.SeqLayer = nn.Sequential(
            nn.Linear(input_dims, 64),
            nn.Tanh(),
            nn.Linear(64, 64),
            nn.Tanh(),
            nn.Linear(64, output_dims),
        )

    def forward(self, x):
        out = self.SeqLayer(x)
        return out


policy = policy_network(input_size, output_size).to(device)
value_function_estimator = Value_Approx(input_size, 1).to(device)
critic_optimizer = torch.optim.Adam(policy.parameters(), lr = 1e-3)
value_optimizer = torch.optim.Adam(value_function_estimator.parameters(), lr = learning_rate)
loss_fn = nn.MSELoss()

In [24]:
def AgentLearning(state, reward, next_state, done, log_prob):
    value_optimizer.zero_grad()
    critic_optimizer.zero_grad()
    state_T = torch.from_numpy(state).float().to(device)
    next_state_T = torch.from_numpy(next_state).float().to(device)
    critic_value = value_function_estimator(state_T)
    critic_next_value = value_function_estimator(next_state_T)
    TD_error = reward + GAMMA*critic_next_value*(1-int(done))-critic_value
    target = reward + GAMMA*critic_next_value*(1-int(done))
    prediction = critic_value
    critic_loss = loss_fn(prediction, target)
    policy_loss = -log_prob*TD_error
    (critic_loss+policy_loss).backward()
    value_optimizer.step()
    critic_optimizer.step()

In [25]:
episodes_rewards = np.zeros(num_episodes)
for i in tqdm(range(num_episodes)):
    done = False
    state = env.reset()
    score = 0
    while not done:
        action, log_prob = policy.act(state)
        next_state, reward, done, info = env.step(action)
        AgentLearning(state, reward, next_state, done, log_prob)
        state = next_state
        score += reward
    episodes_rewards[i] = score
    writer.add_scalar("Episode total reward", episodes_rewards[i], i)

 30%|██▉       | 2991/10000 [13:20<31:15,  3.74it/s]  

KeyboardInterrupt

