In [1]:
import gym
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim 
from torch.distributions import Categorical

In [2]:
# hyperparameters
learning_rate = 0.0002
gamma = 0.98
## number of iteration for taking actions before parameter updates
n_rollout = 10

In [8]:
class ActorCritic(nn.Module):
    def __init__(self):
        super(ActorCritic, self).__init__()
        self.data = []
        
        self.fc1 = nn.Linear(4, 256)
        self.fc_pi = nn.Linear(256, 2)
        self.fc_v = nn.Linear(256, 1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
    def pi(self, x, softmax_dim=0):
        x = F.relu(self.fc1(x))
        x = self.fc_pi(x)
        prob = F.softmax(x, dim=softmax_dim)
        return prob
    
    def v(self, x):
        x = F.relu(self.fc1(x))
        v = self.fc_v(x)
        return v
    
    def put_data(self, transition):
        self.data.append(transition)
        
    def make_batch(self):
        s_lst, a_lst, r_lst, s_prime_lst, done_lst = [], [], [], [], []
        for transition in self.data:
            s, a, r, s_prime, done = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r/100.0])
            s_prime_lst.append(s_prime)
            done_mask = 0.0 if done else 1.0
            done_lst.append([done_mask])
        
        s_batch, a_batch, r_batch, s_prime_batch, done_batch = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), torch.tensor(r_lst, dtype=torch.float), torch.tensor(s_prime_lst, dtype=torch.float), torch.tensor(done_lst, dtype=torch.float)
        self.data = []
        return s_batch, a_batch, r_batch, s_prime_batch, done_batch
    
    def train_net(self):
        s, a, r, s_prime, done = self.make_batch()
        td_target = r + gamma * self.v(s_prime) * done      # 'done' variable for making value 0 if done
        delta = td_target - self.v(s)
        
        pi = self.pi(s, softmax_dim=1)
        pi_a = pi.gather(1, a)
        # loss function = policy loss + delta loss 
        # we use detach function not to train fixed values for moving both parameters at the same time. 
        loss = -torch.log(pi_a) * delta.detach() + F.smooth_l1_loss(self.v(s), td_target.detach())
        
        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()
    

In [9]:
env = gym.make('CartPole-v1')

In [10]:
model = ActorCritic()

In [11]:
print_interval = 100
score = 0.0

In [12]:
for n_epi in range(10000):
    done = False
    s, _ = env.reset()
    while not done:
        for i in range(n_rollout):
            prob = model.pi(torch.from_numpy(s).float())
            m = Categorical(prob)
            a = m.sample().item()
            s_prime, r, done, truncated, info = env.step(a)
            model.put_data((s, a, r, s_prime, done))
            
            s = s_prime
            score += r
            
            if done:
                break
        
        model.train_net()
        
    if n_epi%print_interval == 0 and n_epi != 0:
        print('# of episode: {}, avg scroes: {:.1f}'
              .format(n_epi, score/print_interval))
        score = 0.0
        
env.close()

# of episode: 100, avg scroes: 17.9
# of episode: 200, avg scroes: 21.1
# of episode: 300, avg scroes: 26.8
# of episode: 400, avg scroes: 39.3
# of episode: 500, avg scroes: 63.2
# of episode: 600, avg scroes: 110.7
# of episode: 700, avg scroes: 247.7
# of episode: 800, avg scroes: 292.8
# of episode: 900, avg scroes: 312.0
# of episode: 1000, avg scroes: 349.1
# of episode: 1100, avg scroes: 343.5
# of episode: 1200, avg scroes: 272.1
# of episode: 1300, avg scroes: 277.6
# of episode: 1400, avg scroes: 538.2
# of episode: 1500, avg scroes: 1079.1
# of episode: 1600, avg scroes: 443.9
# of episode: 1700, avg scroes: 712.9
# of episode: 1800, avg scroes: 161.8
# of episode: 1900, avg scroes: 181.8
# of episode: 2000, avg scroes: 243.4
# of episode: 2100, avg scroes: 248.3
# of episode: 2200, avg scroes: 531.5
# of episode: 2300, avg scroes: 311.6
# of episode: 2400, avg scroes: 372.8
# of episode: 2500, avg scroes: 632.9
# of episode: 2600, avg scroes: 606.6
# of episode: 2700, avg s

KeyboardInterrupt: 