Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Building custom model over the final embedding layer #19

Closed
njordsir2 opened this issue Feb 25, 2019 · 6 comments
Closed

Building custom model over the final embedding layer #19

njordsir2 opened this issue Feb 25, 2019 · 6 comments

Comments

@njordsir2
Copy link

BERT supposedly generates 768 dimensional embeddings for tokens. I am trying to build a multi-class classification model on top of this. My assumption is that the output of layer Encoder-12-FeedForward-Norm of shape (None, [seq_length], 768) would give this embeddings. This is what I am trying :

model = load_trained_model_from_checkpoint(config_path, checkpoint_path, training=True, seq_len=seq_len)

new_out = Bidirectional(LSTM(50, return_sequences=True, 
                       dropout=0.1, 
                       recurrent_dropout=0.1))(model.layers[-9].output)
new_out = GlobalMaxPool1D()(new_out)
new_out = Dense(50, activation='relu')(new_out)
new_out = Dropout(0.1)(new_out)
new_out = Dense(6, activation='sigmoid')(new_out)

newModel = Model(model.inputs[:2], new_out)

I get the following error for new_out = GlobalMaxPool1D()(new_out) :

TypeError: Layer global_max_pooling1d_11 does not support masking, but was passed an input_mask: Tensor("Encoder-12-FeedForward-Add/All:0", shape=(?, 128), dtype=bool)

I am not sure how masking is involved if I am just using the output of the encoder.

The paper mentions that the output corresponding to just the first [CLS] token should be used for classification. On trying this :

new_out = Lambda(lambda x: x[:,0,:])(model.layers[-9].output)

the model trains (although with poor results).

How can the pre-loaded model be used for classification?

@BerenLuthien
Copy link

BerenLuthien commented Feb 27, 2019

It looks that the author provided a demo:
inputs, output_layer = get_model( #output_layeris the last feature extraction layer (the last transformer) ... training=False, # The input layers and output layer will be returned iftrainingisFalse)
Then Any classifier can be added on top of this `output_layer' (which is embeddings), such as LSTM or Logistic Regression.
Make sure "training=False"
You may have to rewrite the token dictionary since your dataset may not be exactly like MRPC.

CyberZHG added a commit that referenced this issue Feb 28, 2019
@CyberZHG
Copy link
Owner

CyberZHG commented Feb 28, 2019

#7 Sentence Embedding

GlobalMaxPool1D doesn't support masking. Following is a modification that suits this case:

class MaskedGlobalMaxPool1D(keras.layers.Layer):
def __init__(self, **kwargs):
super(MaskedGlobalMaxPool1D, self).__init__(**kwargs)
self.supports_masking = True
def compute_mask(self, inputs, mask=None):
return None
def compute_output_shape(self, input_shape):
return input_shape[:-2] + (input_shape[-1],)
def call(self, inputs, mask=None):
if mask is not None:
mask = K.cast(mask, K.floatx())
inputs -= K.expand_dims((1.0 - mask) * 1e6, axis=-1)
return K.max(inputs, axis=-2)

I've added a demo for sentence embedding with pooling:

model = load_trained_model_from_checkpoint(config_path, checkpoint_path)
pool_layer = MaskedGlobalMaxPool1D(name='Pooling')(model.output)
model = keras.models.Model(inputs=model.inputs, outputs=pool_layer)
model.summary(line_length=120)
tokens = ['[CLS]', '语', '言', '模', '型', '[SEP]']
token_dict = {}
with codecs.open(dict_path, 'r', 'utf8') as reader:
for line in reader:
token = line.strip()
token_dict[token] = len(token_dict)
token_input = np.asarray([[token_dict[token] for token in tokens] + [0] * (512 - len(tokens))])
seg_input = np.asarray([[0] * len(tokens) + [0] * (512 - len(tokens))])
print('Inputs:', token_input[0][:len(tokens)])
predicts = model.predict([token_input, seg_input])[0]
print('Pooled:', predicts.tolist()[:5])

@njordsir2
Copy link
Author

@CyberZHG Oh sweet! Will check this out. Going through the paper for BERT cleared my masking queries.
I meanwhile got the job done with the official tensorflow-hub module.

@BerenLuthien
Copy link

BerenLuthien commented Mar 13, 2019

Thanks. The MaskedGlobalMaxPool1D itself works well in the demo you gave, but it looks it does not fit if we add a classification layer (such as Dense) on top of it:
This code gives error:

model = load_trained_model_from_checkpoint('bert_config.json', 'bert_model.ckpt') 
def get_custermized_model(model):
    pool_layer = MaskedGlobalMaxPool1D(name='Pooling')(model.output) 
    x = pool_layer
    x = Dense(units=1, activation='sigmoid')(x)
    print(model.inputs[0])    
    custermized_model = Model(inputs=model.inputs,  outputs=x)
    custermized_model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['acc'])
    return custermized_model
custermized_model = get_custermized_model(model=model)
history = custermized_model.fit(x=X_train, y=train_labels, epochs=1, validation_split=0.3)

InvalidArgumentError: Incompatible shapes: [32] vs. [32,512]
[[{{node metrics_4/acc/mul}} = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"](metrics_4/acc/Mean, metrics_4/acc/Cast_1)]]
[[{{node metrics_4/acc/Mean_2/_1745}} = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_4979_metrics_4/acc/Mean_2", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

Would you help

  • add a Dense layer on top of "MaskedGlobalMaxPool1D" and try how it works ? Thanks

PS:
I noticed that the token_input in your demo has shape (1, 512) , and the above error happens when I feed a batch of data with shape (64, 512) where 64 is batch size.
However, if I add dimension and feed (64, 1, 512) to the model, it complains of input shape errors.

CyberZHG added a commit that referenced this issue Mar 13, 2019
@CyberZHG
Copy link
Owner

CyberZHG commented Mar 13, 2019

I forgot to return a None mask in MaskedGlobalMaxPool1D. I've fixed it and made a release.

@weizhenzhao
Copy link

@CyberZHG
Hi cyberzhg

I notice in the code above
`
(1) model = load_trained_model_from_checkpoint(config_path, checkpoint_path)

(2) model = load_trained_model_from_checkpoint(config_path, checkpoint_path, training=True, seq_len=seq_len)

`
if I build an classifier with bilstm on top of that ,
which means finetune?

Thanks
weiizhen

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants