Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 23 additions & 50 deletions official/nlp/bert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,52 +32,24 @@
import megengine.hub as hub
import numpy as np
from megengine import Parameter
from megengine.functional import cross_entropy_with_softmax
from megengine.functional.loss import cross_entropy
from megengine.module import Dropout, Embedding, Linear, Module, Sequential
from megengine.module.activation import Softmax


def transpose(inp, a, b):
cur_shape = list(range(0, len(inp.shape)))
cur_shape = list(range(0, inp.ndim))
cur_shape[a], cur_shape[b] = cur_shape[b], cur_shape[a]
return inp.dimshuffle(*cur_shape)
return inp.transpose(cur_shape)


def matmul(a, b, transpose_b=None):
dim = len(b.shape)

if transpose_b:
b = transpose(b, dim - 1, dim - 2)

if dim > 3:
a_shape = list(a.shape)
b_shape = list(b.shape)
reshape_batch_size = 1
for i in a_shape[0 : dim - 2]:
reshape_batch_size *= i
a = a.reshape(*([reshape_batch_size] + a_shape[dim - 2 : dim]))
b = b.reshape(*([reshape_batch_size] + b_shape[dim - 2 : dim]))
c = F.batched_matrix_mul(a, b)
c = c.reshape(*(a_shape[0 : dim - 1] + b_shape[dim - 1 : dim]))
return c
elif dim == 3:
return F.batched_matrix_mul(a, b)
else:
return F.matrix_mul(a, b)

def zeros_like(inp):
return mge.zeros(inp.shape, dtype=inp.dtype)

def ones_like(inp):
return mge.ones(inp.shape, dtype=inp.dtype)

def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
x * 0.5 * (1.0 + F.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * (x ** 3)))))
x * 0.5 * (1.0 + F.tanh((F.sqrt(2 / math.pi) * (x + 0.044715 * (x ** 3)))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + F.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * (x ** 3)))))
return x * 0.5 * (1.0 + F.tanh(F.sqrt(2 / math.pi) * (x + 0.044715 * (x ** 3))))


ACT2FN = {"gelu": gelu, "relu": F.relu}
Expand Down Expand Up @@ -221,10 +193,10 @@ def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.shape[1]

if token_type_ids is None:
token_type_ids = zeros_like(input_ids)
token_type_ids = F.zeros_like(input_ids)

position_ids = F.linspace(0, seq_length - 1, seq_length).astype(np.int32)
position_ids = F.add_axis(position_ids, 0).broadcast(*input_ids.shape)
position_ids = F.broadcast_to(F.expand_dims(position_ids, 0), input_ids.shape)
words_embeddings = self.word_embeddings(input_ids)

