Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
morvanzhou committed Mar 15, 2020
1 parent 6ee65a8 commit 5ab27ab
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion continuous_A3C.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def run(self):
if __name__ == "__main__":
gnet = Net(N_S, N_A) # global network
gnet.share_memory() # share the global parameters in multiprocessing
opt = SharedAdam(gnet.parameters(), lr=0.0002) # global optimizer
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.95, 0.999)) # global optimizer
global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()

# parallel training
Expand Down
22 changes: 11 additions & 11 deletions discrete_A3C.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import os
os.environ["OMP_NUM_THREADS"] = "1"

UPDATE_GLOBAL_ITER = 10
UPDATE_GLOBAL_ITER = 5
GAMMA = 0.9
MAX_EP = 4000
MAX_EP = 3000

env = gym.make('CartPole-v0')
N_S = env.observation_space.shape[0]
Expand All @@ -29,17 +29,17 @@ def __init__(self, s_dim, a_dim):
super(Net, self).__init__()
self.s_dim = s_dim
self.a_dim = a_dim
self.pi1 = nn.Linear(s_dim, 200)
self.pi2 = nn.Linear(200, a_dim)
self.v1 = nn.Linear(s_dim, 100)
self.v2 = nn.Linear(100, 1)
self.pi1 = nn.Linear(s_dim, 128)
self.pi2 = nn.Linear(128, a_dim)
self.v1 = nn.Linear(s_dim, 128)
self.v2 = nn.Linear(128, 1)
set_init([self.pi1, self.pi2, self.v1, self.v2])
self.distribution = torch.distributions.Categorical

def forward(self, x):
pi1 = F.relu6(self.pi1(x))
pi1 = torch.tanh(self.pi1(x))
logits = self.pi2(pi1)
v1 = F.relu6(self.v1(x))
v1 = torch.tanh(self.v1(x))
values = self.v2(v1)
return logits, values

Expand Down Expand Up @@ -67,7 +67,7 @@ def loss_func(self, s, a, v_t):
class Worker(mp.Process):
def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name):
super(Worker, self).__init__()
self.name = 'w%i' % name
self.name = 'w%02i' % name
self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
self.gnet, self.opt = gnet, opt
self.lnet = Net(N_S, N_A) # local network
Expand All @@ -80,7 +80,7 @@ def run(self):
buffer_s, buffer_a, buffer_r = [], [], []
ep_r = 0.
while True:
if self.name == 'w0':
if self.name == 'w00':
self.env.render()
a = self.lnet.choose_action(v_wrap(s[None, :]))
s_, r, done, _ = self.env.step(a)
Expand All @@ -106,7 +106,7 @@ def run(self):
if __name__ == "__main__":
gnet = Net(N_S, N_A) # global network
gnet.share_memory() # share the global parameters in multiprocessing
opt = SharedAdam(gnet.parameters(), lr=0.0001) # global optimizer
opt = SharedAdam(gnet.parameters(), lr=1e-4, betas=(0.92, 0.999)) # global optimizer
global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()

# parallel training
Expand Down
2 changes: 1 addition & 1 deletion shared_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class SharedAdam(torch.optim.Adam):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.9), eps=1e-8,
def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8,
weight_decay=0):
super(SharedAdam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
# State initialization
Expand Down

0 comments on commit 5ab27ab

Please sign in to comment.