Policy-based

In [22]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

#Hyperparameters
learning_rate = 0.0002
gamma         = 0.98

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=0)
        return x
      
    def put_data(self, item):
        self.data.append(item)
        
    def train_net(self):
        R = 0
        self.optimizer.zero_grad()
        for r, prob in self.data[::-1]: # backward calculation
            R = r + gamma * R 
            loss = -torch.log(prob) * R
            loss.backward() # Compute gradients
        self.optimizer.step()
        self.data = []



In [23]:
env = gym.make('CartPole-v1')
num_states = env.observation_space.shape[0]
print("Size of State Space ->  {}".format(num_states))
num_actions = env.action_space.n
print("Size of Action Space ->  {}".format(num_actions))

pi = Policy()
score = 0.0
print_interval = 20
print_interval = 100


#for n_epi in range(10000):
for n_epi in range(500):
    s = env.reset()
    done = False

    while not done: # CartPole-v1 forced to terminates at 500 step.
        prob = pi(torch.from_numpy(s).float())
        m = Categorical(prob)
        a = m.sample()
        s_prime, r, done, info = env.step(a.item())
        pi.put_data((r,prob[a]))
        s = s_prime
        score += r

    pi.train_net()

    if n_epi%print_interval==0 and n_epi!=0:
        print("# of episode :{}, avg score : {}".format(n_epi, score/print_interval))
        score = 0.0
env.close()


Size of State Space ->  4
Size of Action Space ->  2
# of episode :100, avg score : 24.62
# of episode :200, avg score : 32.91
# of episode :300, avg score : 39.6
# of episode :400, avg score : 48.44
