-
Notifications
You must be signed in to change notification settings - Fork 0
/
swarm.py
executable file
·105 lines (85 loc) · 3.49 KB
/
swarm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python
# ----------------------------------------------------------------------
# Numenta Platform for Intelligent Computing (NuPIC)
# Copyright (C) 2013, Numenta, Inc. Unless you have an agreement
# with Numenta, Inc., for a separate license for this software code, the
# following terms and conditions apply:
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses.
#
# http://numenta.org/licenses/
# ----------------------------------------------------------------------
"""
Groups together the code dealing with swarming.
(This is a component of the One Hot Gym Prediction Tutorial.)
"""
import os
import pprint
from nupic.swarming import permutations_runner
from swarm_description import SWARM_DESCRIPTION
INPUT_FILE = "ctrl_sim.csv"
DESCRIPTION = (
"This script runs a swarm on the input data (rec-center-hourly.csv) and\n"
"creates a model parameters file in the `model_params` directory containing\n"
"the best model found by the swarm. Dumps a bunch of crud to stdout because\n"
"that is just what swarming does at this point. You really don't need to\n"
"pay any attention to it.\n"
)
def modelParamsToString(modelParams):
pp = pprint.PrettyPrinter(indent=2)
return pp.pformat(modelParams)
def writeModelParamsToFile(modelParams, name):
cleanName = name.replace(" ", "_").replace("-", "_")
paramsName = "%s_model_params.py" % cleanName
outDir = os.path.join(os.getcwd(), 'model_params')
if not os.path.isdir(outDir):
os.mkdir(outDir)
outPath = os.path.join(os.getcwd(), 'model_params', paramsName)
with open(outPath, "wb") as outFile:
modelParamsString = modelParamsToString(modelParams)
outFile.write("MODEL_PARAMS = \\\n%s" % modelParamsString)
return outPath
def swarmForBestModelParams(swarmConfig, name, maxWorkers=4):
outputLabel = name
permWorkDir = os.path.abspath('swarm')
if not os.path.exists(permWorkDir):
os.mkdir(permWorkDir)
modelParams = permutations_runner.runWithConfig(
swarmConfig,
{"maxWorkers": maxWorkers, "overwrite": True},
outputLabel=outputLabel,
outDir=permWorkDir,
permWorkDir=permWorkDir,
verbosity=0
)
modelParamsFile = writeModelParamsToFile(modelParams, name)
return modelParamsFile
def printSwarmSizeWarning(size):
if size is "small":
print "= THIS IS A DEBUG SWARM. DON'T EXPECT YOUR MODEL RESULTS TO BE GOOD."
elif size is "medium":
print "= Medium swarm. Sit back and relax, this could take awhile."
else:
print "= LARGE SWARM! Might as well load up the Star Wars Trilogy."
def swarm(filePath):
name = os.path.splitext(os.path.basename(filePath))[0]
print "================================================="
print "= Swarming on %s data..." % name
printSwarmSizeWarning(SWARM_DESCRIPTION["swarmSize"])
print "================================================="
modelParams = swarmForBestModelParams(SWARM_DESCRIPTION, name)
print "\nWrote the following model param files:"
print "\t%s" % modelParams
if __name__ == "__main__":
print DESCRIPTION
swarm(INPUT_FILE)