In [1]:
import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_networkx
from copy import deepcopy
from scipy.sparse import coo_matrix
from sklearn.metrics import roc_auc_score
import itertools
import dgl
from dgl.nn import SAGEConv
import dgl.function as fn

In [2]:
import sys
sys.path.append('../src')
sys.path.append('..')

import src.synthetic as synthetic
import src.transform as transform

In [3]:
def create_train_test_split_edge(data):
    # Create a list of positive and negative edges
    u, v = data.edges()
    u, v = u.numpy(), v.numpy()
    edge_index = np.array((u, v))
    adj = coo_matrix((np.ones(data.num_edges()), edge_index))
    adj_neg = 1 - adj.todense() - np.eye(data.num_nodes())
    neg_u, neg_v = np.where(adj_neg != 0)

    # Create train/test edge split
    test_size = int(np.floor(data.num_edges() * 0.1))
    eids = np.random.permutation(np.arange(data.num_edges())) # Create an array of 'edge IDs'

    train_pos_u, train_pos_v = edge_index[:, eids[test_size:]]
    test_pos_u, test_pos_v   = edge_index[:, eids[:test_size]]

    # Sample an equal amount of negative edges from  the graph, split into train/test
    neg_eids = np.random.choice(len(neg_u), data.num_edges())
    test_neg_u, test_neg_v = (
        neg_u[neg_eids[:test_size]],
        neg_v[neg_eids[:test_size]],
    )
    train_neg_u, train_neg_v = (
        neg_u[neg_eids[test_size:]],
        neg_v[neg_eids[test_size:]],
    )

    # Remove test edges from original graph
    train_g = deepcopy(data)
    train_g.remove_edges(eids[:test_size]) # Remove positive edges from the testing set from the network

    train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=data.num_nodes())
    train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=data.num_nodes())

    test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=data.num_nodes())
    test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=data.num_nodes())

    return train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g

def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    )
    return F.binary_cross_entropy_with_logits(scores, labels)


def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
    ).numpy()
    return roc_auc_score(labels, scores)

In [4]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_feats, h_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats, "mean")
        self.conv2 = SAGEConv(h_feats, h_feats, "mean")

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h
    

class DotPredictor(torch.nn.Module):
    def forward(self, g, h):
        with g.local_scope():
            g.ndata["h"] = h
            # Compute a new edge feature named 'score' by a dot-product between the
            # source node feature 'h' and destination node feature 'h'.
            g.apply_edges(fn.u_dot_v("h", "h", "score"))
            # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
            return g.edata["score"][:, 0]

In [18]:
def engineer_features(G):
    # TODO Work on getting this to be more feature agnostic - i.e. take the join of all this stuff and null if not present
    # Also need a stored one-hot 

    # Change type to two features, is_student, and is_org
    G_eng = deepcopy(G)
    _type = np.asarray(list(nx.get_node_attributes(G_eng, 'type').items()))
    is_student = np.asarray(_type[:,1] == 'student', dtype='float32')
    # commitment_limit = list(nx.get_node_attributes(G, 'commitment_limit').values())

    X = np.column_stack([is_student, 1-is_student])
    nx.set_node_attributes(G_eng, dict(zip(_type[:,0], X)), 'X')

    # TODO Add major in as one-hot

    # TODO Add Year in as one-hot


    return G_eng

In [19]:
G = synthetic.synthesize_graph()


In [20]:
G_eng = engineer_features(G)

In [21]:
G = dgl.from_networkx(G_eng, node_attrs=['X']) # TODO Investigate the slowness here

In [22]:
G

Graph(num_nodes=315, num_edges=958,
      ndata_schemes={'X': Scheme(shape=(2,), dtype=torch.float32)}
      edata_schemes={})

In [23]:
train_g, train_pos_g, train_neg_g, test_pos_g, test_neg_g = create_train_test_split_edge(G)

model = GraphSAGE(train_g.ndata["X"].shape[1], 32)
pred = DotPredictor()
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)

In [24]:
# ----------- 4. training -------------------------------- #
all_logits = []
for e in range(1001):
    # forward
    h = model(train_g, train_g.ndata["X"])
    pos_score = pred(train_pos_g, h)
    neg_score = pred(train_neg_g, h)
    loss = compute_loss(pos_score, neg_score)

    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if e % 5 == 0:
        print("In epoch {}, loss: {}".format(e, loss))

    # ----------- 5. check results ------------------------ #
    if e % 100 == 0:
        with torch.no_grad():
            pos_score = pred(test_pos_g, h)
            neg_score = pred(test_neg_g, h)
            print("AUC", compute_auc(pos_score, neg_score))

In epoch 0, loss: 11.222268104553223
AUC 0.07357340720221606
In epoch 5, loss: 1.07867431640625
In epoch 10, loss: 1.1495293378829956
In epoch 15, loss: 0.961980938911438
In epoch 20, loss: 0.6747779250144958
In epoch 25, loss: 0.5894042253494263
In epoch 30, loss: 0.5782065391540527
In epoch 35, loss: 0.5644545555114746
In epoch 40, loss: 0.5414692163467407
In epoch 45, loss: 0.5274854302406311
In epoch 50, loss: 0.5222107768058777
In epoch 55, loss: 0.5185007452964783
In epoch 60, loss: 0.5156992673873901
In epoch 65, loss: 0.5136780142784119
In epoch 70, loss: 0.5121491551399231
In epoch 75, loss: 0.5113844275474548
In epoch 80, loss: 0.5108682513237
In epoch 85, loss: 0.510433554649353
In epoch 90, loss: 0.5101290345191956
In epoch 95, loss: 0.5099584460258484
In epoch 100, loss: 0.5098488926887512
AUC 0.9401108033240997
In epoch 105, loss: 0.5097494125366211
In epoch 110, loss: 0.5096747875213623
In epoch 115, loss: 0.5096328854560852
In epoch 120, loss: 0.5095926523208618
In epoc