From ae42aa3ff848bedb08830eb7d4f3f2c72da26a7e Mon Sep 17 00:00:00 2001 From: chenyuanzhao Date: Tue, 25 Aug 2020 14:42:34 +0800 Subject: [PATCH] feat(nlp): Update bert for MegEngine v1.0 --- official/nlp/bert/model.py | 73 ++++++++++++-------------------------- official/nlp/bert/test.py | 9 ++--- official/nlp/bert/train.py | 34 +++++++++--------- 3 files changed, 41 insertions(+), 75 deletions(-) diff --git a/official/nlp/bert/model.py b/official/nlp/bert/model.py index 879879ff..9eafacac 100644 --- a/official/nlp/bert/model.py +++ b/official/nlp/bert/model.py @@ -32,52 +32,24 @@ import megengine.hub as hub import numpy as np from megengine import Parameter -from megengine.functional import cross_entropy_with_softmax +from megengine.functional.loss import cross_entropy from megengine.module import Dropout, Embedding, Linear, Module, Sequential from megengine.module.activation import Softmax def transpose(inp, a, b): - cur_shape = list(range(0, len(inp.shape))) + cur_shape = list(range(0, inp.ndim)) cur_shape[a], cur_shape[b] = cur_shape[b], cur_shape[a] - return inp.dimshuffle(*cur_shape) + return inp.transpose(cur_shape) -def matmul(a, b, transpose_b=None): - dim = len(b.shape) - - if transpose_b: - b = transpose(b, dim - 1, dim - 2) - - if dim > 3: - a_shape = list(a.shape) - b_shape = list(b.shape) - reshape_batch_size = 1 - for i in a_shape[0 : dim - 2]: - reshape_batch_size *= i - a = a.reshape(*([reshape_batch_size] + a_shape[dim - 2 : dim])) - b = b.reshape(*([reshape_batch_size] + b_shape[dim - 2 : dim])) - c = F.batched_matrix_mul(a, b) - c = c.reshape(*(a_shape[0 : dim - 1] + b_shape[dim - 1 : dim])) - return c - elif dim == 3: - return F.batched_matrix_mul(a, b) - else: - return F.matrix_mul(a, b) - -def zeros_like(inp): - return mge.zeros(inp.shape, dtype=inp.dtype) - -def ones_like(inp): - return mge.ones(inp.shape, dtype=inp.dtype) - def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - x * 0.5 * (1.0 + F.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * (x ** 3))))) + x * 0.5 * (1.0 + F.tanh((F.sqrt(2 / math.pi) * (x + 0.044715 * (x ** 3))))) Also see https://arxiv.org/abs/1606.08415 """ - return x * 0.5 * (1.0 + F.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * (x ** 3))))) + return x * 0.5 * (1.0 + F.tanh(F.sqrt(2 / math.pi) * (x + 0.044715 * (x ** 3)))) ACT2FN = {"gelu": gelu, "relu": F.relu} @@ -221,10 +193,10 @@ def forward(self, input_ids, token_type_ids=None): seq_length = input_ids.shape[1] if token_type_ids is None: - token_type_ids = zeros_like(input_ids) + token_type_ids = F.zeros_like(input_ids) position_ids = F.linspace(0, seq_length - 1, seq_length).astype(np.int32) - position_ids = F.add_axis(position_ids, 0).broadcast(*input_ids.shape) + position_ids = F.broadcast_to(F.expand_dims(position_ids, 0), input_ids.shape) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) @@ -255,12 +227,11 @@ def __init__(self, config): self.dropout = Dropout(config.attention_probs_dropout_prob) def transpose_for_scores(self, x): - new_x_shape = x.shape[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.reshape(*new_x_shape) - return x.dimshuffle(0, 2, 1, 3) + # using symbolic shapes to make trace happy + x_shape = mge.tensor(x.shape) + new_x_shape = F.concat([x_shape[:-1], (self.num_attention_heads, self.attention_head_size)]) + x = x.reshape(new_x_shape) + return x.transpose(0, 2, 1, 3) def forward(self, hidden_states, attention_mask): mixed_query_layer = self.query(hidden_states) @@ -272,7 +243,7 @@ def forward(self, hidden_states, attention_mask): value_layer = self.transpose_for_scores(mixed_value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = matmul(query_layer, transpose(key_layer, -1, -2)) + attention_scores = F.matmul(query_layer, transpose(key_layer, -1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Apply the attention mask is (precomputed for all layers in BertModel forward() function) attention_scores = attention_scores + attention_mask @@ -284,10 +255,12 @@ def forward(self, hidden_states, attention_mask): # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) - context_layer = matmul(attention_probs, value_layer) - context_layer = context_layer.dimshuffle(0, 2, 1, 3) - new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,) - context_layer = context_layer.reshape(*new_context_layer_shape) + context_layer = F.matmul(attention_probs, value_layer) + context_layer = context_layer.transpose(0, 2, 1, 3) + # using symbolic shapes to make trace happy + context_shape = mge.tensor(context_layer.shape) + new_context_layer_shape = F.concat([context_shape[:-2], self.all_head_size]) + context_layer = context_layer.reshape(new_context_layer_shape) return context_layer @@ -453,9 +426,9 @@ def forward( output_all_encoded_layers=True, ): if attention_mask is None: - attention_mask = ones_like(input_ids) + attention_mask = F.ones_like(input_ids) if token_type_ids is None: - token_type_ids = zeros_like(input_ids) + token_type_ids = F.zeros_like(input_ids) # print('input_ids', input_ids.sum()) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] @@ -463,7 +436,7 @@ def forward( # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. # print('attention_mask', attention_mask.sum()) - extended_attention_mask = F.add_axis(attention_mask, (1, 2)) + extended_attention_mask = F.expand_dims(attention_mask, (1, 2)) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -554,7 +527,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No logits = self.classifier(pooled_output) if labels is not None: - loss = cross_entropy_with_softmax( + loss = cross_entropy( logits.reshape(-1, self.num_labels), labels.reshape(-1) ) return logits, loss diff --git a/official/nlp/bert/test.py b/official/nlp/bert/test.py index 924d0ada..2f6d3e64 100644 --- a/official/nlp/bert/test.py +++ b/official/nlp/bert/test.py @@ -20,7 +20,7 @@ logger = mge.get_logger(__name__) -@trace(symbolic=True) +# @trace(symbolic=True) def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None): net.eval() results = net(input_ids, segment_ids, input_mask, label_ids) @@ -28,11 +28,6 @@ def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None): return loss, logits, label_ids -def accuracy(out, labels): - outputs = F.argmax(out, axis=1) - return F.sum(outputs == labels) - - def eval(dataloader, net): logger.info("***** Running evaluation *****") logger.info("batch size = %d", args.eval_batch_size) @@ -48,7 +43,7 @@ def eval(dataloader, net): input_ids, segment_ids, input_mask, label_ids, net=net ) sum_loss += loss.mean().item() - sum_accuracy += accuracy(logits, label_ids) + sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size total_examples += batch_size total_steps += 1 diff --git a/official/nlp/bert/train.py b/official/nlp/bert/train.py index 1b82d84b..5fa9dff5 100644 --- a/official/nlp/bert/train.py +++ b/official/nlp/bert/train.py @@ -10,6 +10,7 @@ import megengine as mge import megengine.functional as F import megengine.optimizer as optim +from megengine.autodiff import GradManager from megengine.jit import trace from tqdm import tqdm @@ -21,28 +22,24 @@ logger = mge.get_logger(__name__) -@trace(symbolic=True) +# @trace(symbolic=True) def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None): net.eval() results = net(input_ids, segment_ids, input_mask, label_ids) logits, loss = results - return loss, logits, label_ids + return loss, logits -@trace(symbolic=True) -def net_train(input_ids, segment_ids, input_mask, label_ids, opt=None, net=None): +# @trace(symbolic=True) +def net_train(input_ids, segment_ids, input_mask, label_ids, gm=None, net=None): net.train() - results = net(input_ids, segment_ids, input_mask, label_ids) - logits, loss = results - opt.backward(loss) + with gm: + results = net(input_ids, segment_ids, input_mask, label_ids) + logits, loss = results + gm.backward(loss) return loss, logits, label_ids -def accuracy(out, labels): - outputs = F.argmax(out, axis=1) - return F.sum(outputs == labels) - - def eval(dataloader, net): logger.info("***** Running evaluation *****") logger.info("batch size = %d", args.eval_batch_size) @@ -56,11 +53,11 @@ def eval(dataloader, net): batch_size = input_ids.shape[0] if batch_size != args.eval_batch_size: break - loss, logits, label_ids = net_eval( + loss, logits = net_eval( input_ids, segment_ids, input_mask, label_ids, net=net ) sum_loss += loss.mean().item() - sum_accuracy += accuracy(logits, label_ids) + sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size total_examples += batch_size total_steps += 1 @@ -79,18 +76,19 @@ def train(dataloader, net, opt): logger.info("batch size = %d", args.train_batch_size) sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0 + gm = GradManager().attach(net.parameters()) + for _, batch in enumerate(tqdm(dataloader, desc="Iteration")): input_ids, input_mask, segment_ids, label_ids = tuple( mge.tensor(t) for t in batch ) batch_size = input_ids.shape[0] - opt.zero_grad() loss, logits, label_ids = net_train( - input_ids, segment_ids, input_mask, label_ids, opt=opt, net=net + input_ids, segment_ids, input_mask, label_ids, gm=gm, net=net ) - optimizer.step() + opt.step().clear_grad() sum_loss += loss.mean().item() - sum_accuracy += accuracy(logits, label_ids) + sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size total_examples += batch_size total_steps += 1