-
Notifications
You must be signed in to change notification settings - Fork 2
/
sample.py
197 lines (167 loc) · 7.38 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import argparse
import gin
import glob
import jax
import os
import requests
import trax
from trax.supervised import inputs
import numpy as onp
import jax.numpy as np
from tokenizers import ByteLevelBPETokenizer
from configs import train_config, test_config
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Sample from a given model')
parser.add_argument('--model_folder', type=str, default='model',
help='Folder For saving and loading the model')
parser.add_argument('--tpu',
default=False, action='store_true')
parser.add_argument('--prompt', type=str, default='',
help='Prompt for beginning the sampling e.g. {"title": "Sampling"')
parser.add_argument('--length', type=int, default=512,
help='Maximum length of sample')
parser.add_argument('--temperature', type=float, default=1.0,
help='Sampling Temperature')
parser.add_argument('--top_k', type=int, default=0,)
parser.add_argument('--exp2', default=False, action='store_true',
help='Use exp2 instead of exp during sampling')
parser.add_argument('--meena_max', default=False, action='store_true',
help='pick the probabilities with highest max')
parser.add_argument('--meena_combine', default=False, action='store_true',
help='use all probabilities from all samples (8 on TPU) at once')
args = parser.parse_args()
tokenizer = ByteLevelBPETokenizer(
'256bytebpe-res-vocab.json', '256bytebpe-merges.txt')
def fake_data(n_devices):
data = onp.zeros((n_devices, 65536))
while True:
yield (data, data, data)
def create_fixed_training_schedule(lr=0.001):
# Yes, it does look unneceserily nested for passing a single float
def FixedTrainingSchedule(*args, **kwargs):
def learning_rate(step):
return {'learning_rate': np.asarray(lr, dtype=np.float32)}
return learning_rate
return FixedTrainingSchedule
def sample_single(model_weights, infer_state, jit_model_infer, length=2048, prompt=None, temperature=1.0, top_k=None, exp2=None,
boost_top=None, meena_max=False, meena_combine=False):
"""Sample from the ReformerLM model
example top_k = 32
exp2 = True|False
boost_top = tuple of how many to boost by how much - e.g. (5, 1.1).
Doesnt boost the very top token or (5, 1.1, True) to include the top one
meena_max = True|False - pick the probabilities with highest max
meena_combine = True|False use all probabilities from all samples (8 on TPU) at once
"""
# Token id 0 is the equivalent of a "start" token
cur_inputs = np.zeros((trax.math.device_count(), 1, 1), dtype=np.int32)
cur_state = infer_state
rngs = trax.math.random.split(
trax.math.random.get_prng(0), trax.math.device_count())
all_samples = []
if prompt is not None:
prompt = np.asarray(
[tokenizer.encode(prompt).ids] * trax.math.device_count())
for iteration in range(length):
logits, cur_state = jit_model_infer(
cur_inputs,
model_weights,
cur_state,
rngs)
if prompt is not None and iteration < prompt.shape[1]:
cur_samples = onp.array(prompt[:, iteration], dtype=int)
else:
logits = onp.array(logits)[:, 0, 0, :] / temperature
# custom settings
if exp2:
probs = onp.exp2(logits)
else:
probs = onp.exp(logits)
del logits
if top_k:
for prob in probs:
prob[prob.argsort()[:-top_k]] = 1e-12
del prob
if boost_top:
boost_until = -1 if len(boost_top) < 3 else probs.shape[-1]
for prob in probs:
sort = prob.argsort()
# boost all except the top one
prob[sort[-boost_top[0]:boost_until]] *= boost_top[1]
del prob
# make sure probabilities always add up to 1
probs = probs / probs.sum(axis=1, keepdims=1)
if meena_max:
max = 0
axis = 0
for i, prob in enumerate(probs):
max_ = prob.max()
if max_ > max:
max = max_
axis = i
sample = onp.random.choice(probs.shape[-1], p=probs[axis, :])
cur_samples = [sample
for i in range(probs.shape[0])]
elif meena_combine:
sample = onp.random.choice(probs.shape[-1], p=probs[0, :])
cur_samples = [sample
for i in range(probs.shape[0])]
else:
cur_samples = [onp.random.choice(probs.shape[-1], p=probs[i, :])
for i in range(probs.shape[0])]
cur_samples = onp.array(cur_samples, dtype=int)
all_samples.append(cur_samples)
cur_inputs = np.array(cur_samples[:, None, None])
all_samples = onp.stack(all_samples, -1)
for ids in all_samples:
print(tokenizer.decode(ids.tolist()).replace('\\n', '\n'))
print('_____________________')
if meena_combine or meena_max:
# all samples are the same for those options
break
return all_samples
def sample(args):
gin.parse_config(train_config)
schedule = create_fixed_training_schedule()
output_dir = os.path.expanduser(f'{args.model_folder}/')
trainer = trax.supervised.Trainer(
model=trax.models.ReformerLM,
loss_fn=trax.layers.CrossEntropyLoss,
optimizer=trax.optimizers.Adam,
lr_schedule=schedule,
inputs=trax.supervised.inputs.Inputs(
fake_data, fake_data),
output_dir=output_dir,
has_weights=True)
gin.parse_config(test_config)
model_infer = trax.models.ReformerLM(mode='predict')
# Prepare a jitted copy of the model.
jit_model_infer = trax.layers.base._accelerate(
model_infer._forward_internal, trax.math.device_count())
# Set up the initial state for sampling.
infer_state = model_infer.new_weights_and_state(
trax.supervised.trainer_lib.ShapeDtype((1, 1), dtype=np.int32))[1]
infer_state = trainer._for_n_devices(infer_state)
model_weights = trainer._opt_state[0][0]
del trainer
del model_infer
sample_single(model_weights=model_weights, infer_state=infer_state, jit_model_infer=jit_model_infer, length=args.length, prompt=args.prompt, temperature=args.temperature,
top_k=args.top_k, exp2=args.exp2, meena_combine=args.meena_combine,
meena_max=args.meena_max)
def main_sample(args):
if args.tpu:
if 'TPU_DRIVER_MODE' not in globals():
url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(
':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
# The following is required to use TPU Driver as JAX's backend.
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + \
os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)
sample(args)
if __name__ == '__main__':
main_sample(args)