-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
100 lines (69 loc) · 2.82 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from __future__ import print_function, division
import matplotlib.pyplot as plt
import numpy as np
import gym
from atari_util import PreprocessAtari
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from agent import Agent
from env_pool import EnvPool
from visualization import ImageGenerator
from train import *
def make_env():
game_id="KungFuMasterDeterministic-v0"
env = gym.make(game_id)
env = PreprocessAtari(env, height=42, width=42,
crop = lambda img: img[60:-30, 15:],
color=False, n_frames=1)
return env
def flatten_obs(unflattened_obs, n_parallel_games):
obs = unflattened_obs[0]
for i in range(1, n_parallel_games):
obs = np.vstack((obs, unflattened_obs[i]))
return obs
def train(agent, env_pool, niters, n_parallel_games, gamma, save_path=None, curiosity=False):
n_parallel_games = n_parallel_games
gamma = gamma
if cuda:
agent.cuda()
# pool = EnvPool(agent, make_env, n_parallel_games)
opt_decision = torch.optim.Adam(list(agent.decisionUnit.parameters()) + list(agent.memoryUnit.parameters()), lr=1e-5)
opt_curiosity = torch.optim.Adam(agent.curiosityUnit.get_all_params(), lr=1e-3)
if curiosity:
opts = [opt_decision, opt_curiosity]
else:
opts = [opt_decision]
rewards_history = []
ImageGen = ImageGenerator('./graph.png')
for i in range(niters):
memory = list(pool.prev_memory_states)
rollout_obs, rollout_actions, rollout_rewards, rollout_mask = pool.interact(10)
loss = train_on_rollout(agent, opts, rollout_obs, rollout_actions, rollout_rewards, rollout_mask, memory, gamma, curiosity)
if i % 100 == 0:
test_var = torch.autograd.Variable(torch.from_numpy(flatten_obs(rollout_obs, n_parallel_games)).cuda(), requires_grad=False)
rewards_history.append(np.mean(evaluate(agent, env, n_games=1)))
ImageGen(rewards_history)
print(loss)
# if i % 200 == 0 and save_path:
# agent.save_agent(save_path)
return rewards_history
if __name__ == "__main__":
n_parallel_games = 5
gamma = 0.99
env = make_env()
obs_shape = env.observation_space.shape
n_actions = env.action_space.n
print("Observation shape:", obs_shape)
print("Num actions:", n_actions)
print("Action names:", env.env.env.get_action_meanings())
n_parallel_games = n_parallel_games
gamma = gamma
from agent import Agent
agent = Agent(obs_shape, n_actions, n_parallel_games)
if cuda:
agent.cuda()
chkpt_dir = "./chkpt"
pool = EnvPool(agent, make_env, n_parallel_games)
train(agent, pool, 50000, n_parallel_games, gamma, save_path=chkpt_dir, curiosity=True)