Skip to content

Commit

Permalink
Merge pull request #123 from yuehhua/refactor
Browse files Browse the repository at this point in the history
Refactor graph convolutional layers
  • Loading branch information
yuehhua committed Nov 5, 2020
2 parents 3a46a95 + 56c5222 commit 14468b1
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ end

message(g::GraphConv, x_i, x_j::AbstractVector, e_ij) = g.weight2 * x_j
update(g::GraphConv, m::AbstractVector, x::AbstractVector) = g.σ.(g.weight1*x .+ m .+ g.bias)
function (g::GraphConv)(X::AbstractMatrix)
@assert has_graph(g.fg) "A GraphConv created without a graph must be given a FeaturedGraph as an input."
fg = FeaturedGraph(graph(g.fg), X)
fg_ = g(fg)
node_feature(fg_)
function (gc::GraphConv)(X::AbstractMatrix)
@assert has_graph(gc.fg) "A GraphConv created without a graph must be given a FeaturedGraph as an input."
g = graph(gc.fg)
_, X = propagate(gc, adjacency_list(g), Fill(0.f0, 0, ne(g)), X, :add)
X
end
(g::GraphConv)(fg::FeaturedGraph) = propagate(g, fg, :add)

Expand Down Expand Up @@ -305,11 +305,11 @@ function update_batch_vertex(g::GATConv, M::AbstractMatrix, X::AbstractMatrix, u
return M .+ g.bias
end

function (g::GATConv)(X::AbstractMatrix)
@assert has_graph(g.fg) "A GATConv created without a graph must be given a FeaturedGraph as an input."
fg = FeaturedGraph(graph(g.fg), X)
fg_ = g(fg)
node_feature(fg_)
function (gat::GATConv)(X::AbstractMatrix)
@assert has_graph(gat.fg) "A GATConv created without a graph must be given a FeaturedGraph as an input."
g = graph(gat.fg)
_, X = propagate(gat, adjacency_list(g), Fill(0.f0, 0, ne(g)), X, :add)
X
end
(g::GATConv)(fg::FeaturedGraph) = propagate(g, fg, :add)

Expand Down Expand Up @@ -366,26 +366,29 @@ end
message(g::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
update(g::GatedGraphConv, m::AbstractVector, x) = m

function (g::GatedGraphConv)(X::AbstractMatrix)
@assert has_graph(g.fg) "A GraphConv created without a graph must be given a FeaturedGraph as an input."
fg = FeaturedGraph(graph(g.fg), X)
fg_ = g(fg)
node_feature(fg_)
function (ggc::GatedGraphConv)(X::AbstractMatrix{T}) where {T<:Real}
@assert has_graph(ggc.fg) "A GraphConv created without a graph must be given a FeaturedGraph as an input."
ggc(adjacency_list(ggc.fg), X)
end

function (g::GatedGraphConv{V,T})(fg::FeaturedGraph) where {V,T<:Real}
H = node_feature(fg)
function (ggc::GatedGraphConv{V,T})(fg::FeaturedGraph) where {V,T<:Real}
g = graph(fg)
H = ggc(adjacency_list(g), node_feature(fg))
FeaturedGraph(g, H)
end

function (ggc::GatedGraphConv)(adj::AbstractVector{T}, X::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
H = X
m, n = size(H)
@assert (m <= g.out_ch) "number of input features must less or equals to output features."
(m < g.out_ch) && (H = vcat(H, zeros(T, g.out_ch - m, n)))

for i = 1:g.num_layers
M = view(g.weight, :, :, i) * H
fg_ = propagate(g, FeaturedGraph(graph(fg), M), g.aggr)
M = node_feature(fg_)
H, _ = g.gru(H, M) # BUG: FluxML/Flux.jl#1381
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
(m < ggc.out_ch) && (H = vcat(H, zeros(S, ggc.out_ch - m, n)))

for i = 1:ggc.num_layers
M = view(ggc.weight, :, :, i) * H
_, M = propagate(ggc, adj, Fill(0.f0, 0, ne(adj)), M, :add)
H, _ = ggc.gru(H, M) # BUG: FluxML/Flux.jl#1381
end
FeaturedGraph(graph(fg), H)
H
end

function Base.show(io::IO, l::GatedGraphConv)
Expand Down Expand Up @@ -431,9 +434,9 @@ update(e::EdgeConv, m::AbstractVector, x) = m

function (e::EdgeConv)(X::AbstractMatrix)
@assert has_graph(e.fg) "A EdgeConv created without a graph must be given a FeaturedGraph as an input."
fg = FeaturedGraph(graph(e.fg), X)
fg_ = e(fg)
node_feature(fg_)
g = graph(e.fg)
_, X = propagate(e, adjacency_list(g), Fill(0.f0, 0, ne(g)), X, e.aggr)
X
end

(e::EdgeConv)(fg::FeaturedGraph) = propagate(e, fg, e.aggr)
Expand Down

0 comments on commit 14468b1

Please sign in to comment.