`-` Reference

- [Basic of Graph Convolution Network](https://www.youtube.com/watch?v=YL1jGgcY78U)
- [Molecular Property(logP) Prediction with GCN](https://www.youtube.com/watch?v=htTt4iPJqMg)

## Graph Neural Network
- GCN (Graph Convolutional Network) Layer
- Skip Connection
- GCN Block
- Readout

In [12]:
import torch
import torch.nn as nn

import argparse

In [13]:
parser = argparse.ArgumentParser()
args = parser.parse_args("")

### GCNLayer

In [3]:
class GCNLayer(nn.Module):

    def __init__(self, in_dim, out_dim, n_atom, act=None, bn=None):
        super(GCNLayer, self).__init__()

        self.use_bn = bn
        self.linear = nn.Linear(in_dim, out_dim)
        nn.init.xavier_uniform_(self.linear.weight)
        self.bn = nn.BatchNorm1d(n_atom)
        self.activation = act

    def forward(self, x, adj):
        out = self.linear(x)
        out = torch.matmul(adj, out)
        if self.use_bn:
            out = self.bn(out)
        if self.activation != None:
            out = self.activation(out)
        return out, adj

### Skip Connection

In [4]:
class SkipConnection(nn.Module):

    def __init__(self, in_dim, out_dim):
        super(SkipConnection, self).__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim

        self.linear = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, in_x, out_x):
        if (self.in_dim != self.out_dim):
            in_x = self.linear(in_x)
        out = in_x + out_x
        return out

### GCN Block

In [5]:
class GCNBlock(nn.Module):

    def __init__(self, n_layer, in_dim, hidden_dim, out_dim, n_atom, bn=True, sc="sc", act=None):
        super(GCNBlock, self).__init__

        self.layers = nn.ModuleList()
        for i in range(n_layer):
            self.layers.append(GCNLayer(in_dim if i==0 else hidden_dim,
                                        out_dim if i==n_layer-1 else hidden_dim,
                                        n_atom,
                                        act if i != n_layer-1 else None,
                                        bn))

        if sc == "sc":
            self.sc = SkipConnection(in_dim, out_dim)
        elif sc == "no":
            self.sc = None
        else:
            assert False, "Wrong sc type."
        
        self.activation = act

    def forward(self, x, adj):
        identity = x
        for i, layer in enumerate(self.layer):
            out, adj = layer((x if i==0 else out), adj)
        if self.sc != None:
            out = self.sc(identity, out)
        if self.activation != None:
            out = self.activation(out)
        return out, adj

### Readout

In [6]:
class Readout(nn.Module):
    def __init__(self, in_dim, out_dim, act=None):
        super(Readout, self).__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim

        self.linear = nn.Linear(self.in_dim, self.out_dim)
        nn.init.xavier_uniform_(self.liear.weight)

        self.activation = act
    
    def forward(self, x):
        out = self.linear(x)
        out = torch.sum(out, 1)
        if self.activation != None:
            out = self.activation(out)
        return out

### Predictor

In [7]:
class Predictor(nn.Module):
    def __init__(self, in_dim, out_dim, act=None):
        super(Predictor, self).__init__()

        self.in_dim = in_dim
        self.out_dim = out_dim

        self.linear = nn.Linear(in_dim, out_dim)
        nn.init.xavier_uniform_(self.linear.weight)

        self.activation = act

    def forward(self, x):
        out = self.linear(x)
        if self.activation != None:
            out = self.activation(out)
        return out

### GCN Network

In [9]:
class GCNNet(nn.Module):
    def __init__(self, args):
        super(GCNNet, self).__init__()

        self.blocks = nn.ModuleList()
        for i in range(args.n_block):
            self.blocks.append(GCNBlock(args.n_layer,
                                        args.in_dim if i==0 else args.hidden_dim,
                                        args.hidden_dim,
                                        args.hidden_dim,
                                        args.n_atom,
                                        args.bn,
                                        args.sc,
                                        args.act))
        self.readout = Readout(args.hidden_dim,
                                args.pred_dim1,
                                act=nn.ReLU())
        self.pred1 = Predictor(args.pred_dim1,
                                args.pred_dim2,
                                act=nn.ReLU())
        self.pred2 = Predictor(args.pred_dim2,
                                args.pred_dim3,
                                act=nn.Tanh())
        self.pred3 = Predictor(args.pred_dim3,
                                args.out_dim)

    def forward(self, x, adj):
        for i, block in enumerate(self.blocks):
            out, adj = block((x if i==0 else out), adj)
        out = self.readout(out)
        out = self.pred1(out)
        out = self.pred2(out)
        out = self.pred3(out)
        return out