# Graph Convolutional Network for Link Predition
This notebook demonstrates the training of Graph Convolutional Networks for Link Prediction with TigerGraph. Pytorch Geometric's implementation of GCN is used here. We train the model on the Cora dataset from PyG datasets with TigerGraph as the data store. The dataset contains 2708 machine learning papers and 10556 citation links between the papers. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from a dictionary. The dictionary consists of 1433 unique words. Each paper is classified into one of seven classes based on the topic. The goal is to predict the class of each vertex in the graph.

## Data Processing
Here we assume the dataset is already ingested into the TigerGraph database. If not, please refer to the example on data ingestion first.

For each edges, the original attributes include "is_train, is_val", and you may add "is_test" if want have the train/val/test edge sets. Otherwise, for the edgeSplitter, you can just split the train/val sets.

### Connect to TigerGraph

In [16]:
from pyTigerGraph import TigerGraphConnection
import torch

conn = TigerGraphConnection(
    host="http://127.0.0.1", # Change the address to your database server's
    graphname="Cora",
    username="tigergraph",
    password="tigergraph",
    useCert=False
)

In [9]:
conn.getVertexCount('*')

{'Paper': 2708}

In [10]:
conn.getEdgeCount('*')

{'Cite': 10556}

### Train/validation/test split
If there are no is_test in the graph, you can add the edge attribute to TigerGraph or split train and val by 0.8 and 0.2

In [None]:
%%time
splitter = conn.gds.edgeSplitter(is_train=0.8, is_val=0.1, is_test=0.1)

In [None]:
%%time
splitter.run()

## Train on whole graph
Here, we use the full graph for link prediciton. This will **NOT** work when the graph is very large. See the section of Stochastic Mini-Batch Training for real use. However, we still include this example for illustration purpose.

We load the whole graph from TigerGraph which include the feature and split results.

### Construct graph loader and negative edges

In [11]:
graph_loader = conn.gds.graphLoader(
    num_batches=1,
    v_in_feats = ["x"],
    e_extra_feats=["is_train","is_val", "is_test"],
    output_format = "PyG")

In [12]:
data = graph_loader.data

In [13]:
data

Data(edge_index=[2, 10556], is_train=[10556], is_val=[10556], is_test=[10556], x=[2708, 1433])

In [14]:
train_edge_index = data.edge_index[:, data.is_train]
val_edge_index = data.edge_index[:, data.is_val]
test_edge_index = data.edge_index[:, data.is_test]

In [17]:
neg_val_edge = torch.randint(0, data.x.shape[0], val_edge_index.size(), dtype=torch.long)
neg_test_edge = torch.randint(0, data.x.shape[0], test_edge_index.size(), dtype=torch.long)

In [18]:
train_edge_index.shape, val_edge_index.shape, neg_val_edge.shape

(torch.Size([2, 8532]), torch.Size([2, 1007]), torch.Size([2, 1007]))

### Construct GCN Model
We use dot product to measure the similarity of two node in a decode function.

In [19]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, **kwargs):
        super(GCN, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x

    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # concatenate pos and neg edges
        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)  # dot product 
        return logits


### Get binary labels for positive and negative edges

In [20]:
def get_link_labels(pos_edge_index, neg_edge_index):
    E = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(E, dtype=torch.float)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels

### Define Hyperparameters

In [21]:
# Hyperparameters
hp = {"hidden_dim": 128, "out_dim": 64, "num_layers": 2,
      "dropout": 0.6, "lr": 0.01, "l2_penalty": 5e-4}

### Instantiate Model and optimizer

In [31]:
model = GCN(1433, hp["hidden_dim"], hp["out_dim"], hp["num_layers"], hp["dropout"])
optimizer = torch.optim.Adam(
    model.parameters(), lr=hp["lr"], weight_decay=hp["l2_penalty"]
)

In [32]:
val_labels = get_link_labels(val_edge_index, neg_val_edge)
val_labels

tensor([1., 1., 1.,  ..., 0., 0., 0.])

### Train the model

In [33]:
from sklearn.metrics import roc_auc_score

