Skip to content

Commit

Permalink
logger saver
Browse files Browse the repository at this point in the history
  • Loading branch information
CUN-bjy committed Nov 22, 2020
1 parent b9003a1 commit 519926a
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from agent.ddpg import ddpgAgent

NUM_EPISODES_ = 5000
NUM_EPISODES_ = 20000

def model_train(pretrained_):
# Create Environments
Expand Down Expand Up @@ -74,6 +74,7 @@ def model_train(pretrained_):
try:
act_range = env.action_space.high
rewards = []; critic_losses = []
max_reward = 0
for epi in range(NUM_EPISODES_):
print("=========EPISODE # %d =========="%epi)
obs = env.reset()
Expand Down Expand Up @@ -105,13 +106,14 @@ def model_train(pretrained_):
print("Episode#%d, steps:%d, rewards:%f"%(epi,t,epi_reward))
agent.replay(1)

# save weights at every 50 iters
if epi%50 == 0:
# save weights at the new records performance
if epi_reward > max_reward:
max_reward = epi_reward
dir_path = "%s/weights"%os.getcwd()
if not os.path.isdir(dir_path):
os.mkdir(dir_path)
path = dir_path+'/'+'gym_ddpg_'
agent.save_weights(path + 'ep%d'%epi)
agent.save_weights(path + 'ep%d_%f'%(epi,max_reward))


# save reward logs
Expand All @@ -132,12 +134,12 @@ def model_train(pretrained_):
if not os.path.isdir(dir_path):
os.mkdir(dir_path)
path = dir_path+'/'+'gym_ddpg_'
agent.save_weights(path +'temp')
agent.save_weights(path +'lastest')
env.close()

# log saver
import pickle
pickle.dump(open(path+'%s.log'%time.Time.now(),'wb'))
pickle.dump(logger,open(path+'%s.log'%time.time(),'wb'))


argparser = argparse.ArgumentParser(
Expand Down

0 comments on commit 519926a

Please sign in to comment.