-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmodel_trainer.py
120 lines (89 loc) · 3.94 KB
/
model_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Train a Model on NHL 94
"""
import os
import sys
import time
import argparse
from common import get_model_file_name, com_print, init_logger, create_output_dir
from models import init_model
from envs import init_env
def parse_cmdline(argv):
parser = argparse.ArgumentParser()
parser.add_argument('--alg', type=str, default='ppo2')
parser.add_argument('--nn', type=str, default='CnnPolicy')
parser.add_argument('--nnsize', type=int, default='256')
parser.add_argument('--env', type=str, default='NHL941on1-Genesis')
parser.add_argument('--state', type=str, default=None)
parser.add_argument('--num_players', type=int, default='1')
parser.add_argument('--num_env', type=int, default=24)
parser.add_argument('--num_timesteps', type=int, default=6000000)
parser.add_argument('--output_basedir', type=str, default='~/OUTPUT')
parser.add_argument('--load_p1_model', type=str, default='')
parser.add_argument('--display_width', type=int, default='1440')
parser.add_argument('--display_height', type=int, default='810')
parser.add_argument('--alg_verbose', default=True, action='store_true')
parser.add_argument('--info_verbose', default=True, action='store_true')
parser.add_argument('--play', default=False, action='store_true')
parser.add_argument('--rf', type=str, default='')
parser.add_argument('--deterministic', default=True, action='store_true')
print(argv)
args = parser.parse_args(argv)
#if args.info_verbose is False:
# logger.set_level(logger.DISABLED)
return args
class ModelTrainer:
def __init__(self, args, logger):
self.args = args
#if self.args.alg_verbose:
# logger.log('========= Init =============')
self.output_fullpath = create_output_dir(args)
model_savefile_name = get_model_file_name(args)
self.model_savepath = os.path.join(self.output_fullpath, model_savefile_name)
self.env = init_env(self.output_fullpath, args.num_env, args.state, args.num_players, args)
self.p1_model = init_model(self.output_fullpath, args.load_p1_model, args.alg, args, self.env, logger)
#if self.args.alg_verbose:
com_print('OUTPUT PATH: %s' % self.output_fullpath)
com_print('ENV: %s' % args.env)
com_print('STATE: %s' % args.state)
com_print('NN: %s' % args.nn)
com_print('ALGO: %s' % args.alg)
com_print('NUM TIMESTEPS: %s' % args.num_timesteps)
com_print('NUM ENV: %s' % args.num_env)
com_print('NUM PLAYERS: %s' % args.num_players)
print(self.env.observation_space)
def train(self):
#if self.args.alg_verbose:
com_print('========= Start Training ==========')
self.p1_model.learn(total_timesteps=self.args.num_timesteps)
#if self.args.alg_verbose:
com_print('========= End Training ==========')
self.p1_model.save(self.model_savepath )
#if self.args.alg_verbose:
com_print('Model saved to:%s' % self.model_savepath)
return self.model_savepath
def play(self, args, continuous=True):
#if self.args.alg_verbose:
com_print('========= Start Play Loop ==========')
state = self.env.reset()
while True:
self.env.render(mode='human')
p1_actions = self.p1_model.predict(state, deterministic=args.deterministic)
state, reward, done, info = self.env.step(p1_actions[0])
time.sleep(0.01)
#print(reward)
if done[0]:
state = self.env.reset()
if not continuous and done is True:
return info
def main(argv):
args = parse_cmdline(argv[1:])
logger = init_logger(args)
com_print("=========== Params ===========")
com_print(args)
trainer = ModelTrainer(args, logger)
trainer.train()
if args.play:
trainer.play(args)
if __name__ == '__main__':
main(sys.argv)