position_embeddings = self.position_embeddings(position_ids)
Expand Down Expand Up @@ -255,12 +227,11 @@ def __init__(self, config):
self.dropout = Dropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x):
new_x_shape = x.shape[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.reshape(*new_x_shape)
return x.dimshuffle(0, 2, 1, 3)
# using symbolic shapes to make trace happy
x_shape = mge.tensor(x.shape)
new_x_shape = F.concat([x_shape[:-1], (self.num_attention_heads, self.attention_head_size)])
x = x.reshape(new_x_shape)
return x.transpose(0, 2, 1, 3)

def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
Expand All @@ -272,7 +243,7 @@ def forward(self, hidden_states, attention_mask):
value_layer = self.transpose_for_scores(mixed_value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = matmul(query_layer, transpose(key_layer, -1, -2))
attention_scores = F.matmul(query_layer, transpose(key_layer, -1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
Expand All @@ -284,10 +255,12 @@ def forward(self, hidden_states, attention_mask):
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

context_layer = matmul(attention_probs, value_layer)
context_layer = context_layer.dimshuffle(0, 2, 1, 3)
new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(*new_context_layer_shape)
context_layer = F.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(0, 2, 1, 3)
# using symbolic shapes to make trace happy
context_shape = mge.tensor(context_layer.shape)
new_context_layer_shape = F.concat([context_shape[:-2], self.all_head_size])
context_layer = context_layer.reshape(new_context_layer_shape)
return context_layer


Expand Down Expand Up @@ -453,17 +426,17 @@ def forward(
output_all_encoded_layers=True,
):
if attention_mask is None:
attention_mask = ones_like(input_ids)
attention_mask = F.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = zeros_like(input_ids)
token_type_ids = F.zeros_like(input_ids)
# print('input_ids', input_ids.sum())
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
# print('attention_mask', attention_mask.sum())
extended_attention_mask = F.add_axis(attention_mask, (1, 2))
extended_attention_mask = F.expand_dims(attention_mask, (1, 2))

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
Expand Down Expand Up @@ -554,7 +527,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No
logits = self.classifier(pooled_output)

if labels is not None:
loss = cross_entropy_with_softmax(
loss = cross_entropy(
logits.reshape(-1, self.num_labels), labels.reshape(-1)
)
return logits, loss
Expand Down
9 changes: 2 additions & 7 deletions official/nlp/bert/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,14 @@
logger = mge.get_logger(__name__)


@trace(symbolic=True)
# @trace(symbolic=True)
def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None):
net.eval()
results = net(input_ids, segment_ids, input_mask, label_ids)
logits, loss = results
return loss, logits, label_ids


def accuracy(out, labels):
outputs = F.argmax(out, axis=1)
return F.sum(outputs == labels)


def eval(dataloader, net):
logger.info("***** Running evaluation *****")
logger.info("batch size = %d", args.eval_batch_size)
Expand All @@ -48,7 +43,7 @@ def eval(dataloader, net):
input_ids, segment_ids, input_mask, label_ids, net=net
)
sum_loss += loss.mean().item()
sum_accuracy += accuracy(logits, label_ids)
sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size
total_examples += batch_size
total_steps += 1

Expand Down
34 changes: 16 additions & 18 deletions official/nlp/bert/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import megengine as mge
import megengine.functional as F
import megengine.optimizer as optim
from megengine.autodiff import GradManager
from megengine.jit import trace
from tqdm import tqdm

Expand All @@ -21,28 +22,24 @@
logger = mge.get_logger(__name__)


@trace(symbolic=True)
# @trace(symbolic=True)
def net_eval(input_ids, segment_ids, input_mask, label_ids, net=None):
net.eval()
results = net(input_ids, segment_ids, input_mask, label_ids)
logits, loss = results
return loss, logits, label_ids
return loss, logits


@trace(symbolic=True)
def net_train(input_ids, segment_ids, input_mask, label_ids, opt=None, net=None):
# @trace(symbolic=True)
def net_train(input_ids, segment_ids, input_mask, label_ids, gm=None, net=None):
net.train()
results = net(input_ids, segment_ids, input_mask, label_ids)
logits, loss = results
opt.backward(loss)
with gm:
results = net(input_ids, segment_ids, input_mask, label_ids)
logits, loss = results
gm.backward(loss)
return loss, logits, label_ids


def accuracy(out, labels):
outputs = F.argmax(out, axis=1)
return F.sum(outputs == labels)


def eval(dataloader, net):
logger.info("***** Running evaluation *****")
logger.info("batch size = %d", args.eval_batch_size)
Expand All @@ -56,11 +53,11 @@ def eval(dataloader, net):
batch_size = input_ids.shape[0]
if batch_size != args.eval_batch_size:
break
loss, logits, label_ids = net_eval(
loss, logits = net_eval(
input_ids, segment_ids, input_mask, label_ids, net=net
)
sum_loss += loss.mean().item()
sum_accuracy += accuracy(logits, label_ids)
sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size
total_examples += batch_size
total_steps += 1

Expand All @@ -79,18 +76,19 @@ def train(dataloader, net, opt):
logger.info("batch size = %d", args.train_batch_size)
sum_loss, sum_accuracy, total_steps, total_examples = 0, 0, 0, 0

gm = GradManager().attach(net.parameters())

for _, batch in enumerate(tqdm(dataloader, desc="Iteration")):
input_ids, input_mask, segment_ids, label_ids = tuple(
mge.tensor(t) for t in batch
)
batch_size = input_ids.shape[0]
opt.zero_grad()
loss, logits, label_ids = net_train(
input_ids, segment_ids, input_mask, label_ids, opt=opt, net=net
input_ids, segment_ids, input_mask, label_ids, gm=gm, net=net
)
optimizer.step()
opt.step().clear_grad()
sum_loss += loss.mean().item()
sum_accuracy += accuracy(logits, label_ids)
sum_accuracy += F.topk_accuracy(logits, label_ids) * batch_size
total_examples += batch_size
total_steps += 1

Expand Down