Skip to content

Commit

Permalink
made SortPooling optional
Browse files Browse the repository at this point in the history
  • Loading branch information
LeviBorodenko committed Jan 27, 2020
1 parent 4f5d911 commit 76c1acd
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 21 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ Version 0.3
- added flatten_signals option to DeepGraphConvolution
- publish to pypi and github.
- publish docs.

Version 0.3.1
===========

- Make SortPooling optional in DeepConvolutional
- Host docs on github pages
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Initiated it with the following parameters:
|`flatten_signals` (default: False) | If `True`, flattens the last 2 dimensions of the output tensor into 1|
|`attention_heads` (default: None) | If given, then instead of using <a href="https://www.codecogs.com/eqnedit.php?latex=D^{-1}E" target="_blank"><img src="https://latex.codecogs.com/gif.latex?D^{-1}E" title="D^{-1}E" /></a> as the transition matrix inside the graph convolutions, we will use an attention based transition matrix. Utilizing `dgcnn.attention.AttentionMechanism` as the internal attention mechanism. This sets the number of attention heads used.|
|`attention_units` (default: None) | Also needs to be provided if `attention_heads` is set. This is the size of the internal embedding used by the attention mechanism.|
|`use_sortpooling` (default: True) | Whether or not to apply sortpooling at the end of the procedure. If False, we will simply return the concatinated graph convolution outputs.|

Thus, if we have non-temporal graph signals with 10 nodes and 5 features each and we would like to apply a DGCNN containing 3 graph convolutions with hidden feature dimensions of 10, 5 and 2 and SortPooling that keeps the 5 most relevant nodes. Then we would run

Expand Down
19 changes: 0 additions & 19 deletions README.rst

This file was deleted.

17 changes: 15 additions & 2 deletions src/dgcnn/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ class DeepGraphConvolution(layers.Layer):
- (default: (None)).
use_sortpooling (bool):
- If False, won't apply SortPooling at the end of the procedure.
- (default: True)
Inputs:
- X (tf.Tensor):
Expand Down Expand Up @@ -149,6 +155,7 @@ def __init__(
attention_heads: int = None,
attention_units: int = None,
flatten_signals: bool = False,
use_sortpooling: bool = True,
**kwargs
):
super(DeepGraphConvolution, self).__init__()
Expand Down Expand Up @@ -176,6 +183,7 @@ def __init__(
self.attention_heads = attention_heads
self.attention_units = attention_units
self.flatten_signals = flatten_signals
self.use_sortpooling = use_sortpooling

# save kwargs to pass them to the graphconv layers
self.kwargs = kwargs
Expand All @@ -193,7 +201,8 @@ def build(self, input_shape):
self.convolutions.append(layer)

# initiating SortPooling layer
self.SortPooling = SortPooling(self.k)
if self.use_sortpooling:
self.SortPooling = SortPooling(self.k)

# create AttentionMechanism if required
if self.use_attention:
Expand Down Expand Up @@ -234,7 +243,11 @@ def call(self, inputs):
# concat them to (N x sum(c_i)) signal
Z = tf.concat(rec_conv_signals, axis=-1)

# now apply SortPooling
# Check if we apply SortPooling
if not self.use_sortpooling:
return Z

# apply SortPooling
Z_pooled = self.SortPooling(Z)

if not self.flatten_signals:
Expand Down

0 comments on commit 76c1acd

Please sign in to comment.