/
pool.jl
147 lines (106 loc) · 3.28 KB
/
pool.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
using DataStructures: nlargest
@doc raw"""
GlobalPool(aggr)
Global pooling layer for graph neural networks.
Takes a graph and feature nodes as inputs
and performs the operation
```math
\mathbf{u}_V = \square_{i \in V} \mathbf{x}_i
```
where ``V`` is the set of nodes of the input graph and
the type of aggregation represented by ``\square`` is selected by the `aggr` argument.
Commonly used aggregations are `mean`, `max`, and `+`.
See also [`reduce_nodes`](@ref).
# Examples
```julia
using Flux, GraphNeuralNetworks, Graphs
pool = GlobalPool(mean)
g = GNNGraph(erdos_renyi(10, 4))
X = rand(32, 10)
pool(g, X) # => 32x1 matrix
g = Flux.batch([GNNGraph(erdos_renyi(10, 4)) for _ in 1:5])
X = rand(32, 50)
pool(g, X) # => 32x5 matrix
```
"""
struct GlobalPool{F} <: GNNLayer
aggr::F
end
function (l::GlobalPool)(g::GNNGraph, x::AbstractArray)
return reduce_nodes(l.aggr, g, x)
end
(l::GlobalPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g)))
@doc raw"""
GlobalAttentionPool(fgate, ffeat=identity)
Global soft attention layer from the [Gated Graph Sequence Neural
Networks](https://arxiv.org/abs/1511.05493) paper
```math
\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i)
```
where the coefficients ``\alpha_i`` are given by a [`softmax_nodes`](@ref)
operation:
```math
\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}}
{\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}.
```
# Arguments
- `fgate`: The function ``f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}``.
It is tipically expressed by a neural network.
- `ffeat`: The function ``f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}``.
It is tipically expressed by a neural network.
# Examples
```julia
chin = 6
chout = 5
fgate = Dense(chin, 1)
ffeat = Dense(chin, chout)
pool = GlobalAttentionPool(fgate, ffeat)
g = Flux.batch([GNNGraph(random_regular_graph(10, 4),
ndata=rand(Float32, chin, 10))
for i=1:3])
u = pool(g, g.ndata.x)
@assert size(u) == (chout, g.num_graphs)
```
"""
struct GlobalAttentionPool{G,F}
fgate::G
ffeat::F
end
@functor GlobalAttentionPool
GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)
function (l::GlobalAttentionPool)(g::GNNGraph, x::AbstractArray)
α = softmax_nodes(g, l.fgate(x))
feats = α .* l.ffeat(x)
u = reduce_nodes(+, g, feats)
return u
end
(l::GlobalAttentionPool)(g::GNNGraph) = GNNGraph(g, gdata=l(g, node_features(g)))
"""
TopKPool(adj, k, in_channel)
Top-k pooling layer.
# Arguments
- `adj`: Adjacency matrix of a graph.
- `k`: Top-k nodes are selected to pool together.
- `in_channel`: The dimension of input channel.
"""
struct TopKPool{T,S}
A::AbstractMatrix{T}
k::Int
p::AbstractVector{S}
Ã::AbstractMatrix{T}
end
function TopKPool(adj::AbstractMatrix, k::Int, in_channel::Int; init=glorot_uniform)
TopKPool(adj, k, init(in_channel), similar(adj, k, k))
end
function (t::TopKPool)(X::AbstractArray)
y = t.p' * X / norm(t.p)
idx = topk_index(y, t.k)
t.Ã .= view(t.A, idx, idx)
X_ = view(X, :, idx) .* σ.(view(y, idx)')
return X_
end
function topk_index(y::AbstractVector, k::Int)
v = nlargest(k, y)
return collect(1:length(y))[y .>= v[end]]
end
topk_index(y::Adjoint, k::Int) = topk_index(y', k)