/
torch_DQN_lander.py
37 lines (29 loc) · 1.04 KB
/
torch_DQN_lander.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import gym
import torch as T
from DQN_01_base import Agent
import numpy as np
if __name__=='__main__':
env=gym.make('LunarLander-v2')
agent=Agent(gamma=0.99,epsilon=1.0,batch_size=64, n_actions=4,
eps_end=0.01,input_dims=[8],lr=0.003)
scores,eps_history=[],[]
n_games=5000
for i in range(n_games):
score=0
done=False
observation=env.reset()
while not done:
action=agent.choose_action(observation)
observation_,reward,done,info=env.step(action)
score+=reward
agent.store_transition(observation,action,reward,observation_,done)
agent.learn()
observation=observation_
scores.append(score)
eps_history.append(agent.epsilon)
avg_score=np.mean([scores[-100:]])
print('episode ',i,'score %.2f' % score,'average score %.2f' % avg_score,
'epsilon %.2f' %agent.epsilon)
if i % 5000==0:
T.save(agent,'whole_agent.pt')
x=[i+1 for i in range(n_games)]