-
Notifications
You must be signed in to change notification settings - Fork 1
/
clgen.py
314 lines (262 loc) · 9.49 KB
/
clgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 4 11:40:47 2019
@author: deepfuzz
"""
"""CLgen: a deep learning program generator.
The core operations of CLgen are:
1. Preprocess and encode a corpus of handwritten example programs.
2. Define and train a machine learning model on the corpus.
3. Sample the trained model to generate new programs.
This program automates the execution of all three stages of the pipeline.
The pipeline can be interrupted and resumed at any time. Results are cached
across runs. Please note that many of the steps in the pipeline are extremely
compute intensive and highly parallelized. If configured with CUDA support,
any NVIDIA GPUs will be used to improve performance where possible.
Made with \033[1;31m♥\033[0;0m by Chris Cummins <chrisc.101@gmail.com>.
https://chriscummins.cc/clgen
"""
import cProfile
import contextlib
import os
import pathlib
import shutil
import sys
import traceback
import typing
from absl import app
from absl import flags
from absl import logging
from clgen import errors
from clgen import samplers
from clgen.models import models
from clgen.models import pretrained
from clgen.proto import clgen_pb2
from clgen.proto import model_pb2
from labm8 import pbutil
from labm8 import prof
FLAGS = flags.FLAGS
flags.DEFINE_string(
'config', None,
'Path to a Instance proto file.')
flags.DEFINE_integer(
'min_samples', -1,
'The minimum number of samples to make.')
flags.DEFINE_string(
'stop_after', None,
'Stop CLgen early. Valid options are: "corpus", or "train".')
flags.DEFINE_string(
'print_cache_path', None,
'Print the directory of a cache and exit. Valid options are: "corpus", '
'"model", or "sampler".')
flags.DEFINE_bool(
'print_preprocessed', False,
'Print the pre-processed corpus to stdout and exit.')
flags.DEFINE_string(
'export_model', None,
'Path to export a trained TensorFlow model to. This exports all of the '
'files required for sampling to specified directory. The directory can '
'then be used as the pretrained_model field of an Instance proto config.')
flags.DEFINE_bool(
'clgen_debug', False,
'Enable a debugging mode of CLgen python runtime. When enabled, errors '
'which may otherwise be caught lead to program crashes and stack traces.')
flags.DEFINE_bool(
'clgen_profiling', False,
'Enable CLgen self profiling. Profiling results be logged.')
flags.DEFINE_bool(
'visualize', False,
'Enable Loss vs Epoch Visualization while training')
'''
Example: --sampling_technique "topK 3"
--sampling_technique "nucleus 3"
--sampling_technique "beam 3"
'''
flags.DEFINE_string(
'sampling_technique', "default",
' Valid options are: "topK K-value" or "nucleus P-value" or "beam beamwidthvalue" or "default".')
class Instance(object):
"""A CLgen instance."""
def __init__(self, config: clgen_pb2.Instance):
"""Instantiate an instance.
Args:
config: An Instance proto.
Raises:
UserError: If the instance proto contains invalid values, is missing
a model or sampler fields.
"""
try:
pbutil.AssertFieldIsSet(config, 'model_specification')
pbutil.AssertFieldIsSet(config, 'sampler')
except pbutil.ProtoValueError as e:
raise errors.UserError(e)
self.working_dir = None
if config.HasField('working_dir'):
self.working_dir: pathlib.Path = pathlib.Path(
os.path.expandvars(config.working_dir)).expanduser().absolute()
# Enter a session so that the cache paths are set relative to any requested
# working directory.
with self.Session():
if config.HasField('model'):
self.model: models.Model = models.Model(config.model)
else:
self.model: pretrained.PreTrainedModel = pretrained.PreTrainedModel(
pathlib.Path(config.pretrained_model))
self.sampler: samplers.Sampler = samplers.Sampler(config.sampler)
@contextlib.contextmanager
def Session(self) -> 'Instance':
"""Scoped $CLGEN_CACHE value."""
old_working_dir = os.environ.get('CLGEN_CACHE', '')
if self.working_dir:
os.environ['CLGEN_CACHE'] = str(self.working_dir)
yield self
if self.working_dir:
os.environ['CLGEN_CACHE'] = old_working_dir
def Train(self, *args, **kwargs) -> None:
with self.Session():
self.model.Train(*args, **kwargs)
def Sample(self, *args, **kwargs) -> typing.List[model_pb2.Sample]:
with self.Session():
return self.model.Sample(self.sampler, *args, **kwargs)
def ToProto(self) -> clgen_pb2.Instance:
"""Get the proto config for the instance."""
config = clgen_pb2.Instance()
config.working_dir = str(self.working_dir)
config.model.CopyFrom(self.model.config)
config.sampler.CopyFrom(self.sampler.config)
return config
@classmethod
def FromFile(cls, path: pathlib.Path) -> 'Instance':
return cls(pbutil.FromFile(path, clgen_pb2.Instance()))
def Flush():
"""Flush logging and stout/stderr outputs."""
logging.flush()
sys.stdout.flush()
sys.stderr.flush()
def LogException(exception: Exception):
"""Log an error."""
logging.error(f"""\
%s (%s)
Please report bugs at <https://github.com/ChrisCummins/phd/issues>\
""", exception, type(exception).__name__)
sys.exit(1)
def LogExceptionWithStackTrace(exception: Exception):
"""Log an error with a stack trace."""
# get limited stack trace
def _msg(i, x):
n = i + 1
filename, lineno, fnname, _ = x
# TODO(cec): Report filename relative to PhD root.
loc = f'{filename}:{lineno}'
return f' #{n} {loc: <18} {fnname}()'
_, _, tb = sys.exc_info()
NUM_ROWS = 5 # number of rows in traceback
trace = reversed(traceback.extract_tb(tb, limit=NUM_ROWS + 1)[1:])
message = "\n".join(_msg(*r) for r in enumerate(trace))
logging.error("""\
%s (%s)
stacktrace:
%s
Please report bugs at <https://github.com/ChrisCummins/phd/issues>\
""", exception, type(exception).__name__, message)
sys.exit(1)
def RunWithErrorHandling(function_to_run: typing.Callable, *args,
**kwargs) -> typing.Any:
"""
Runs the given method as the main entrypoint to a program.
If an exception is thrown, print error message and exit. If FLAGS.debug is
set, the exception is not caught.
Args:
function_to_run: The function to run.
*args: Arguments to be passed to the function.
**kwargs: Arguments to be passed to the function.
Returns:
The return value of the function when called with the given args.
"""
if FLAGS.clgen_debug:
# Enable verbose stack traces. See: https://pymotw.com/2/cgitb/
import cgitb
cgitb.enable(format='text')
return function_to_run(*args, **kwargs)
try:
def RunContext():
"""Run the function with arguments."""
return function_to_run(*args, **kwargs)
if prof.is_enabled():
return cProfile.runctx('RunContext()', None, locals(), sort='tottime')
else:
return RunContext()
except app.UsageError as err:
# UsageError is handled by the call to app.run(), not here.
raise err
except errors.UserError as err:
logging.error("%s (%s)", err, type(err).__name__)
sys.exit(1)
except KeyboardInterrupt:
Flush()
print("\nReceived keyboard interrupt, terminating", file=sys.stderr)
sys.exit(1)
except errors.File404 as e:
Flush()
LogException(e)
sys.exit(1)
except Exception as e:
Flush()
LogExceptionWithStackTrace(e)
sys.exit(1)
def DoFlagsAction():
"""Do the action requested by the command line flags."""
if not FLAGS.config:
raise app.UsageError("Missing required argument: '--config'")
config_path = pathlib.Path(FLAGS.config)
if not config_path.is_file():
raise app.UsageError(f"File not found: '{config_path}'")
config = pbutil.FromFile(config_path, clgen_pb2.Instance())
os.environ['PWD'] = str(config_path.parent)
if FLAGS.clgen_profiling:
prof.enable()
instance = Instance(config)
with instance.Session():
if FLAGS.print_cache_path == 'corpus':
print(instance.model.corpus.cache.path)
return
elif FLAGS.print_cache_path == 'model':
print(instance.model.cache.path)
return
elif FLAGS.print_cache_path == 'sampler':
print(instance.model.SamplerCache(instance.sampler))
return
elif FLAGS.print_cache_path:
raise app.UsageError(
f"Invalid --print_cache_path argument: '{FLAGS.print_cache_path}'")
if FLAGS.print_preprocessed:
print(instance.model.corpus.GetTextCorpus(shuffle=False))
return
# The default action is to sample the model.
if FLAGS.stop_after == 'corpus':
instance.model.corpus.Create()
elif FLAGS.stop_after == 'train':
instance.model.Train()
logging.info('Model: %s', instance.model.cache.path)
elif FLAGS.stop_after:
raise app.UsageError(
f"Invalid --stop_after argument: '{FLAGS.stop_after}'")
elif FLAGS.export_model:
instance.model.Train()
export_dir = pathlib.Path(FLAGS.export_model)
for path in instance.model.InferenceManifest():
relpath = pathlib.Path(os.path.relpath(path, instance.model.cache.path))
(export_dir / relpath.parent).mkdir(parents=True, exist_ok=True)
shutil.copyfile(path, export_dir / relpath)
print(export_dir / relpath)
else:
instance.model.Sample(instance.sampler, FLAGS.min_samples)
def main(argv):
"""Main entry point."""
if len(argv) > 1:
raise app.UsageError(
"Unrecognized command line options: '{}'".format(' '.join(argv[1:])))
RunWithErrorHandling(DoFlagsAction)
if __name__ == '__main__':
app.run(main)