In [76]:
def glorot(tensor):
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)

        
def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)

import math

from typing import Optional, Tuple
from torch_geometric.typing import Adj, OptTensor, PairTensor
import torch
from torch import nn, einsum, broadcast_tensors
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, matmul, fill_diag, sum as sparsesum, mul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.typing import Adj, Size, OptTensor, Tensor

class CoorsNorm(nn.Module):
    def __init__(self, eps = 1e-8):
        super().__init__()
        self.eps = eps
        self.fn = nn.Sequential(nn.LayerNorm(1), nn.GELU())

    def forward(self, coors):
        norm = coors.norm(dim = -1, keepdim = True)
        normed_coors = coors / norm.clamp(min = self.eps)
        phase = self.fn(norm)
        return (phase * normed_coors)

    
class PEG_conv(MessagePassing):
    
    r"""The simple graph convolutional operator from the `"Simplifying Graph
    Convolutional Networks" <https://arxiv.org/abs/1902.07153>`_ paper
    
    
    Args:
        feats_dim (int): Size of node features.
        pos_dim (int): Size of positional encoding.
        improved (bool, optional): If set to :obj:`True`, the layer computes
            :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
            (default: :obj:`False`)
        cached (bool, optional): If set to :obj:`True`, the layer will cache
            the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
            \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the
            cached version for further executions.
            This parameter should only be set to :obj:`True` in transductive
            learning scenarios. (default: :obj:`False`)
        add_self_loops (bool, optional): If set to :obj:`False`, will not add
            self-loops to the input graph. (default: :obj:`True`)
        normalize (bool, optional): Whether to add self-loops and compute
            symmetric normalization coefficients on the fly.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        update_coors (bool): Whether to update positional encodings.
        use_formerinfo (bool): Whether to use previous layer's output to update node features.
        norm_coors (bool): Whether to normalize positional encodings. Only used when update_coors = True. 
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """

    def __init__(self, in_feats_dim: int, pos_dim: int, out_feats_dim: int, edge_mlp_dim: int = 32,
                 improved: bool = False, cached: bool = False,
                 add_self_loops: bool = True, normalize: bool = True,
                 bias: bool = True, update_coors: bool = False,
                 use_formerinfo: bool = False, norm_coors = True, **kwargs):

        kwargs.setdefault('aggr', 'add')
        super(PEG_conv, self).__init__(**kwargs)

        self.in_feats_dim = in_feats_dim
        self.out_feats_dim = out_feats_dim
        self.pos_dim = pos_dim
        self.update_coors = update_coors
        self.use_formerinfo = use_formerinfo
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.normalize = normalize
        self.edge_mlp_dim = edge_mlp_dim
        self.coors_norm = CoorsNorm() if norm_coors else nn.Identity()

        self._cached_edge_index = None
        self._cached_adj_t = None
        
        self.edge_mlp1 = nn.Linear(1, edge_mlp_dim)
        self.edge_mlp2 = nn.Linear(edge_mlp_dim, 1)
        self.weight_withformer = Parameter(torch.Tensor(in_feats_dim + in_feats_dim, out_feats_dim))
        self.weight_noformer = Parameter(torch.Tensor(in_feats_dim, out_feats_dim))
        self.coors_mlp = nn.Sequential(
            nn.Linear(pos_dim, pos_dim * 4),
            nn.Linear(pos_dim * 4, 1)
        ) if update_coors else None

        if bias:
            self.bias = Parameter(torch.Tensor(out_feats_dim))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight_withformer)
        glorot(self.weight_noformer)
        zeros(self.bias)
        self._cached_edge_index = None
        self._cached_adj_t = None


    def forward(self,x: Tensor, edge_index: Adj, 
                edge_weight: OptTensor = None) -> Tensor:
        """"""
        
        coors, feats = x[:, :self.pos_dim], x[:, self.pos_dim:]
        
        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, feats.size(self.node_dim),
                        self.improved, self.add_self_loops)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, feats.size(self.node_dim),
                        self.improved, self.add_self_loops)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache
        else:
            print('We normalize the adjacent matrix in PEG.')
        
        
        rel_coors = coors[edge_index[0]] - coors[edge_index[1]]
        neighbour_coors = coors[edge_index[1]]
        rel_dist  = (rel_coors ** 2).sum(dim=-1, keepdim=True)

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        # pos: l2 norms
        # rel_coors: used in updating positional encodings, not required for PEG
        hidden_out, coors_out = self.propagate(edge_index, x = feats, edge_weight=edge_weight, pos=rel_dist, coors=coors, rel_coors=rel_coors, neighbour_coors = neighbour_coors,
                             size=None)
        
        

        if self.bias is not None:
            hidden_out += self.bias

        return torch.cat([coors_out, hidden_out], dim=-1)


    def message(self, x_i: Tensor, x_j: Tensor, edge_weight: OptTensor, pos) -> Tensor:
        PE_edge_weight = self.edge_mlp1(pos)
        PE_edge_weight = self.edge_mlp2(PE_edge_weight)
        PE_edge_weight = torch.sigmoid(PE_edge_weight)
        return x_j if edge_weight is None else PE_edge_weight * edge_weight.view(-1, 1) * x_j
    
    def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
        """The initial call to start propagating messages.
            Args:
            `edge_index` holds the indices of a general (sparse)
                assignment matrix of shape :obj:`[N, M]`.
            size (tuple, optional) if none, the size will be inferred
                and assumed to be quadratic.
            **kwargs: Any additional data which is needed to construct and
                aggregate messages, and to update node embeddings.
        """
        size = self.__check_input__(edge_index, size)
        coll_dict = self.__collect__(self.__user_args__,
                                     edge_index, size, kwargs)
        msg_kwargs = self.inspector.distribute('message', coll_dict)
        aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
        update_kwargs = self.inspector.distribute('update', coll_dict)

        # get messages
        m_ij = self.message(**msg_kwargs)


        m_i = self.aggregate(m_ij, **aggr_kwargs)
        
        # update coors if specified
        if self.update_coors:
            coor_wij = self.coors_mlp(m_ij)
            kwargs["neighbour_coors"] = self.coors_norm(kwargs["neighbour_coors"])
            mhat_i = self.aggregate(coor_wij * kwargs["neighbour_coors"], **aggr_kwargs)
            coors_out = kwargs["coors"] + mhat_i
        else:
            coors_out = kwargs["coors"]
        
        
        hidden_feats = kwargs["x"]
        if self.use_formerinfo:
            hidden_out = torch.cat([hidden_feats, m_i], dim = -1)
            hidden_out = hidden_out @ self.weight_withformer
        else:
            hidden_out = m_i
            hidden_out = hidden_out @ self.weight_noformer
        



        # return tuple
        return self.update((hidden_out, coors_out), **update_kwargs)

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

