Skip to content

Commit

Permalink
Merge pull request #130 from alexwwang/tf.keras-version
Browse files Browse the repository at this point in the history
🐛 fix typo caused bug in DPCNN & modify its unit test to fit its using scenario.
  • Loading branch information
BrikerMan committed Jul 3, 2019
2 parents 2340d6e + b245574 commit 3faaf8c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
8 changes: 7 additions & 1 deletion kashgari/tasks/classification/dpcnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@


class DPCNN_Model(BaseClassificationModel):
'''
This implementation of DPCNN requires a clear declared sequence length.
So sequences input in should be padded or cut to a given length in advance.
'''

@classmethod
def get_default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]:
Expand Down Expand Up @@ -152,7 +156,7 @@ def build_model_arc(self):
L.Dense(output_dim, **config['activation'])
]

tensor_out = embed_model.inputs
tensor_out = embed_model.output

# build region tensors
for layer in layers_region:
Expand All @@ -162,6 +166,8 @@ def build_model_arc(self):
tensor_out = self.conv_block(tensor_out, **config['conv_block'])
# build the above pyramid layers while `steps > 2`
seq_len = tensor_out.shape[1].value
if seq_len is None:
raise ValueError('`sequence_length` should be explicitly assigned, but it is `None`.')
for i in range(floor(log2(seq_len)) - 2):
tensor_out = self.resnet_block(tensor_out, stage=i + 1,
**config['resnet_block'])
Expand Down
2 changes: 1 addition & 1 deletion tests/classification/test_bi_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_custom_hyper_params(self):
if isinstance(value, bool):
pass
elif isinstance(value, int):
hyper_params[layer][key] = value + 15 if value > 64 else value
hyper_params[layer][key] = value + 15 if value >= 64 else value
model = self.model_class(embedding=w2v_embedding_variable_len,
hyper_parameters=hyper_params)
model.fit(valid_x, valid_y, epochs=1)
Expand Down
14 changes: 14 additions & 0 deletions tests/classification/test_dpcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,19 @@ def setUpClass(cls):
cls.model_class = DPCNN_Model


def test_custom_hyper_params(self):
hyper_params = self.model_class.get_default_hyper_parameters()

for layer, config in hyper_params.items():
for key, value in config.items():
if isinstance(value, bool):
pass
elif isinstance(value, int):
hyper_params[layer][key] = value + 15 if value >= 64 else value
model = self.model_class(embedding=base.w2v_embedding,
hyper_parameters=hyper_params)
model.fit(base.valid_x, base.valid_y, epochs=1)
assert True

if __name__ == "__main__":
print("Hello world")

0 comments on commit 3faaf8c

Please sign in to comment.