Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ NNlib = "0.7"
NNlibCUDA = "0.1"
Reexport = "1.1"
Zygote = "0.6"
julia = "1.6"
julia = "1.6 - 1.7"

[extras]
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
4 changes: 4 additions & 0 deletions src/GeometricFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ using LightGraphs
using Zygote

export
# layers/graphlayers
AbstractGraphLayer,

# layers/gn
GraphNet,

Expand Down Expand Up @@ -55,6 +58,7 @@ include("datasets.jl")

include("utils.jl")

include("layers/graphlayers.jl")
include("layers/gn.jl")
include("layers/msgpass.jl")

Expand Down
60 changes: 26 additions & 34 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Graph convolutional layer.
The input to the layer is a node feature array `X`
of size `(num_features, num_nodes)`.
"""
struct GCNConv{A<:AbstractMatrix, B, F, S<:AbstractFeaturedGraph}
struct GCNConv{A<:AbstractMatrix, B, F, S<:AbstractFeaturedGraph} <: AbstractGraphLayer
weight::A
bias::B
σ::F
Expand All @@ -42,7 +42,6 @@ function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix)
end

(l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GCNConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::GCNConv)
out, in = size(l.weight)
Expand All @@ -66,7 +65,7 @@ Chebyshev spectral graph convolutional layer.
- `bias`: Add learnable bias.
- `init`: Weights' initializer.
"""
struct ChebConv{A<:AbstractArray{<:Number,3}, B, S<:AbstractFeaturedGraph}
struct ChebConv{A<:AbstractArray{<:Number,3}, B, S<:AbstractFeaturedGraph} <: AbstractGraphLayer
weight::A
bias::B
fg::S
Expand Down Expand Up @@ -104,7 +103,6 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
end

(l::ChebConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::ChebConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::ChebConv)
out, in, k = size(l.weight)
Expand Down Expand Up @@ -164,7 +162,6 @@ function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix)
end

(l::GraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GraphConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::GraphConv)
in_channel = size(l.weight1, ndims(l.weight1))
Expand Down Expand Up @@ -272,7 +269,6 @@ function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix)
end

(l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GATConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::GATConv)
in_channel = size(l.weight, ndims(l.weight))
Expand Down Expand Up @@ -340,7 +336,6 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T
end

(l::GatedGraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::GatedGraphConv)(x::AbstractMatrix) = l(l.fg, x)


function Base.show(io::IO, l::GatedGraphConv)
Expand Down Expand Up @@ -383,7 +378,6 @@ function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix)
end

(l::EdgeConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
(l::EdgeConv)(x::AbstractMatrix) = l(l.fg, x)

function Base.show(io::IO, l::EdgeConv)
print(io, "EdgeConv(", l.nn)
Expand All @@ -393,34 +387,34 @@ end


"""
GINConv([fg,] nn, [eps])
GINConv([fg,] nn, [eps=0])

Graph Isomorphism Network.

# Arguments

- `fg`: Optionally pass in a FeaturedGraph as input.
- `nn`: A neural network/layer.
- `eps`: Weighting factor. Default 0.
- `eps`: Weighting factor.

The definition of this is as defined in the original paper,
Xu et. al. (2018) https://arxiv.org/abs/1810.00826.
"""
struct GINConv{V<:AbstractFeaturedGraph,R<:Real} <: MessagePassing
fg::V
struct GINConv{G,R} <: MessagePassing
fg::G
nn
eps::R
end

function GINConv(fg::AbstractFeaturedGraph, nn; eps=0f0)
GINConv(fg, nn, eps)
function GINConv(fg::G, nn, eps::R=0f0) where {G<:AbstractFeaturedGraph,R<:Real}
new{G,R}(fg, nn, eps)
end
end

function GINConv(nn; eps=0f0)
function GINConv(nn, eps::Real=0f0)
GINConv(NullGraph(), nn, eps)
end

Flux.trainable(g::GINConv) = (fg=g.fg,nn=g.nn)
Flux.trainable(g::GINConv) = (fg=g.fg, nn=g.nn)

message(g::GINConv, x_i::AbstractVector, x_j::AbstractVector) = x_j
update(g::GINConv, m::AbstractVector, x) = g.nn((1 + g.eps) * x + m)
Expand All @@ -434,12 +428,11 @@ function (g::GINConv)(fg::FeaturedGraph, X::AbstractMatrix)
X
end

(l::GINConv)(x::AbstractMatrix) = l(l.fg, x)
(l::GINConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))


"""
CGConv([fg,] (node_dim, edge_dim), out, init)
CGConv([fg,] (node_dim, edge_dim), out, init, bias=true, as_edge=false)

Crystal Graph Convolutional network. Uses both node and edge features.

Expand All @@ -451,18 +444,17 @@ Crystal Graph Convolutional network. Uses both node and edge features.
- `out`: Dimensionality of the output features.
- `init`: Initialization algorithm for each of the weight matrices
- `bias`: Whether or not to learn an additive bias parameter.
- `as_edge`: When call to layer `CGConv(M)`, accept input feature as node features or edge features.

# Usage

You can call `CGConv` in several different ways:

- Pass a FeaturedGraph: `CGConv(fg)`, returns `FeaturedGraph`
- Pass both node and edge features: `CGConv(X, E)`
- Pass one matrix, which can either be node features or edge features: `CGConv(M; edge)`:
`edge` is default false, meaning that `M` denotes node features.
- Pass one matrix, which is determined as node features or edge features by `as_edge` keyword argument.
"""
struct CGConv{V <: AbstractFeaturedGraph, T,
A <: AbstractMatrix{T}, B} <: MessagePassing
struct CGConv{E, V<:AbstractFeaturedGraph, A<:AbstractMatrix, B} <: MessagePassing
fg::V
Wf::A
Ws::A
Expand All @@ -472,18 +464,20 @@ end

