In [None]:
import numpy as np
import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings("ignore")

# GeniePath Model

## Adaptive Breadth

In [None]:
from dgl.nn import GATConv

In [None]:
class AdaptiveBreadthLayer(nn.Module):
    def __init__(self, in_dim, h_dim, num_heads=1):
        super(AdaptiveBreadthLayer, self).__init__()
        self.gat = GATConv(in_feats=in_dim, 
                           out_feats=h_dim, 
                           num_heads=num_heads, 
                           activation=torch.tanh)
        
    def forward(self, g, h):
        h = self.gat(g, h)
        h = h.mean(dim=1)
        return h

## Adaptive Depth

In [None]:
class AdaptiveDepthLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(AdaptiveDepthLayer, self).__init__()
        # input gate
        self.input_gate = nn.Linear(in_dim, out_dim)
        # forget gate
        self.forget_gate = nn.Linear(in_dim, out_dim)
        # output gate
        self.output_gate = nn.Linear(in_dim, out_dim)
        # state C
        self.state = nn.Linear(in_dim, out_dim)
    
    def forward(self, c, h):
        # input gate
        i = torch.sigmoid(self.input_gate(h))
        f = torch.sigmoid(self.forget_gate(h))
        o = torch.sigmoid(self.output_gate(h))
        c_tilde = torch.tanh(self.state(h))
        
        c = f * c + i * c_tilde
        h = o * torch.tanh(c)
        
        return c, h

## GeniePath

In [None]:
class GeniePath(nn.Module):
    def __init__(self, in_dim, h_dim, out_dim, depth, lazy_mode=True):
        """
        @in_dim: 输入维度
        @h_dim: 隐藏层的维度
        @out_dim: 输出维度
        @depth: 迭代深度
        """
        super(GeniePath,  self).__init__()
        self.depth = depth
        self.lazy_mode = lazy_mode
        
        self.wx = nn.Linear(in_dim, h_dim)
        
        self.breath_fn = torch.nn.ModuleList()
        self.depth_fn = torch.nn.ModuleList()
        for _ in range(depth):
            self.breath_fn.append(AdaptiveBreadthLayer(h_dim, h_dim))
            if not self.lazy_mode:
                self.depth_fn.append(AdaptiveDepthLayer(h_dim, h_dim))
            else:
                self.depth_fn.append(AdaptiveDepthLayer(2*h_dim, h_dim))

        # 输出层
        self.out_layer = nn.Linear(h_dim, out_dim)
    
    def forward(self, g, x):
        h0 = self.wx(x)
        h = h0
        c = torch.zeros_like(h)
        
        # standard模式
        if not self.lazy_mode:
            for i in range(self.depth):
                h = self.breath_fn[i](g, h)
                c, h = self.depth_fn[i](c, h)
            
            out = torch.relu(self.out_layer(h))
            return out
        
        # lazy模式
        else:  
            collector = []
            for i in range(self.depth):
                h = self.breath_fn[i](g, h)
                collector.append(h)
                
            mu = h0
            for i in range(self.depth):
                h_mu = torch.cat([collector[i], mu], dim=1)
                c, mu = self.depth_fn[i](c, h_mu)
            
            out = torch.relu(self.out_layer(mu))
            return out

# Application on DataSets

## citation graph

In [None]:
from dgl.data import citation_graph  as citegrh
from dgl import DGLGraph

In [None]:
citegrh_data = citegrh.load_cora()

In [None]:
# features and target
features = torch.FloatTensor(citegrh_data.features)
labels = torch.LongTensor(citegrh_data.labels)

# mask
train_mask = torch.BoolTensor(citegrh_data.train_mask)
val_mask = torch.BoolTensor(citegrh_data.val_mask)
test_mask = torch.BoolTensor(citegrh_data.test_mask)

# graph
g = DGLGraph(citegrh_data.graph)
g = g.add_self_loop()

In [None]:
# hyper-parameters
in_dim = features.size()[-1]
h_dim = 4
out_dim = citegrh_data.num_classes
depth = 1
lazy = True

In [None]:
# build model
net = GeniePath(in_dim, h_dim, out_dim, depth, lazy)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, amsgrad=True)

In [None]:
# main loop
dur = []
epoch_losses = []
best_acc = 0
best_epoch = 0

for epoch in range(1000):

    logits = net(g, features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[train_mask], labels[train_mask])
    epoch_losses.append(loss)

    # Compute prediction
    pred = logits.argmax(1)

    # Compute accuracy on training/validation/test
    train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
    val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
    
    if val_acc > best_acc:
        best_acc = val_acc
        best_epoch = epoch

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
        

    print("Epoch {:05d} | Loss {:.4f} | Train: {:.4f} | Val: {:.4f}".format(
        epoch, loss.item(), train_acc, val_acc))

## Pubmed

In [None]:
from dgl.data import PubmedGraphDataset

In [None]:
pub_data = PubmedGraphDataset()

In [None]:
features = torch.FloatTensor(pub_data.features)
labels = torch.LongTensor(pub_data.labels)
train_mask = torch.BoolTensor(pub_data.train_mask)
val_mask = torch.BoolTensor(pub_data.val_mask)
test_mask = torch.BoolTensor(pub_data.test_mask)
g = DGLGraph(pub_data.graph)
g = g.add_self_loop()

In [None]:
# hyper-parameters
in_dim = features.size()[-1]
h_dim = 8
out_dim = pub_data.num_classes
depth = 3
lazy = True

In [None]:
# build model
net = GeniePath(in_dim, h_dim, out_dim, depth, lazy)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, amsgrad=True)

In [None]:
# main loop
dur = []
epoch_losses = []
best_acc = 0
best_epoch = 0

for epoch in range(100):

    logits = net(g, features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[train_mask], labels[train_mask])
    epoch_losses.append(loss)

    # Compute prediction
    pred = logits.argmax(1)

    # Compute accuracy on training/validation/test
    train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
    val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
    
    if val_acc > best_acc:
        best_acc = val_acc
        best_epoch = epoch

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
        

    print("Epoch {:05d} | Loss {:.4f} | Train: {:.4f} | Val: {:.4f}".format(
        epoch, loss.item(), train_acc, val_acc))

## PPI

In [None]:
from dgl.data import PPIDataset
from sklearn.metrics import f1_score

In [None]:
data = PPIDataset(mode="train")
val_data = PPIDataset(mode="valid")
test_data = PPIDataset(mode="test")
g = data.graph

In [None]:
loss_op = torch.nn.BCEWithLogitsLoss()

In [None]:
# hyper-parameters
in_dim = data.features.shape[-1]
h_dim = 8
out_dim = data.labels.shape[-1]
depth = 3
lazy = True

In [None]:
net = GeniePath(in_dim, h_dim, out_dim, depth, lazy)

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, amsgrad=True)

In [None]:
# main loop
dur = []
epoch_losses = []
# best_f1 = 0
# best_epoch = 0
for epoch in range(5000):

    logits = net(g, torch.FloatTensor(data.features))
    torch.nn.BCEWithLogitsLoss()
    loss = loss_op(logits, torch.FloatTensor(data.labels))
    epoch_losses.append(loss)

    # Compute prediction
    pred_val = net(val_data.graph, torch.FloatTensor(val_data.features))
    pred_val = (pred_val.detach().numpy()>0).astype(int)

    micro_f1_val = f1_score(val_data.labels, pred_val, average='micro')

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print("Epoch {:05d} | Loss {:.4f} | Val: {:.4f} | Test: {:.4f}".format(
        epoch, loss.item(), micro_f1_val, 0))