Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions configs/default/components/models/feed_forward_network.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ dropout_rate: 0
# If true, output of the last layer will be additionally processed with Log Softmax (LOADED)
use_logsoftmax: True

# Number of dimensions, where:
# - 2 means [Batch size, Input size]
# - n means [Batch size, dim 1, ..., dim n-2, Input size]
# And the FFN is broadcasted over the last (Input Size) Dimension.
# Also, all the dimensions sizes but the last are conserved, as the FFN is applied over the last dimension.
dimensions: 2

streams:
####################################################################
# 2. Keymappings associated with INPUT and OUTPUT streams.
Expand Down
22 changes: 18 additions & 4 deletions ptp/components/models/feed_forward_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def __init__(self, name, config):
self.key_inputs = self.stream_keys["inputs"]
self.key_predictions = self.stream_keys["predictions"]

self.dimensions = self.config["dimensions"]

# Retrieve input size from global variables.
self.input_size = self.globals["input_size"]
if type(self.input_size) == list:
Expand Down Expand Up @@ -106,7 +108,7 @@ def input_data_definitions(self):
:return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`).
"""
return {
self.key_inputs: DataDefinition([-1, self.input_size], [torch.Tensor], "Batch of inputs, each represented as index [BATCH_SIZE x INPUT_SIZE]"),
self.key_inputs: DataDefinition(([-1] * (self.dimensions -1)) + [self.input_size], [torch.Tensor], "Batch of inputs, each represented as index [BATCH_SIZE x ... x INPUT_SIZE]"),
}


Expand All @@ -117,7 +119,7 @@ def output_data_definitions(self):
:return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`).
"""
return {
self.key_predictions: DataDefinition([-1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x PREDICTION_SIZE]")
self.key_predictions: DataDefinition(([-1] * (self.dimensions -1)) + [self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x ... x PREDICTION_SIZE]")
}

def forward(self, data_dict):
Expand All @@ -126,13 +128,22 @@ def forward(self, data_dict):

:param data_dict: DataDict({'inputs', 'predictions ...}), where:

- inputs: expected inputs [BATCH_SIZE x INPUT_SIZE],
- predictions: returned output with predictions (log_probs) [BATCH_SIZE x NUM_CLASSES]
- inputs: expected inputs [BATCH_SIZE x ... x INPUT_SIZE],
- predictions: returned output with predictions (log_probs) [BATCH_SIZE x ... x NUM_CLASSES]
"""

# Get inputs.
x = data_dict[self.key_inputs]

# Check that the input has the number of dimensions that we expect
assert len(x.shape) == self.dimensions, \
"Expected " + str(self.dimensions) + " dimensions for input, got " + str(len(x.shape))\
+ " instead. Check number of dimensions in the config."

# Reshape such that we do a broadcast over the last dimension
origin_shape = x.shape
x = x.contiguous().view(-1, origin_shape[-1])

# Propagate inputs through all but last layer.
for layer in self.layers[:-1]:
x = layer(x)
Expand All @@ -147,5 +158,8 @@ def forward(self, data_dict):
if self.use_logsoftmax:
x = self.log_softmax(x)

# Restore the input dimensions but the last one (as it's been resized by the FFN)
x = x.view(*origin_shape[0:self.dimensions-1], -1)

# Add predictions to datadict.
data_dict.extend({self.key_predictions: x})