In [1]:
import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import to_networkx
from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from PEG-PYG import PEG_conv
from torch_geometric.utils import train_test_split_edges

# Link prediction example for PEG (cora)

In [100]:
device = f'cuda:{7}' if torch.cuda.is_available() else 'cpu'
#device = "cpu"

In [101]:
dataset = 'Cora'
path = osp.join('.', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
print(dataset.data)

Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])


In [102]:
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
print(data)

Data(test_neg_edge_index=[2, 527], test_pos_edge_index=[2, 527], train_neg_adj_mask=[2708, 2708], train_pos_edge_index=[2, 8976], val_neg_edge_index=[2, 263], val_pos_edge_index=[2, 263], x=[2708, 1433])


In [56]:
#Build train matrix for PE preparation
import copy
train_graph = copy.deepcopy(dataset[0])
train_graph.edge_index = data.train_pos_edge_index
G = to_networkx(train_graph)

# We use Deepwalk to calculate PE in this example.

In [103]:
import networkx as nx
import pandas as pd
import numpy as np
import random
from tqdm import tqdm
import itertools
import math
from joblib import Parallel, delayed
from tqdm import trange

In [104]:
def partition_num(num, workers):
    if num % workers == 0:
        return [num//workers]*workers
    else:
        return [num//workers]*workers + [num % workers]

In [59]:
#modified from 
class RandomWalker:
    def __init__(self, G, p=1, q=1, use_rejection_sampling=0):
        """
        :param G:
        :param p: Return parameter,controls the likelihood of immediately revisiting a node in the walk.
        :param q: In-out parameter,allows the search to differentiate between “inward” and “outward” nodes
        :param use_rejection_sampling: Whether to use the rejection sampling strategy in node2vec.
        """
        self.G = G
        self.p = p
        self.q = q
        self.use_rejection_sampling = use_rejection_sampling
    
    def deepwalk_walk(self, walk_length, start_node):

        walk = [start_node]

        while len(walk) < walk_length:
            cur = walk[-1]
            cur_nbrs = list(self.G.neighbors(cur))
            if len(cur_nbrs) > 0:
                walk.append(random.choice(cur_nbrs))
            else:
                break
        return walk

    def simulate_walks(self, num_walks, walk_length, workers=1, verbose=0):

        G = self.G

        nodes = list(G.nodes())

        results = Parallel(n_jobs=workers, verbose=verbose, )(
            delayed(self._simulate_walks)(nodes, num, walk_length) for num in
            partition_num(num_walks, workers))

        walks = list(itertools.chain(*results))

        return walks
    
    def _simulate_walks(self, nodes, num_walks, walk_length,):
        walks = []
        for _ in range(num_walks):
            random.shuffle(nodes)
            for v in nodes:
                if self.p == 1 and self.q == 1:
                    walks.append(self.deepwalk_walk(
                        walk_length=walk_length, start_node=v))
                else:
                    return ("only work for DeepWalk")
        return walks

In [60]:
from gensim.models import Word2Vec
import pandas as pd

class DeepWalk:
    def __init__(self, graph, walk_length = 80, num_walks = 10, workers=1):

        self.graph = graph
        self.w2v_model = None
        self._embeddings = {}

        self.walker = RandomWalker(
            graph, p=1, q=1, )
        self.sentences = self.walker.simulate_walks(
            num_walks=num_walks, walk_length=walk_length, workers=workers, verbose=1)

    def train(self, embed_size=128, window_size=5, workers=3, iter=3, **kwargs):

        kwargs["sentences"] = self.sentences
        kwargs["min_count"] = kwargs.get("min_count", 0)
        kwargs["vector_size"] = embed_size
        kwargs["sg"] = 1  # skip gram
        kwargs["hs"] = 1  # deepwalk use Hierarchical Softmax
        kwargs["workers"] = workers
        kwargs["window"] = window_size
        kwargs["epochs"] = iter

        print("Learning embedding vectors...")
        model = Word2Vec(**kwargs)
        print("Learning embedding vectors done!")

        self.w2v_model = model
        return model

    def get_embeddings(self,):
        if self.w2v_model is None:
            print("model not train")
            return {}

        self._embeddings = {}
        for word in self.graph.nodes():
            self._embeddings[word] = self.w2v_model.wv[word]

        return self._embeddings

In [61]:
model_emb = DeepWalk(G,walk_length=80, num_walks=10,workers=1)#init model
model_emb.train(embed_size = 128)# train model
emb = model_emb.get_embeddings()# get embedding vectors
embeddings = []
for i in range(len(emb)):
    embeddings.append(emb[i])
embeddings = np.array(embeddings)

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    2.5s finished


Learning embedding vectors...
Learning embedding vectors done!


In [85]:
class Net(torch.nn.Module):
    def __init__(self, in_feats_dim, pos_dim, hidden_dim, use_former_information = False, update_coors = False):
        super(Net, self).__init__()
        
        self.in_feats_dim = in_feats_dim
        self.hidden_dim = hidden_dim
        self.pos_dim = pos_dim
        self.use_former_information = use_former_information
        self.update_coors = update_coors
        
        self.conv1 = PEG_conv(in_feats_dim = in_feats_dim, pos_dim = pos_dim, out_feats_dim = hidden_dim,
                               use_formerinfo = use_former_information, update_coors = update_coors)
        self.conv2 = PEG_conv(in_feats_dim = hidden_dim, pos_dim = pos_dim, out_feats_dim = hidden_dim,
                               use_formerinfo = use_former_information, update_coors = update_coors)
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        self.fc = nn.Linear(2, 1)

    def forward(self, x, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        x = self.conv1(x, pos_edge_index)
        x = self.conv2(x, pos_edge_index)
        pos_dim = self.pos_dim
        
        nodes_first = x[ : , pos_dim: ][edge_index[0]]
        nodes_second = x[ : , pos_dim: ][edge_index[1]]
        pos_first = x[ : , :pos_dim ][edge_index[0]]
        pos_second = x[ : , :pos_dim ][edge_index[1]]
        
        positional_encoding = ((pos_first - pos_second)**2).sum(dim=-1, keepdim=True)

        pred = (nodes_first * nodes_second).sum(dim=-1)  # dot product 
        out = self.fc(torch.cat([pred.reshape(len(pred), 1),positional_encoding.reshape(len(positional_encoding), 1)], 1))

        return out

    def loss(self, pred, link_label):
        return self.loss_fn(pred, link_label)

In [64]:
node_features = data.x

In [65]:
positional_encoding = embeddings

In [66]:
x = torch.cat((torch.tensor(embeddings), node_features), 1)
x = x.cuda(device)

In [95]:
def get_link_labels(pos_edge_index, neg_edge_index):
    # returns a tensor:
    # [1,1,1,1,...,0,0,0,0,0,..] with the number of ones is equel to the lenght of pos_edge_index
    # and the number of zeros is equal to the length of neg_edge_index
    E = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(E, dtype=torch.float, device=device)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels


def train():
    model.train()

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index, #positive edges
        num_nodes=data.num_nodes, # number of nodes
        num_neg_samples=data.train_pos_edge_index.size(1)) # number of neg_sample equal to number of pos_edges

    optimizer.zero_grad()
    
    link_logits = model(x, data.train_pos_edge_index, neg_edge_index) # decode
    link_logits = link_logits.reshape(len(link_logits),)
    link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index)
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()
    with torch.no_grad():
        model.fc.weight[0][0].clamp_(1e-5,100)
    return loss


@torch.no_grad()
def test():
    model.eval()
    perfs = []
    for prefix in ["val", "test"]:
        pos_edge_index = data[f'{prefix}_pos_edge_index']
        neg_edge_index = data[f'{prefix}_neg_edge_index']

        link_logits = model(x, pos_edge_index, neg_edge_index) # decode test or val
        
        link_probs = link_logits.sigmoid() # apply sigmoid
        
        link_labels = get_link_labels(pos_edge_index, neg_edge_index) # get link
        
        perfs.append(roc_auc_score(link_labels.cpu(), link_probs.cpu())) #compute roc_auc score
    return perfs


In [96]:
model = Net(in_feats_dim = 1433, pos_dim = 128, hidden_dim = 128,
            use_former_information = False, update_coors = False)
    
data = data.to(device)    
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr= 0.001, weight_decay= 5e-4)

