In [1]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import gym

env = gym.make('CartPole-v0')
print(env.observation_space)
print(env.action_space)

n_features = env.observation_space.shape[0]
n_actions = env.action_space.n
n_hiddens = 10
print('n_features:%d, n_hiddens:%d, n_actions:%d' % (n_features, n_hiddens, n_actions))

Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)
Discrete(2)
n_features:4, n_hiddens:10, n_actions:2


In [5]:
class Actor(nn.Module):
    def __init__(self, n_features, n_hiddens, n_actions):
        super().__init__()
        self.n_features = n_features
        self.n_hiddens = n_hiddens
        self.n_actions = n_actions
        
        self.fc1 = nn.Linear(n_features, n_hiddens)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(n_hiddens, n_actions)
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        assert x.shape[-1] == self.n_features
        
        x = self.relu(self.fc1(x))
        x = self.softmax(self.fc2(x))
        return x
    
class Critic(nn.Module):
    def __init__(self, n_features, n_hiddens, gamma):
        super().__init__()
        self.n_features = n_features
        self.n_hiddens = n_hiddens
        self.gamma = gamma
        
        self.fc1 = nn.Linear(n_features, n_hiddens)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(n_hiddens, 1)
        
    def forward(self, x):
        assert x.shape[-1] == self.n_features
        
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class ACNetwork:
    def __init__(self, n_features, n_actions, n_hiddens):
        self.lr_a = 0.001
        self.lr_c = 0.001
        self.gamma = 0.95
        
        self.actor = Actor(n_features, n_hiddens, n_actions)
        self.critic = Critic(n_features, n_hiddens, self.gamma)
        self.actor_op = torch.optim.Adam(self.actor.parameters(), lr=self.lr_a)
        self.critic_op = torch.optim.Adam(self.critic.parameters(), lr=self.lr_c)
        
    def choose_action(self, s):
        act_prob = self.actor(torch.Tensor(s))
        action = torch.multinomial(act_prob, 1).detach().numpy()[0]
        return action
    
    def learn(self, s, a, r, s_):
        val = self.critic(torch.Tensor(s))
        val_ = self.critic(torch.Tensor(s_)).detach()
        td_error = r + self.gamma * val_ - val
        critic_loss = torch.square(td_error)
        self.critic_op.zero_grad()
        critic_loss.backward()
        self.critic_op.step()
        
        act_prob = self.actor(torch.Tensor(s))
        actor_loss = -torch.log(act_prob[a]) * td_error.detach()
        self.actor_op.zero_grad()
        actor_loss.backward()
        self.actor_op.step()
    
    def save(self, actor_path, critic_path):
        torch.save(self.actor, actor_path)
        torch.save(self.critic, critic_path)
    
    def load(self, actor_path, critic_path):
        self.actor = torch.load(actor_path)
        self.critic = torch.load(critic_path)

In [6]:
ac = ACNetwork(n_features, n_actions, n_hiddens)

stop = False
eps = 3000
for ep in range(eps):
    s = env.reset()
    i = 0
    track_r = []
    while True:
        i += 1
        env.render()
        a = ac.choose_action(s)
        s_, r, done, info = env.step(a)
        if done and i < 200:
            r = -20
        track_r.append(r)
        ac.learn(s, a, r, s_)
        
        s = s_
        
        if done:
            ep_rs_sum = np.sum(track_r)
            if 'running_reward' not in globals():
                running_reward = ep_rs_sum
            else:
                running_reward = running_reward * 0.9 + ep_rs_sum * 0.1
            print('episode:', ep, ' reward:', int(running_reward), ' round:', i)
            if int(running_reward) > 170:
                stop = True
            break
    if stop:
        break
env.close()

episode: 0  reward: -5  round: 16
episode: 1  reward: -4  round: 23
episode: 2  reward: -4  round: 15
episode: 3  reward: -4  round: 13
episode: 4  reward: -5  round: 12
episode: 5  reward: -5  round: 17
episode: 6  reward: -5  round: 17
episode: 7  reward: -5  round: 15
episode: 8  reward: -5  round: 10
episode: 9  reward: -5  round: 17
episode: 10  reward: -5  round: 14
episode: 11  reward: -5  round: 14
episode: 12  reward: -5  round: 21
episode: 13  reward: -5  round: 18
episode: 14  reward: -4  round: 24
episode: 15  reward: -4  round: 10
episode: 16  reward: -4  round: 20
episode: 17  reward: -4  round: 20
episode: 18  reward: -4  round: 13
episode: 19  reward: -5  round: 10
episode: 20  reward: -5  round: 10
episode: 21  reward: -5  round: 13
episode: 22  reward: -4  round: 29
episode: 23  reward: -4  round: 15
episode: 24  reward: -5  round: 12
episode: 25  reward: -5  round: 17
episode: 26  reward: -5  round: 14
episode: 27  reward: -5  round: 16
episode: 28  reward: -5  round