@functor CGConv

function CGConv(fg::AbstractFeaturedGraph, dims::NTuple{2,Int};
init=glorot_uniform, bias=true)
function CGConv(fg::G, dims::NTuple{2,Int};
init=glorot_uniform, bias=true, as_edge=false) where {G<:AbstractFeaturedGraph}
node_dim, edge_dim = dims
Wf = init(node_dim, 2*node_dim + edge_dim)
Ws = init(node_dim, 2*node_dim + edge_dim)
bf = Flux.create_bias(Wf, bias, node_dim)
bs = Flux.create_bias(Ws, bias, node_dim)
CGConv(fg, Wf, Ws, bf, bs)
T, S = typeof(Wf), typeof(bf)

CGConv{as_edge,G,T,S}(fg, Wf, Ws, bf, bs)
end

function CGConv(dims::NTuple{2,Int}; init=glorot_uniform, bias=true)
CGConv(NullGraph(), dims; init=init, bias=bias)
function CGConv(dims::NTuple{2,Int}; init=glorot_uniform, bias=true, as_edge=false)
CGConv(NullGraph(), dims; init=init, bias=bias, as_edge=as_edge)
end

message(c::CGConv,
Expand All @@ -503,10 +497,8 @@ end
(l::CGConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf=l(fg, node_feature(fg),
edge_feature(fg)),
ef=edge_feature(fg))
(l::CGConv)(M::AbstractMatrix; as_edge=false) =
if as_edge
l(l.fg, node_feature(l.fg), M)
else
l(l.fg, M, edge_feature(l.fg))
end

(l::CGConv)(X::AbstractMatrix, E::AbstractMatrix) = l(l.fg, X, E)

(l::CGConv{true})(M::AbstractMatrix) = l(l.fg, node_feature(l.fg), M)
(l::CGConv{false})(M::AbstractMatrix) = l(l.fg, M, edge_feature(l.fg))
2 changes: 1 addition & 1 deletion src/layers/gn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ aggregate(aggr::typeof(max), X) = vec(maximum(X, dims=2))
aggregate(aggr::typeof(min), X) = vec(minimum(X, dims=2))
aggregate(aggr::typeof(mean), X) = vec(aggr(X, dims=2))

abstract type GraphNet end
abstract type GraphNet <: AbstractGraphLayer end

@inline update_edge(gn::GraphNet, e, vi, vj, u) = e
@inline update_vertex(gn::GraphNet, ē, vi, u) = vi
Expand Down
3 changes: 3 additions & 0 deletions src/layers/graphlayers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
abstract type AbstractGraphLayer end

(l::AbstractGraphLayer)(x::AbstractMatrix) = l(l.fg, x)
10 changes: 5 additions & 5 deletions test/cuda/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ N = 4
adj = [0 1 0 1;
1 0 1 0;
0 1 0 1;
1 0 1 0]
1 0 1 0] |> gpu

fg = FeaturedGraph(adj)

Expand All @@ -15,7 +15,7 @@ fg = FeaturedGraph(adj)
gc = GCNConv(fg, in_channel=>out_channel) |> gpu
@test size(gc.weight) == (out_channel, in_channel)
@test size(gc.bias) == (out_channel,)
@test Array(adjacency_matrix(gc.fg)) == adj
@test collect(graph(gc.fg)) == Array(adj)

X = rand(in_channel, N) |> gpu
Y = gc(X)
Expand All @@ -35,7 +35,7 @@ fg = FeaturedGraph(adj)
cc = ChebConv(fg, in_channel=>out_channel, k) |> gpu
@test size(cc.weight) == (out_channel, in_channel, k)
@test size(cc.bias) == (out_channel,)
@test Array(adjacency_matrix(cc.fg)) == adj
@test collect(graph(cc.fg)) == Array(adj)
@test cc.k == k
@test cc.in_channel == in_channel
@test cc.out_channel == out_channel
Expand All @@ -44,8 +44,8 @@ fg = FeaturedGraph(adj)
Y = cc(X)
@test size(Y) == (out_channel, N)

# g = Zygote.gradient(x -> sum(cc(x)), X)[1]
# @test size(g) == size(X)
g = Zygote.gradient(x -> sum(cc(x)), X)[1]
@test size(g) == size(X)

# g = Zygote.gradient(model -> sum(model(X)), cc)[1]
# @test size(g.weight) == size(cc.weight)
Expand Down
2 changes: 1 addition & 1 deletion test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex)
eps = 0.001

@testset "layer with graph" begin
gc = GINConv(FeaturedGraph(adj), nn, eps=eps)
gc = GINConv(FeaturedGraph(adj), nn, eps)
@test size(gc.nn.layers[1].weight) == (out_channel, in_channel)
@test size(gc.nn.layers[1].bias) == (out_channel, )
@test graph(gc.fg) === adj
Expand Down