Skip to content

Commit

Permalink
Making number and size of conv. filters configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
tuetschek committed Nov 10, 2015
1 parent 3b85fc5 commit c3be9a5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
14 changes: 8 additions & 6 deletions tgen/rank_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def __init__(self, cfg):
self.dict_formeme = {'UNK_FORMEME': self.UNK_FORMEME}
self.max_tree_len = cfg.get('max_tree_len', 20)

self.cnn_num_filters = cfg.get('cnn_num_filters', 3)
self.cnn_filter_length = cfg.get('cnn_filter_length', 3)

if self.da_emb:
self.dict_slot = {'UNK_SLOT': self.UNK_SLOT}
self.dict_value = {'UNK_VALUE': self.UNK_VALUE}
Expand Down Expand Up @@ -277,13 +280,11 @@ def _init_neural_network(self):
pooling = T.mean

if self.da_emb:
da_layers = self._conv_layers('conv_da', num_conv_layers,
filter_length=3, num_filters=2, pooling=pooling)
da_layers = self._conv_layers('conv_da', num_conv_layers, pooling=pooling)
else:
da_layers = self._id_layers('id_da',
num_conv_layers + (1 if pooling is not None else 0))
tree_layers = self._conv_layers('conv_tree', num_conv_layers,
filter_length=3, num_filters=3, pooling=pooling)
tree_layers = self._conv_layers('conv_tree', num_conv_layers, pooling=pooling)

for da_layer, tree_layer in zip(da_layers, tree_layers):
layers.append([da_layer[0], tree_layer[0]])
Expand Down Expand Up @@ -318,11 +319,12 @@ def _init_neural_network(self):
self.nn = NN(layers, input_shapes, input_types, self.normgrad)
log_info("Network shape:\n\n" + str(self.nn))

def _conv_layers(self, name, num_layers=1, filter_length=3, num_filters=3, pooling=None):
def _conv_layers(self, name, num_layers=1, pooling=None):
ret = []
for i in xrange(num_layers):
ret.append([Conv1D(name + str(i + 1),
filter_length=filter_length, num_filters=num_filters,
filter_length=self.cnn_filter_length,
num_filters=self.cnn_num_filters,
init=self.init, activation=T.tanh)])
if pooling is not None:
ret.append([Pool1D(name + str(i + 1) + 'pool', pooling_func=pooling)])
Expand Down
2 changes: 2 additions & 0 deletions util/describe_experiment.pl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
$nn_shape .= ' E' . ( ( $config_data =~ /'emb_size'\s*:\s*([0-9]*)/ )[0] // 20 );
$nn_shape .= '-N' . ( ( $config_data =~ /'num_hidden_units'\s*:\s*([0-9]*)/ )[0] // 512 );
$nn_shape .= '-A' . ( ( $config_data =~ /'alpha'\s*:\s*([0-9.]+)/ )[0] // 0.1 );
$nn_shape .= '-C' . ( ( $config_data =~ /'cnn_filter_length'\s*:\s*([0-9]+)/ )[0] // 3 )
. '/' . ( ( $config_data =~ /'cnn_num_filters'\s*:\s*([0-9]+)/ )[0] // 3 );
$nn_shape .= '-' . ( ( $config_data =~ /'initialization'\s*:\s*'([^']*)'/ )[0] // 'uniform_glorot10' );

# NN gadgets
Expand Down

0 comments on commit c3be9a5

Please sign in to comment.