episode: 234  reward: -3  round: 17
episode: 235  reward: -2  round: 22
episode: 236  reward: -3  round: 13
episode: 237  reward: -3  round: 11
episode: 238  reward: -3  round: 22
episode: 239  reward: -3  round: 15
episode: 240  reward: -4  round: 11
episode: 241  reward: -3  round: 21
episode: 242  reward: -4  round: 15
episode: 243  reward: -1  round: 42
episode: 244  reward: 3  round: 70
episode: 245  reward: 2  round: 16
episode: 246  reward: 1  round: 14
episode: 247  reward: 0  round: 11
episode: 248  reward: 0  round: 11
episode: 249  reward: 0  round: 27
episode: 250  reward: 0  round: 13
episode: 251  reward: 0  round: 22
episode: 252  reward: 0  round: 23
episode: 253  reward: -1  round: 13
episode: 254  reward: 0  round: 24
episode: 255  reward: 0  round: 21
episode: 256  reward: 0  round: 27
episode: 257  reward: 0  round: 14
episode: 258  reward: 0  round: 23
episode: 259  reward: -1  round: 13
episode: 260  reward: -1  round: 17
episode: 261  reward: -1  round: 20
episod

episode: 465  reward: 35  round: 49
episode: 466  reward: 33  round: 35
episode: 467  reward: 36  round: 82
episode: 468  reward: 33  round: 31
episode: 469  reward: 33  round: 47
episode: 470  reward: 31  round: 40
episode: 471  reward: 29  round: 35
episode: 472  reward: 30  round: 60
episode: 473  reward: 30  round: 45
episode: 474  reward: 29  round: 47
episode: 475  reward: 31  round: 67
episode: 476  reward: 31  round: 59
episode: 477  reward: 36  round: 97
episode: 478  reward: 35  round: 50
episode: 479  reward: 39  round: 94
episode: 480  reward: 42  round: 89
episode: 481  reward: 41  round: 51
episode: 482  reward: 39  round: 42
episode: 483  reward: 36  round: 35
episode: 484  reward: 33  round: 32
episode: 485  reward: 33  round: 52
episode: 486  reward: 34  round: 59
episode: 487  reward: 33  round: 50
episode: 488  reward: 38  round: 106
episode: 489  reward: 36  round: 42
episode: 490  reward: 47  round: 162
episode: 491  reward: 45  round: 47
episode: 492  reward: 47  

episode: 690  reward: 104  round: 122
episode: 691  reward: 99  round: 82
episode: 692  reward: 103  round: 160
episode: 693  reward: 113  round: 200
episode: 694  reward: 122  round: 200
episode: 695  reward: 129  round: 200
episode: 696  reward: 126  round: 122
episode: 697  reward: 134  round: 200
episode: 698  reward: 131  round: 125
episode: 699  reward: 132  round: 165
episode: 700  reward: 135  round: 186
episode: 701  reward: 132  round: 122
episode: 702  reward: 130  round: 132
episode: 703  reward: 127  round: 124
episode: 704  reward: 134  round: 200
episode: 705  reward: 141  round: 200
episode: 706  reward: 143  round: 185
episode: 707  reward: 142  round: 154
episode: 708  reward: 148  round: 200
episode: 709  reward: 144  round: 137
episode: 710  reward: 143  round: 151
episode: 711  reward: 143  round: 163
episode: 712  reward: 145  round: 183
episode: 713  reward: 142  round: 141
episode: 714  reward: 143  round: 168
episode: 715  reward: 137  round: 105
episode: 716  

episode: 910  reward: 129  round: 165
episode: 911  reward: 136  round: 200
episode: 912  reward: 139  round: 180
episode: 913  reward: 145  round: 200
episode: 914  reward: 150  round: 200
episode: 915  reward: 150  round: 171
episode: 916  reward: 152  round: 186
episode: 917  reward: 148  round: 137
episode: 918  reward: 142  round: 106
episode: 919  reward: 138  round: 122
episode: 920  reward: 140  round: 182
episode: 921  reward: 139  round: 149
episode: 922  reward: 145  round: 200
episode: 923  reward: 142  round: 143
episode: 924  reward: 141  round: 151
episode: 925  reward: 147  round: 200
episode: 926  reward: 152  round: 200
episode: 927  reward: 157  round: 200
episode: 928  reward: 161  round: 200
episode: 929  reward: 165  round: 200
episode: 930  reward: 168  round: 200
episode: 931  reward: 172  round: 200


In [7]:
ac.save('model/AC_Actor.pkl', 'model/AC_Critic.pkl')

In [8]:
ac = ACNetwork(n_features, n_actions, n_hiddens)
ac.load('model/AC_Actor.pkl', 'model/AC_Critic.pkl')

eps = 10
for i in range(eps):
    step = 0
    s = env.reset()
    while(True):
        env.render()
        a = ac.choose_action(s)
        s, _, done, _ = env.step(a)
        step += 1
        if done:
            print('Total steps:', step)
            break
env.close()

Total steps: 200
Total steps: 200
Total steps: 200
Total steps: 200
Total steps: 200
Total steps: 200
Total steps: 200
Total steps: 200
Total steps: 200
Total steps: 200
