In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

#Hyperparameters
learning_rate = 0.0005
gamma         = 0.98
lmbda         = 0.95
eps_clip      = 0.1
K_epoch       = 3
T_horizon     = 20

class PPO(nn.Module):
    def __init__(self):
        super(PPO, self).__init__()
        self.data = []
        
        self.fc1   = nn.Linear(4,256)
        self.fc_pi = nn.Linear(256,2)
        self.fc_v  = nn.Linear(256,1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)

    def pi(self, x, softmax_dim = 0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob
    
    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v
      
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, prob_a_lst, done_lst = [], [], [], [], [], []
        for transition in self.data:
            s, a, r, s_prime, prob_a, done = transition
            
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            prob_a_lst.append([prob_a])
            done_mask = 0 if done else 1
            done_lst.append([done_mask])
            
        s,a,r,s_prime,done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
                                          torch.tensor(r_lst), torch.tensor(s_prime_lst, dtype=torch.float), \
                                          torch.tensor(done_lst, dtype=torch.float), torch.tensor(prob_a_lst)
        self.data = []
        return s, a, r, s_prime, done_mask, prob_a
        
    def train_net(self):
        s, a, r, s_prime, done_mask, prob_a = self.make_batch()

        for i in range(K_epoch):
            td_target = r + gamma * self.v(s_prime) * done_mask
            delta = td_target - self.v(s)
            delta = delta.detach().numpy()

            advantage_lst = []
            advantage = 0.0
            for delta_t in delta[::-1]:
                advantage = gamma * lmbda * advantage + delta_t[0]
                advantage_lst.append([advantage])
            advantage_lst.reverse()
            advantage = torch.tensor(advantage_lst, dtype=torch.float)

            pi = self.pi(s, softmax_dim=1)
            pi_a = pi.gather(1,a)
            ratio = torch.exp(torch.log(pi_a) - torch.log(prob_a))  # a/b == exp(log(a)-log(b))

            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1-eps_clip, 1+eps_clip) * advantage
            loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(self.v(s) , td_target.detach())

            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()

In [2]:
env = gym.make('CartPole-v1')
model = PPO()
score = 0.0
print_interval = 20

for n_epi in range(10000):
    s, _ = env.reset()
    done = False
    while not done:
        for t in range(T_horizon):
            prob = model.pi(torch.from_numpy(s).float())
            m = Categorical(prob)
            a = m.sample().item()
            s_prime, r, done, truncated, info = env.step(a)

            model.put_data((s, a, r/100.0, s_prime, prob[a].item(), done))
            s = s_prime

            score += r
            if done:
                break

        model.train_net()

    if n_epi%print_interval==0 and n_epi!=0:
        print("# of episode :{}, avg score : {:.1f}".format(n_epi, score/print_interval))
        score = 0.0

env.close()

  if not isinstance(terminated, (bool, np.bool8)):
  s,a,r,s_prime,done_mask, prob_a = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \


# of episode :20, avg score : 26.1
# of episode :40, avg score : 36.3
# of episode :60, avg score : 68.3
# of episode :80, avg score : 107.2
# of episode :100, avg score : 56.3
# of episode :120, avg score : 156.6
# of episode :140, avg score : 311.4
# of episode :160, avg score : 130.4
# of episode :180, avg score : 196.2
# of episode :200, avg score : 203.8
# of episode :220, avg score : 36.0
# of episode :240, avg score : 62.0
# of episode :260, avg score : 132.6
# of episode :280, avg score : 239.6
# of episode :300, avg score : 167.5
# of episode :320, avg score : 194.3
# of episode :340, avg score : 111.8
# of episode :360, avg score : 116.0
# of episode :380, avg score : 120.2
# of episode :400, avg score : 181.6
# of episode :420, avg score : 185.8
# of episode :440, avg score : 209.7
# of episode :460, avg score : 296.8
# of episode :480, avg score : 282.8
# of episode :500, avg score : 291.2
# of episode :520, avg score : 546.8
# of episode :540, avg score : 483.2
# of episod

# of episode :4380, avg score : 131.8
# of episode :4400, avg score : 147.7
# of episode :4420, avg score : 214.8
# of episode :4440, avg score : 226.7
# of episode :4460, avg score : 224.4
# of episode :4480, avg score : 689.6
# of episode :4500, avg score : 206.8
# of episode :4520, avg score : 148.4
# of episode :4540, avg score : 153.2
# of episode :4560, avg score : 199.0
# of episode :4580, avg score : 543.1
# of episode :4600, avg score : 2144.5
# of episode :4620, avg score : 761.9
# of episode :4640, avg score : 437.9
# of episode :4660, avg score : 241.2
# of episode :4680, avg score : 267.9
# of episode :4700, avg score : 321.6
# of episode :4720, avg score : 460.1
# of episode :4740, avg score : 341.1
# of episode :4760, avg score : 591.8
# of episode :4780, avg score : 259.4
# of episode :4800, avg score : 354.2
# of episode :4820, avg score : 390.1
# of episode :4840, avg score : 197.8
# of episode :4860, avg score : 202.1
# of episode :4880, avg score : 187.7
# of episod

# of episode :8700, avg score : 636.1
# of episode :8720, avg score : 696.0
# of episode :8740, avg score : 334.1
# of episode :8760, avg score : 432.6
# of episode :8780, avg score : 207.2
# of episode :8800, avg score : 185.7
# of episode :8820, avg score : 172.6
# of episode :8840, avg score : 146.5
# of episode :8860, avg score : 130.1
# of episode :8880, avg score : 117.0
# of episode :8900, avg score : 116.2
# of episode :8920, avg score : 129.0
# of episode :8940, avg score : 139.9
# of episode :8960, avg score : 164.8
# of episode :8980, avg score : 205.6
# of episode :9000, avg score : 284.8
# of episode :9020, avg score : 359.4
# of episode :9040, avg score : 377.4
# of episode :9060, avg score : 439.4
# of episode :9080, avg score : 914.7
# of episode :9100, avg score : 792.5
# of episode :9120, avg score : 641.5
# of episode :9140, avg score : 330.0
# of episode :9160, avg score : 539.9
# of episode :9180, avg score : 239.1
# of episode :9200, avg score : 181.8
# of episode