In [98]:
best_val_perf = test_perf = 0
for epoch in range(1, 101):
    train_loss = train()
    val_perf, tmp_test_perf = test()
    if val_perf > best_val_perf:
        best_val_perf = val_perf
        test_perf = tmp_test_perf
    log = 'Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    if epoch % 10 == 0:
        print(log.format(epoch, train_loss, best_val_perf, test_perf))

Epoch: 010, Loss: 0.1034, Val: 0.8961, Test: 0.9276
Epoch: 020, Loss: 0.1062, Val: 0.8961, Test: 0.9276
Epoch: 030, Loss: 0.0999, Val: 0.8961, Test: 0.9276
Epoch: 040, Loss: 0.0947, Val: 0.8961, Test: 0.9276
Epoch: 050, Loss: 0.0875, Val: 0.8961, Test: 0.9276
Epoch: 060, Loss: 0.0812, Val: 0.8961, Test: 0.9276
Epoch: 070, Loss: 0.0836, Val: 0.8961, Test: 0.9276
Epoch: 080, Loss: 0.0758, Val: 0.8961, Test: 0.9276
Epoch: 090, Loss: 0.0769, Val: 0.8961, Test: 0.9276
Epoch: 100, Loss: 0.0740, Val: 0.8961, Test: 0.9276
