diff --git a/configs/default/components/models/feed_forward_network.yml b/configs/default/components/models/feed_forward_network.yml index 55a43b1..b9b80c0 100644 --- a/configs/default/components/models/feed_forward_network.yml +++ b/configs/default/components/models/feed_forward_network.yml @@ -15,6 +15,13 @@ dropout_rate: 0 # If true, output of the last layer will be additionally processed with Log Softmax (LOADED) use_logsoftmax: True +# Number of dimensions, where: +# - 2 means [Batch size, Input size] +# - n means [Batch size, dim 1, ..., dim n-2, Input size] +# And the FFN is broadcasted over the last (Input Size) Dimension. +# Also, all the dimensions sizes but the last are conserved, as the FFN is applied over the last dimension. +dimensions: 2 + streams: #################################################################### # 2. Keymappings associated with INPUT and OUTPUT streams. diff --git a/ptp/components/models/feed_forward_network.py b/ptp/components/models/feed_forward_network.py index adbc757..5d4dbd0 100644 --- a/ptp/components/models/feed_forward_network.py +++ b/ptp/components/models/feed_forward_network.py @@ -40,6 +40,8 @@ def __init__(self, name, config): self.key_inputs = self.stream_keys["inputs"] self.key_predictions = self.stream_keys["predictions"] + self.dimensions = self.config["dimensions"] + # Retrieve input size from global variables. self.input_size = self.globals["input_size"] if type(self.input_size) == list: @@ -106,7 +108,7 @@ def input_data_definitions(self): :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ return { - self.key_inputs: DataDefinition([-1, self.input_size], [torch.Tensor], "Batch of inputs, each represented as index [BATCH_SIZE x INPUT_SIZE]"), + self.key_inputs: DataDefinition(([-1] * (self.dimensions -1)) + [self.input_size], [torch.Tensor], "Batch of inputs, each represented as index [BATCH_SIZE x ... x INPUT_SIZE]"), } @@ -117,7 +119,7 @@ def output_data_definitions(self): :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). """ return { - self.key_predictions: DataDefinition([-1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x PREDICTION_SIZE]") + self.key_predictions: DataDefinition(([-1] * (self.dimensions -1)) + [self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x ... x PREDICTION_SIZE]") } def forward(self, data_dict): @@ -126,13 +128,22 @@ def forward(self, data_dict): :param data_dict: DataDict({'inputs', 'predictions ...}), where: - - inputs: expected inputs [BATCH_SIZE x INPUT_SIZE], - - predictions: returned output with predictions (log_probs) [BATCH_SIZE x NUM_CLASSES] + - inputs: expected inputs [BATCH_SIZE x ... x INPUT_SIZE], + - predictions: returned output with predictions (log_probs) [BATCH_SIZE x ... x NUM_CLASSES] """ # Get inputs. x = data_dict[self.key_inputs] + # Check that the input has the number of dimensions that we expect + assert len(x.shape) == self.dimensions, \ + "Expected " + str(self.dimensions) + " dimensions for input, got " + str(len(x.shape))\ + + " instead. Check number of dimensions in the config." + + # Reshape such that we do a broadcast over the last dimension + origin_shape = x.shape + x = x.contiguous().view(-1, origin_shape[-1]) + # Propagate inputs through all but last layer. for layer in self.layers[:-1]: x = layer(x) @@ -147,5 +158,8 @@ def forward(self, data_dict): if self.use_logsoftmax: x = self.log_softmax(x) + # Restore the input dimensions but the last one (as it's been resized by the FFN) + x = x.view(*origin_shape[0:self.dimensions-1], -1) + # Add predictions to datadict. data_dict.extend({self.key_predictions: x})