This repository has been archived by the owner on Feb 25, 2022. It is now read-only.
/
gpt2.py
500 lines (383 loc) · 21.7 KB
/
gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
"""GPT-like model in Mesh-Tensorflow"""
import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
import math
import mesh_tensorflow.transformer as mtf_transformer
from models.utils import parse_inputs
# --------------------------------------------------------------------------------
# LAYERS:
sentinel = object()
def exists(x):
return x is not None
def identity(x, *args, **kwargs):
return x
def is_incremental_inference(context):
return exists(context) and context.mode == "incremental"
def norm(x, axis, epsilon=1e-8):
x -= mtf.reduce_mean(x, reduced_dim=axis, name="norm_reduce_mean_u")
s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name="norm_reduce_mean_s")
return x * mtf.rsqrt(s + epsilon)
def rezero(x, scope, dtype):
with tf.variable_scope(scope):
g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(0), dtype=dtype)
return x * g
def scale_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None):
if axis is sentinel:
axis = x.shape[-1]
with tf.variable_scope(scope):
g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(1),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
x = norm(x, axis, epsilon)
x = x * g
return x
def layer_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None):
"""Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
if axis is sentinel:
axis = x.shape[-1]
with tf.variable_scope(scope):
n_state = x.shape[-1]
g = mtf.get_variable(x.mesh, "g", [n_state], initializer=tf.constant_initializer(1),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
b = mtf.get_variable(x.mesh, "b", [n_state], initializer=tf.constant_initializer(0),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
x = norm(x, axis, epsilon)
x = x * g + b
return x
def linear_attention(q, k, v):
batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])
q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in")
k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in")
dim_in = k.shape[-1]
q = mtf.softmax(q, dim_in)
k = mtf.softmax(k, seq_dim)
context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out])
attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out])
return attn
def causal_linear_attention(q, k, v, epsilon=1e-6):
batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3])
q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in")
k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in")
dim_in = k.shape[-1]
q = mtf.softmax(q, dim_in)
k = mtf.exp(k)
cumulative_k = mtf.cumsum(k, seq_dim)
context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out])
cumulative_context = mtf.cumsum(context, seq_dim)
cumulative_context /= (cumulative_k + epsilon)
attn = mtf.einsum([q, cumulative_context], output_shape=[batch_dim, seq_dim, head_dim, dim_out])
return attn
def linear(x, scope, nf, *, w_init_stdev=0.02, variable_dtype, params=None, scale=False):
# nf = number of features
if params["scale_by_depth"] and scale:
# Scale by sqrt(num_layers), only happens at the final projection before a res block output
w_init_stdev = w_init_stdev * (1. / math.sqrt(params["n_layer"]))
if params["scale_by_in"]: # Scale by sqrt(num_input_features)
w_init_stdev = w_init_stdev * (1. / math.sqrt(x.shape[-1].size)) # Dimension is a namedtuple of (name, size)
# Not in the variable_scope because mtf already has a variable_scope in it
with tf.variable_scope("conv1d_main"):
c = mtf.layers.dense(x, new_dims=[nf], reduced_dims=[x.shape[-1]], name=scope, use_bias=True,
kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev),
variable_dtype=variable_dtype,
)
return c
def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh):
"""memory / key values from all attention paper"""
dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv)
emb_dim = k.shape[-1]
mem_std = 1 / math.sqrt(emb_dim.size)
mem_k = mtf.get_variable(mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
initializer=tf.random_normal_initializer(stddev=mem_std),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype,
)
mem_v = mtf.get_variable(mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]),
initializer=tf.random_normal_initializer(stddev=mem_std),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]),
(mem_k, mem_v))
mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"),
(mem_k, mem_v))
k = mtf.concat([mem_k, k], "sequence")
v = mtf.concat([mem_v, v], "sequence")
return k, v
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None):
# x :: [batch, seq, n_embd]
x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh
# n_state is the same as config["n_embd"], which is also the same as dim_embd.
assert n_state.size % params["n_head"] == 0
dim_heads = mtf.Dimension("heads", params["n_head"])
num_mem_kv = params.get("num_mem_kv", 0)
use_num_mem_kv = num_mem_kv > 0
with tf.variable_scope(scope):
# Compute attention inputs
dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"])
mtfparams = mtf.transformer.attention.attention_params_simple(
x.mesh,
io_dim=dim_embd,
kv_dim=dim_kv,
heads_dim=dim_heads,
variable_dtype=variable_dtype
)
q = mtfparams.compute_q(x)
k = mtfparams.compute_k(x)
v = mtfparams.compute_v(x)
if is_incremental_inference(context):
one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype)
inv_one_hot = 1.0 - one_hot
old_k, old_v = context.get_states(2)
k = old_k * inv_one_hot + k * one_hot
v = old_v * inv_one_hot + v * one_hot
if exists(context):
context.record_new_states([k, v])
with tf.variable_scope("attention"):
if attention_type == "local":
# `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
radius = params.get("local_attention_radius", 256)
if is_incremental_inference(context):
q *= one_hot
a = mtf_transformer.attention.local_attention_1d(
q, k, v,
length_dim=k.shape[1],
key_dim=dim_kv,
value_dim=dim_kv,
radius=radius,
length_dim_num_splits=1,
fully_autoregressive=params["causal"],
attention_kwargs={},
)
if is_incremental_inference(context):
a = mtf.gather(a, context.position - 1, dim_seq)
elif attention_type == "global":
# TODO: pass in fake context
# Broadcast mask bias across batch and heads
if exists(bias):
if not is_incremental_inference(context):
broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]])
else:
# In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position
bias = mtf.gather(bias, context.position - 1, dim_seq)
broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]])
# memory key / values, from all-attention paper
if use_num_mem_kv:
k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh)
k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim)
v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim)
attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0
a = mtf_transformer.attention.attention(
q, k, v,
memory_length_dim=memory_length_dim,
key_dim=dim_kv,
value_dim=dim_kv,
bias=broadcasted_bias,
dropout_rate=attn_dropout_rate
)
elif attention_type == "linear":
linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention
a = linear_attn_fn(q, k, v)
else:
raise NotImplementedError("Unknown attention type {}!".format(attention_type))
with tf.variable_scope("compute_output"):
a = mtfparams.compute_output(a, x_shape)
with tf.variable_scope("compute_output_bias"):
b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
a += b
if params["mode"] == "train" and params["res_dropout"] > 0:
a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout")
return a
def mlp(x, scope, n_state, *, variable_dtype, params):
with tf.variable_scope(scope):
nx = x.shape[-1]
h = mtf.gelu(linear(x, "c_fc", n_state, variable_dtype=variable_dtype, params=params))
h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True)
if params["mode"] == "train" and params["res_dropout"] > 0:
h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout")
return h2
def mlp_glu(x, scope, n_state, *, variable_dtype, params):
with tf.variable_scope(scope):
nx = x.shape[-1]
h = linear(x, "c_fc", n_state, params=params)
h, gate = mtf.split(h, h.shape[-1], 2)
h *= mtf.gelu(gate)
h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True)
if params["mode"] == "train" and params["res_dropout"] > 0:
h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout")
return h2
def block(params, scope, layer_num, bias, sequence_dim, memory_length_dim, variable_dtype, context=None):
use_mlp_glu = params["mlp_glu"] == True
use_scale_norm = params["scalenorm"] == True
use_moe = exists(params["moe_layers"]) and (layer_num in params["moe_layers"])
use_rezero = params["rezero"] == True
macaron_attention = params["macaron"] == True
def fn(x):
with tf.variable_scope(scope):
nx = x.shape[-1] # Grab last dimension from input
if use_rezero:
prenorm = identity
elif use_scale_norm:
prenorm = scale_norm
else:
prenorm = layer_norm
pre_residual_fn = rezero if use_rezero else identity
attention_type = params["attention_types"][layer_num]
if macaron_attention:
mult = 0.5
mlp_fn = mlp_glu if use_mlp_glu else mlp
intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)
# Define intermediate layer of mlp - to split
dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size)
m = mlp_fn(x, "mlp_macaron", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params)
x = x + (m * mult)
else:
mult = 1
if attention_type != "none":
res_x = prenorm(x, "norm_1", variable_dtype=variable_dtype, params=params)
a = attn(res_x, "attn", nx, attention_type=attention_type,
params=params, bias=bias, dim_seq=sequence_dim, memory_length_dim=memory_length_dim,
variable_dtype=variable_dtype, context=context)
else:
a = x
x = x + pre_residual_fn(a, "norm_rezero_1", dtype=variable_dtype)
res_x = prenorm(x, "norm_2", variable_dtype=variable_dtype, params=params)
if use_moe:
moe_params = mtf.transformer.moe.HParams()
mtf.transformer.moe.set_default_moe_hparams(moe_params)
# Override defaults
for k, v in params["moe_params"].items():
moe_params.add_hparam(k, v)
moe_train = params["mode"] == "train"
m, aux_loss = mtf.transformer.moe.transformer_moe_layer_v1(res_x, x.shape[-1], moe_params,
train=moe_train,
mesh_shape=params["mesh_shape"],
layout=params["layout"],
variable_dtype=variable_dtype)
else:
mlp_fn = mlp_glu if use_mlp_glu else mlp
intermediate_size = nx.size * 4 * (1 if not use_mlp_glu else 2)
# Define intermediate layer of mlp - to split
dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size)
m = mlp_fn(res_x, "mlp", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params)
aux_loss = mtf.zeros(x.mesh, mtf.Shape([]), dtype=variable_dtype.slice_dtype)
x = x + pre_residual_fn((m*mult), "norm_rezero_2", variable_dtype)
return x, aux_loss
return fn
def axial_positional_emb(embd_dim, mesh, params, variable_dtype):
# Use axial position encoding
axial_dim_1, axial_dim_2 = params["axial_pos_emb"]
axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2)
dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))]
axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]),
initializer=tf.random_normal_initializer(stddev=0.01),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]),
initializer=tf.random_normal_initializer(stddev=0.01),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]),
(axial_wpe_1, axial_wpe_2))
wpe = (axial_wpe_1 + axial_wpe_2) / 2
wpe = mtf.reshape(wpe, [axial_dim, embd_dim])
return wpe
# --------------------------------------------------------------------------------
# MODEL:
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None):
"""A GPT style model implemented in mesh tensorflow."""
x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features)
if is_incremental_inference(context):
# reshape inputs if in inference mode
x = mtf.gather(x, context.position - 1, sequence_dim)
x = mtf.reshape(x, [batch_dim])
use_axial_pos_emb = params["axial_pos_emb"] is not None
if not use_axial_pos_emb:
# Use standard position encoding
wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]),
initializer=tf.random_normal_initializer(stddev=0.01),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
else:
wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype)
# Text encoding
wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]),
initializer=tf.random_normal_initializer(stddev=0.02),
master_dtype=variable_dtype.master_dtype,
slice_dtype=variable_dtype.slice_dtype,
activation_dtype=variable_dtype.activation_dtype)
with tf.variable_scope("token_embd"):
# Text embedding
h = mtf.gather(wte, x, vocab_dim)
if params["embed_dropout"] > 0 and params["mode"] == "train":
h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout")
with tf.variable_scope("pos_embd"):
# Positional embedding
position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else (
context.position - 1)
pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
if params["embed_dropout"] > 0 and params["mode"] == "train":
pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout")
h += pos_emb
aux_losses = 0 # instantiate auxiliary losses (for MOE models)
for layer in range(params["n_layer"]):
# attn blocks
share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True
block_scope = f"h{layer}" if not share_parameters else ""
block_fn = block(params=params, scope=block_scope, layer_num=layer,
bias=other_features["attn_bias"],
sequence_dim=sequence_dim,
memory_length_dim=other_features["memory_length_dim"],
variable_dtype=variable_dtype,
context=context)
# If true and in train mode, enable gradient checkpointing
recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True
h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h])
aux_losses += loss
no_weight_tie_emb = params["no_weight_tie"] == True
if no_weight_tie_emb:
with tf.variable_scope("wte_final_linear"):
logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params)
else:
# Layer normalize & affine transform
h = layer_norm(h, "ln_f", variable_dtype=variable_dtype)
seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1)
with tf.variable_scope("wte_final_einsum"):
# Equivalent to tf.matmul
logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim])
if params["mode"] == "train":
labels = mtf_features["labels"]
z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy
# Go to full precision for the logits
logits = mtf.cast(logits, tf.float32)
with tf.variable_scope("xentropy_final"):
loss_batch = mtf.layers.softmax_cross_entropy_with_logits(logits=logits, targets=labels,
vocab_dim=logits.shape[-1], z_loss=z_loss)
# For non-autoregressive models (masked language modeling training)
# Make sure labels with padding tokens are not counted in the loss
if not params["causal"]:
padding_id = params.get("padding_id", 0)
loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch))
with tf.variable_scope("reduce_mean_final"):
loss = mtf.reduce_mean(loss_batch)
loss += aux_losses # Add on auxiliary losses (currently only used for MoE)
loss /= params["num_microbatches"]
# Convert to train dtype
loss = mtf.cast(loss, variable_dtype.slice_dtype)
else:
loss = None
loss_batch = None
# Cast back to checkpoint dtype
logits = mtf.cast(logits, variable_dtype.master_dtype)
return logits, loss, loss_batch