Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
morvanzhou committed Sep 3, 2020
1 parent b645528 commit 9ea6d74
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 62 deletions.
89 changes: 27 additions & 62 deletions BERT.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,13 @@
import utils
import time
from transformer import Encoder
import pickle
from GPT import export_attention, GPT
import os

MODEL_DIM = 256
N_LAYER = 4
BATCH_SIZE = 12
LEARNING_RATE = 1e-4
MASK_RATE = 0.15


class BERT(keras.Model):
class BERT(GPT):
def __init__(self, model_dim, max_len, n_layer, n_head, n_vocab, lr, max_seg=3, drop_rate=0.1, padding_idx=0):
super().__init__()
super().__init__(model_dim, max_len, n_layer, n_head, n_vocab, lr, max_seg, drop_rate, padding_idx)
self.padding_idx = padding_idx
self.n_vocab = n_vocab
self.max_len = max_len
Expand Down Expand Up @@ -44,22 +38,15 @@ def __init__(self, model_dim, max_len, n_layer, n_head, n_vocab, lr, max_seg=3,
name="pos", shape=[1, max_len, model_dim], dtype=tf.float32, # [1, step, dim]
initializer=keras.initializers.RandomNormal(0., 0.01))
self.encoder = Encoder(n_head, model_dim, drop_rate, n_layer)
self.o_mlm = keras.layers.Dense(n_vocab)
self.o_nsp = keras.layers.Dense(2)
self.task_mlm = keras.layers.Dense(n_vocab)
self.task_nsp = keras.layers.Dense(2)

self.cross_entropy = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
self.opt = keras.optimizers.Adam(lr)

def __call__(self, seqs, segs, training=False):
embed = self.input_emb(seqs, segs) # [n, step, dim]
z = self.encoder(embed, training=training, mask=self.pad_mask(seqs))
mlm_logits = self.o_mlm(z) # [n, step, n_vocab]
nsp_logits = self.o_nsp(tf.reshape(z, [z.shape[0], -1])) # [n, n_cls]
return mlm_logits, nsp_logits

def step(self, seqs, segs, seqs_, loss_mask, nsp_labels):
with tf.GradientTape() as tape:
mlm_logits, nsp_logits = self(seqs, segs, training=True)
mlm_logits, nsp_logits = self.call(seqs, segs, training=True)
mlm_loss_batch = tf.boolean_mask(self.cross_entropy(seqs_, mlm_logits), loss_mask)
mlm_loss = tf.reduce_mean(mlm_loss_batch)
nsp_loss = tf.reduce_mean(self.cross_entropy(nsp_labels, nsp_logits))
Expand All @@ -68,20 +55,10 @@ def step(self, seqs, segs, seqs_, loss_mask, nsp_labels):
self.opt.apply_gradients(zip(grads, self.trainable_variables))
return loss, mlm_logits

def input_emb(self, seqs, segs):
return self.word_emb(seqs) + self.segment_emb(segs) + self.position_emb # [n, step, dim]

def pad_mask(self, seqs):
def mask(self, seqs):
mask = tf.cast(tf.math.equal(seqs, self.padding_idx), tf.float32)
return mask[:, tf.newaxis, tf.newaxis, :] # [n, 1, 1, step]

@property
def attentions(self):
attentions = {
"encoder": [l.mh.attention.numpy() for l in self.encoder.ls],
}
return attentions


def _get_loss_mask(len_arange, seq, pad_id):
rand_id = np.random.choice(len_arange, size=max(2, int(MASK_RATE * len(len_arange))), replace=False)
Expand All @@ -107,8 +84,8 @@ def do_nothing(seq, len_arange, pad_id):
return loss_mask


def random_mask_or_replace(data, arange):
seqs, segs, xlen, nsp_labels = data.sample(BATCH_SIZE)
def random_mask_or_replace(data, arange, batch_size):
seqs, segs, xlen, nsp_labels = data.sample(batch_size)
seqs_ = seqs.copy()
p = np.random.random()
if p < 0.7:
Expand Down Expand Up @@ -137,19 +114,13 @@ def random_mask_or_replace(data, arange):
return seqs, segs, seqs_, loss_mask, xlen, nsp_labels


def main():
# get and process data
data = utils.MRPCData("./MRPC", 2000)
print("num word: ", data.num_word)
model = BERT(
model_dim=MODEL_DIM, max_len=data.max_len, n_layer=N_LAYER, n_head=4, n_vocab=data.num_word,
lr=LEARNING_RATE, max_seg=data.num_seg, drop_rate=0.1, padding_idx=data.v2i["<PAD>"])
def train(model, data, step=10000, name="bert"):
t0 = time.time()
arange = np.arange(0, data.max_len)
for t in range(10000):
seqs, segs, seqs_, loss_mask, xlen, nsp_labels = random_mask_or_replace(data, arange)
for t in range(step):
seqs, segs, seqs_, loss_mask, xlen, nsp_labels = random_mask_or_replace(data, arange, 16)
loss, pred = model.step(seqs, segs, seqs_, loss_mask, nsp_labels)
if t % 20 == 0:
if t % 100 == 0:
pred = pred[0].numpy().argmax(axis=1)
t1 = time.time()
print(
Expand All @@ -162,27 +133,21 @@ def main():
"\n| prd word: ", [data.i2v[i] for i in pred*loss_mask[0] if i != data.v2i["<PAD>"]],
)
t0 = t1
os.makedirs("./visual/models/bert", exist_ok=True)
model.save_weights("./visual/models/bert/model.ckpt")


def export_attention():
data = utils.MRPCData("./MRPC", 2000)
print("num word: ", data.num_word)
model = BERT(
model_dim=MODEL_DIM, max_len=data.max_len, n_layer=N_LAYER, n_head=4, n_vocab=data.num_word,
lr=LEARNING_RATE, max_seg=data.num_seg, drop_rate=0.1, padding_idx=data.v2i["<PAD>"])
model.load_weights("./visual/models/bert/model.ckpt").expect_partial()

# save attention matrix for visualization
seqs, segs, xlen, nsp_labels = data.sample(1)
model(seqs, segs, False)
data = {"src": [data.i2v[i] for i in seqs[0]], "attentions": model.attentions}
with open("./visual/tmp/bert_attention_matrix.pkl", "wb") as f:
pickle.dump(data, f)
os.makedirs("./visual/models/%s" % name, exist_ok=True)
model.save_weights("./visual/models/%s/model.ckpt" % name)


if __name__ == "__main__":
# main()
export_attention()
MODEL_DIM = 256
N_LAYER = 4
LEARNING_RATE = 1e-4
MASK_RATE = 0.15

d = utils.MRPCData("./MRPC", 2000)
print("num word: ", d.num_word)
m = BERT(
model_dim=MODEL_DIM, max_len=d.max_len, n_layer=N_LAYER, n_head=4, n_vocab=d.num_word,
lr=LEARNING_RATE, max_seg=d.num_seg, drop_rate=0.2, padding_idx=d.v2i["<PAD>"])
train(m, d, step=5000, name="bert")
export_attention(m, d, "bert")

1 change: 1 addition & 0 deletions GPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def export_attention(model, data, name="gpt"):
MODEL_DIM = 256
N_LAYER = 4
LEARNING_RATE = 1e-4

d = utils.MRPCData("./MRPC", 2000)
print("num word: ", d.num_word)
m = GPT(
Expand Down

0 comments on commit 9ea6d74

Please sign in to comment.