In [34]:
for epoch in range(30):
    model.train()
    neg_train_edge = torch.randint(0, data.x.shape[0], train_edge_index.size(), dtype=torch.long)
    h = model(data.x.float(), train_edge_index)
    logits = model.decode(h, train_edge_index, neg_train_edge)
    labels = get_link_labels(train_edge_index, neg_train_edge)
    loss = F.binary_cross_entropy_with_logits(logits, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    model.eval()
    with torch.no_grad():
        val_logits = model.decode(h, val_edge_index, neg_val_edge)
        val_logits = val_logits.sigmoid()
        print('Epoch: {}, training loss: {}, valid roc_auc_score: {}'.format(epoch, loss.item(), roc_auc_score(val_labels, val_logits)))

Epoch: 0, training loss: 0.652077853679657, valid roc_auc_score: 0.8222595752276272
Epoch: 1, training loss: 1.0911240577697754, valid roc_auc_score: 0.813158930189764
Epoch: 2, training loss: 1.2042229175567627, valid roc_auc_score: 0.7523507246691234
Epoch: 3, training loss: 0.7218725085258484, valid roc_auc_score: 0.7803331002742472
Epoch: 4, training loss: 0.6301491856575012, valid roc_auc_score: 0.8178519972900719
Epoch: 5, training loss: 0.6423941254615784, valid roc_auc_score: 0.8162573997903454
Epoch: 6, training loss: 0.6524781584739685, valid roc_auc_score: 0.8052441252838867
Epoch: 7, training loss: 0.6515461802482605, valid roc_auc_score: 0.8135050673093707
Epoch: 8, training loss: 0.6498074531555176, valid roc_auc_score: 0.8139863063816444
Epoch: 9, training loss: 0.6412517428398132, valid roc_auc_score: 0.8221377862410988
Epoch: 10, training loss: 0.6362589597702026, valid roc_auc_score: 0.8369270124027538
Epoch: 11, training loss: 0.6336613893508911, valid roc_auc_score:

## Stochastic Batch Training
For stochastic batch training, we split the training edges into batches. At each specific batch, to do the link prediction, we need to know the neighbor graphs for the each pair of nodes that has a edge.

We use the edgeNeighborLoader, which can load the neghbors of the pair nodes of a edge and has the same parameters as neighborLoader(). The result of a batch is, for example,

`Data(edge_index=[2, 6917], is_train=[6917], is_val=[6917], is_test=[6917], is_seed=[6917], x=[2188, 1433], y=[2188])`

where `is_seed` indicates whether each edge is a seed edge or not


In [36]:
# Hyperparameters
hp = {"hidden_dim": 128, "out_dim": 64, "num_layers": 2,
      "dropout": 0.6, "lr": 0.01, "l2_penalty": 5e-4}

In [37]:
model = GCN(1433, hp["hidden_dim"], hp["out_dim"], hp["num_layers"], hp["dropout"])
optimizer = torch.optim.Adam(
    model.parameters(), lr=hp["lr"], weight_decay=hp["l2_penalty"]
)

### Construct the edge_neighbor_loader for train/val/test edges

In [38]:
train_edge_neighbor_loader = conn.gds.edgeNeighborLoader(
    v_in_feats=["x"],
    v_out_labels=["y"],
    num_batches=5,
    e_extra_feats=["is_train","is_val", "is_test"],
    output_format="PyG",
    num_neighbors=10,
    num_hops=2,
    filter_by="is_train",
    shuffle=False,
)

In [40]:
val_edge_neighbor_loader = conn.gds.edgeNeighborLoader(
    v_in_feats=["x"],
    v_out_labels=["y"],
    num_batches=5,
    e_extra_feats=["is_train","is_val", "is_test"],
    output_format="PyG",
    num_neighbors=10,
    num_hops=2,
    filter_by="is_val",
    shuffle=False,
)

In [41]:
for epoch in range(30):
    model.train()
    total_loss = 0
    for bid, batch in enumerate(train_edge_neighbor_loader):
        # get the training edges and negative edges sampled in the same batch
        train_edges = batch.edge_index[:, batch.is_seed]
        neg_train_edges = torch.randint(0, batch.x.shape[0], train_edges.size(), dtype=torch.long)
        # The graph only include the edges whose is_train is True
        train_graph_edges = batch.edge_index[:, batch.is_train]
        h = model(batch.x.float(), train_graph_edges)
        logits = model.decode(h, train_edges, neg_train_edges)
        labels = get_link_labels(train_edges, neg_train_edges)
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    model.eval()
    all_labels = []
    all_logits = []
    for batch in val_edge_neighbor_loader:
        val_edges = batch.edge_index[:, batch.is_seed]
        neg_val_edges = torch.randint(0, batch.x.shape[0], val_edges.size(), dtype=torch.long)
        # Need to use the train edge for GCN
        val_graph_edges = batch.edge_index[:, batch.is_train]
        with torch.no_grad():
            h = model(batch.x.float(), val_graph_edges)
            logits = model.decode(h, val_edges, neg_val_edges)
            labels = get_link_labels(val_edges, neg_val_edges)
            logits = logits.sigmoid()
            all_labels.extend(labels)
            all_logits.extend(logits)
    print('Epoch: {}, training loss: {}, valid roc_auc_score: {}'.format(epoch, total_loss, roc_auc_score(all_labels, all_logits)))
    

Epoch: 0, training loss: 3.0831108689308167, valid roc_auc_score: 0.8723227378558631
Epoch: 1, training loss: 2.9560742378234863, valid roc_auc_score: 0.909797258317892
Epoch: 2, training loss: 2.6730599403381348, valid roc_auc_score: 0.928123788889886
Epoch: 3, training loss: 2.4813827872276306, valid roc_auc_score: 0.934900581727313
Epoch: 4, training loss: 2.358451873064041, valid roc_auc_score: 0.940606420399803
Epoch: 5, training loss: 2.3540602326393127, valid roc_auc_score: 0.9591291939541384
Epoch: 6, training loss: 2.264620155096054, valid roc_auc_score: 0.9612563100994135
Epoch: 7, training loss: 2.265827476978302, valid roc_auc_score: 0.9633173544868149
Epoch: 8, training loss: 2.223086267709732, valid roc_auc_score: 0.9682091299335635
Epoch: 9, training loss: 2.237872064113617, valid roc_auc_score: 0.9637492862770931
Epoch: 10, training loss: 2.2264839708805084, valid roc_auc_score: 0.9678171370417011
Epoch: 11, training loss: 2.2111732065677643, valid roc_auc_score: 0.9553