# CS224w GraphSage notebook and code
The code will download the obgn-mag dataset for use.

The code uses a custom version of DGL that is available [here](https://github.com/ali6947/dgl)

Or one can modify "/lib/python3.10/site-packages/dgl/subgraph.py" on their end to be the subgraph.py in the above repo. The filepath to subgraph.py of the DGL installation would be different in your setup if you use Windows or a virtual environment.

Parts of this notebook's code have been adapted from this [tutorial](https://docs.dgl.ai/en/0.8.x/tutorials/blitz/4_link_predict.html)

The notebook downloads the obgn-mag dataset and trains our graphSAGE model for link prediction

## Imports


In [None]:
from ogb.nodeproppred import DglNodePropPredDataset
import dgl
import torch
import numpy as np
from dgl import AddReverse, Compose, ToSimple
from dgl.nn import SAGEConv
from tqdm import tqdm
from numpy.random import default_rng
from copy import deepcopy
import networkx as nx

## Dataset download and graph setup

In [None]:
dataset = DglNodePropPredDataset(name = "ogbn-mag", root = 'dataset/')
graph,labels=dataset[0]
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]

valid_idx['paper']=[x for x in valid_idx['paper'] if len(graph.predecessors(x,etype='writes'))>1] #since task is co-author prediction we choose papers with multiple authors
k_hops=5 #sub graph size during train and eval
train_graph= dgl.remove_nodes(graph, valid_idx['paper'],ntype='paper') #ensures no validation papers leak into train

## Train time subgrah sampler

In [None]:
class PaperNbrSampler(dgl.dataloading.Sampler):
    def __init__(self, num_author_pair,khops):
        super().__init__()
        self.num_author_pair = num_author_pair #This is per graph
        self.khops=khops


    def sample_positive_author_pairs(self,graph,num_samples_per_graph):
        num_graphs=graph.batch_size
        positive_pairs=torch.zeros(num_samples_per_graph*num_graphs,2)

        cfiller=0
        anodes=graph.nodes(ntype='author')
        while cfiller<(positive_pairs.shape[0]):
            source=np.random.choice(anodes)
            papers=graph.successors(source,etype='writes')
            if len(papers)==0:
                continue
            paper=np.random.choice(papers)
            writers=graph.predecessors(paper,etype='writes')
            for writer in writers:
                if writer!=source:
                    positive_pairs[cfiller,0]=source
                    positive_pairs[cfiller,1]=writer
                    cfiller+=1
                    break

        return positive_pairs

    def sample_negative_author_pairs(self,graph,sample_per_graph):
        #Nodes of individual graphs are ordered by (graph,nodeID in graph) and then given one grand ordering which is what we use here.
        anodes=graph.nodes(ntype='author')
        cntauth=graph.batch_num_nodes('author')
        num_graphs=cntauth.shape[0]
        negative_pairs=torch.zeros(sample_per_graph*num_graphs,2)
        idx2fillfrom=0
        for i in range(num_graphs):
            negative_pairs[sample_per_graph*i:sample_per_graph*(i+1),0]=torch.tensor(np.random.choice(anodes[idx2fillfrom:idx2fillfrom+cntauth[i]],sample_per_graph))
            negative_pairs[sample_per_graph*i:sample_per_graph*(i+1),1]=torch.tensor(np.random.choice(anodes[idx2fillfrom:idx2fillfrom+cntauth[i]],sample_per_graph))
            idx2fillfrom+=cntauth[i]
        return negative_pairs

    def sample(self,graph,indices):
        #g is full graph. indices are the train paper nodes in curent mini batch
        subgraphs=[]
        for paper in indices:
            sg=dgl.khop_subgraph(graph,{'paper':[paper]},self.khops)[0]
            subgraphs.append(sg)
        mini_batch=dgl.batch(subgraphs)
        positive_pairs=None#self.sample_positive_author_pairs(mini_batch,self.num_author_pair)
        negative_pairs=None#self.sample_negative_author_pairs(mini_batch,self.num_author_pair)
        return mini_batch#,positive_pairs,negative_pairs

## Validation time paper sampler for loss computation

In [None]:
class ValPaperNbrSampler(dgl.dataloading.Sampler):
    def __init__(self,khops):
        super().__init__()
        self.khops=khops

    def sample(self,graph,indices):
        #g is full graph. It returns the index in indices and the author nodes with which edge exists in reality
        #returns, OG subgraph, sub graph to do MSG passing on, correct edges
        assert indices.shape[0]==1
        sg,invlabel=dgl.khop_subgraph(graph,{'paper':[indices[0]]},self.khops)
        paper_in_subg=invlabel['paper']
        authors=sg.predecessors(paper_in_subg[0],etype='writes')
        author_to_del=authors[:-1]
        edge_ids_to_del=sg.edge_ids(author_to_del,paper_in_subg.repeat(len(author_to_del)),etype='writes')
        sg_model_inp=dgl.remove_edges(sg, edge_ids_to_del,etype='writes')

        sg_to_pred=deepcopy(sg)
        for etype in sg.etypes:
            sg_to_pred=dgl.remove_edges(sg_to_pred, sg.edges(form='eid',etype=etype),etype=etype)
        all_authors=sg.nodes('author')
        non_paper_authors=np.setdiff1d(all_authors,author_to_del)
        neg_authors=np.random.choice(non_paper_authors,min(len(author_to_del)*2,non_paper_authors.shape[0]),replace=False)
        authors_to_add=np.concatenate([neg_authors,author_to_del])
        sg_to_pred.add_edges(authors_to_add,paper_in_subg.repeat(len(authors_to_add)),etype='writes')
        return sg_model_inp,sg_to_pred,invlabel['paper'][0],author_to_del,authors[-1]


class ValPaperNbrSamplerLoss(dgl.dataloading.Sampler):
    def __init__(self,khops):
        super().__init__()
        self.khops=khops

    def sample(self,graph,indices):
        #g is full graph. It returns the index in indices and the author nodes with which edge exists in reality
        #returns, OG subgraph, sub graph to do MSG passing on, correct edges
        assert indices.shape[0]==1
        sg,invlabel=dgl.khop_subgraph(graph,{'paper':[indices[0]]},self.khops)
        paper_in_subg=invlabel['paper']
        authors=sg.predecessors(paper_in_subg[0],etype='writes')
        author_to_del=authors[:-1]
        edge_ids_to_del=sg.edge_ids(author_to_del,paper_in_subg.repeat(len(author_to_del)),etype='writes')
        sg_model_inp=dgl.remove_edges(sg, edge_ids_to_del,etype='writes')
        # sg_model_inp=deepcopy(sg)
        sg_pos=deepcopy(sg)
        sg_neg=deepcopy(sg)
        for etype in sg.etypes:
            sg_pos=dgl.remove_edges(sg_pos, sg.edges(form='eid',etype=etype),etype=etype)
            sg_neg=dgl.remove_edges(sg_neg, sg.edges(form='eid',etype=etype),etype=etype)
        all_authors=sg.nodes('author')
        neg_authors=np.random.choice(np.setdiff1d(all_authors,author_to_del),len(author_to_del))
        sg_pos.add_edges(author_to_del,paper_in_subg.repeat(len(author_to_del)),etype='writes')
        sg_neg.add_edges(neg_authors,paper_in_subg.repeat(len(neg_authors)),etype='writes')
        return sg_model_inp,sg_pos,sg_neg

## Initialise the dataloaders below

In [None]:
coauth_train_loader = dgl.dataloading.DataLoader(
        train_graph,
        train_graph.nodes('paper'),
        PaperNbrSampler(2,k_hops),
        batch_size=1,
        shuffle=True,
        num_workers=0,
        device='cpu',
    )

coauth_val_loader_loss = dgl.dataloading.DataLoader(
        graph,
        valid_idx['paper'],
        ValPaperNbrSamplerLoss(k_hops),
        batch_size=1,
        shuffle=True,
        num_workers=0,
        device='cpu',
    )

## Helper functions

In [None]:
def find_node_ids(node_type_list,num_node_func,node_type,node_ids):
    #return the the corresponding node ID after homogenisation of a hetero graph
    # num_node_func is a function that takes noode type as argument and returns number of nodes of that type
    #node_type is the tpye of node on node_ids
    #node_ids (array) whose homogenous node_id we need. All nodes in node_ids should be of the same type
    i=0
    ans=torch.clone(node_ids)
    while node_type_list[i]!=node_type:
        ans+=num_node_func(node_type_list[i])
        i+=1
    return ans

def get_avg_num_shortest_paths(graph,source,targets):
    g1=graph.to_networkx()
    return np.mean([len([x for x in nx.all_shortest_paths(g1,source=source,target=i)]) for i in targets])

def construct_all_inputs(graph,authors,papers,authors_neg,papers_neg):
    #construct input from homogenous graph
    #authors,papers are positive edges
    #authors_neg,papers_neg are negative edges
    author_paper_homo_edge_ids=graph.edge_ids(authors,papers)
    author_paper_homo_edge_ids=torch.concatenate([author_paper_homo_edge_ids,author_paper_homo_edge_ids+graph.number_of_edges()//2])
    graph_inp=dgl.remove_edges(graph, author_paper_homo_edge_ids)

    graph_pos=dgl.remove_edges(graph,torch.arange(graph.number_of_edges())) #remove all edges
    graph_pos=dgl.add_edges(graph_pos,authors,papers)
    graph_pos=dgl.add_edges(graph_pos,papers,authors)

    graph_neg=dgl.remove_edges(graph,torch.arange(graph.number_of_edges())) #remove all edges
    graph_neg=dgl.add_edges(graph_neg,authors_neg,papers_neg)
    graph_neg=dgl.add_edges(graph_neg,papers_neg,authors_neg)

    return graph_inp,graph_pos,graph_neg

def get_author_paper_pairs(graph,neg=False):
    #graph is hetero,
    #if neg is true return negative papers, else positive
    src,dst=graph.edges(etype='writes')
    num_writes=src.shape[0]
    rng = default_rng()
    eids = rng.choice(num_writes, size=num_writes//2, replace=False)
    authors=src[eids]
    if neg:
        papers=dst[np.random.permutation(eids)]
    else:
        papers=dst[eids]
    return authors,papers

## Validation loop function

In [None]:
def validate(model,val_loader):
    #returns recall@10
    model.eval()
    recalls=[]
    for i, (subg_inp,sg_to_pred,paper,authors,conn_author) in enumerate(val_loader):

        if i>50:
            break

        ### homogenising each input node
        paper_node_homo=find_node_ids(subg_inp.ntypes,subg_inp.num_nodes,'paper',paper)
        author_node_homo=find_node_ids(subg_inp.ntypes,subg_inp.num_nodes,'author',authors)

        conn_author_homo=find_node_ids(subg_inp.ntypes,subg_inp.num_nodes,'author',conn_author)

        ###homogenising each input graph

        subg_pred_homo=dgl.to_homogeneous(sg_to_pred)
        subg_inp_homo=dgl.to_homogeneous(subg_inp)
        sub_pred_undir=dgl.add_reverse_edges(subg_pred_homo)
        sub_inp_undir=dgl.add_reverse_edges(subg_inp_homo)

        ### one hot encoding of node type
        node_feats=torch.zeros((subg_pred_homo.num_nodes(),len(subg_inp.ntypes))) #one hot encoding of node type
        node_feats[torch.arange(subg_pred_homo.num_nodes()),subg_pred_homo.ndata['_TYPE']]=1
        # edge_ids_to_predict=sub_pred_undir.edge_ids(author_node_homo,paper_node_homo.repeat(len(author_node_homo)))

        ### running model
        op=model(sub_inp_undir,None,node_feats,sub_pred_undir)
        op=op[:(op.shape[0]//2)] #because symmetric edges
        recall_k=min(int(np.ceil(authors.shape[0]*1.5)),10)

        ###Recall@K#####
        if len(op.shape)>1:
            pred_edges=torch.topk(op[:,0],recall_k)[1]
        else:
            pred_edges=torch.topk(op,recall_k)[1]
        pred_authors=sub_pred_undir.edges()[0][pred_edges]

        recalls.append(np.intersect1d(pred_authors,author_node_homo).shape[0]/authors.shape[0])

    return np.mean(recalls)

def validate_w_loss(model,val_loader,loss_func):
    #returns validation loss
    model.eval()
    losses=[]
    for i, (subg_inp,sg_pos,sg_neg) in enumerate(val_loader):

        if i>50:
            break
        # paper_node_homo=find_node_ids(subg_inp.ntypes,subg_inp.num_nodes,'paper',paper)
        # author_node_homo=find_node_ids(subg_inp.ntypes,subg_inp.num_nodes,'author',authors)

        ###homogenising each input graph

        subg_pos_homo=dgl.to_homogeneous(sg_pos)
        subg_neg_homo=dgl.to_homogeneous(sg_neg)
        subg_inp_homo=dgl.to_homogeneous(subg_inp)
        sub_pos_undir=dgl.add_reverse_edges(subg_pos_homo)
        sub_neg_undir=dgl.add_reverse_edges(subg_neg_homo)
        sub_inp_undir=dgl.add_reverse_edges(subg_inp_homo)

         ### one hot encoding of node type
        node_feats=torch.zeros((subg_inp_homo.num_nodes(),len(subg_inp.ntypes))) #one hot encoding of node type
        node_feats[torch.arange(subg_inp_homo.num_nodes()),subg_inp_homo.ndata['_TYPE']]=1
        # edge_ids_to_predict=sub_pred_undir.edge_ids(author_node_homo,paper_node_homo.repeat(len(author_node_homo)))

        pos_score,neg_score=model(sub_inp_undir,sub_neg_undir,node_feats,sub_pos_undir)
        loss = loss_func(pos_score, neg_score)
        losses.append(loss.item())


    return np.mean(losses)

## Model defintion

In [None]:
class MLPPredictor(torch.nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        self.W1 = torch.nn.Linear(h_feats * 2, h_feats)
        self.W2 = torch.nn.Linear(h_feats, 1)

    def apply_edges(self, edges):
        h = torch.cat([edges.src['h'], edges.dst['h']], 1)
        return {'score': self.W2(torch.nn.functional.relu(self.W1(h))).squeeze(1)}

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.apply_edges(self.apply_edges)
            return g.edata['score']

class Model(torch.nn.Module):
    def __init__(self, in_features, hidden_features):
        super().__init__()

        self.sage1=SAGEConv(in_features, hidden_features, 'mean')
        self.sage2=SAGEConv(hidden_features, hidden_features, 'mean')
        self.sage3=SAGEConv(hidden_features, hidden_features, 'mean')

        self.pred=MLPPredictor(hidden_features)
        self.bn=torch.nn.BatchNorm1d(hidden_features, eps=1)
    def forward(self, g, neg_g, x,pos_g=None):
         #g should not have the edges we want to predict which are there in pos_g. We do message passing on g
         #neg_g is the graph of negative edges
         #x is node features
         #pos_g is graph of positive edges
        if pos_g is None:
            pos_g=g
        h = self.sage1(g, x)
        h=self.sage2(g,h)
        h=self.sage3(g,h)

        if pos_g is not None and neg_g is not None:
            return self.pred(pos_g, h), self.pred(neg_g, h)
        elif neg_g is not None:
            return self.pred(neg_g,h)
        else:
            return self.pred(pos_g,h)

## Loss functions

In [None]:
def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

def compute_loss_CE(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])
    return torch.nn.functional.binary_cross_entropy_with_logits(scores, labels)

## Model initialisation

In [None]:
model = Model(4, 4)
opt = torch.optim.Adam(model.parameters())

## Training

In [None]:
for i, (subg) in (pbar:= tqdm(enumerate(coauth_train_loader))):
    model.train()
    author_pred,paper_pred=get_author_paper_pairs(subg)
    author_pred_neg,paper_pred_neg=get_author_paper_pairs(subg,neg=True)

    ##homogenise graph
    sub_homo=dgl.to_homogeneous(subg)
    sub_homo_undir=dgl.add_reverse_edges(sub_homo)

    ##create features
    node_feats=torch.zeros((subg.num_nodes(),len(subg.ntypes))) #one hot encoding of node type
    node_feats[torch.arange(subg.num_nodes()),sub_homo_undir.ndata['_TYPE']]=1

    ##homoegenise nodes
    author_pred_homo=find_node_ids(subg.ntypes,subg.num_nodes,'author',author_pred)
    paper_pred_homo=find_node_ids(subg.ntypes,subg.num_nodes,'paper',paper_pred)
    author_pred_neg_homo=find_node_ids(subg.ntypes,subg.num_nodes,'author',author_pred_neg)
    paper_pred_neg_homo=find_node_ids(subg.ntypes,subg.num_nodes,'paper',paper_pred_neg)

    ##create inputs
    input_graph,positive_graph,negative_graph=construct_all_inputs(sub_homo_undir,author_pred_homo,paper_pred_homo,author_pred_neg_homo,paper_pred_neg_homo)

    ##get score, compute loss and backpropagate loss
    pos_score, neg_score = model(input_graph, negative_graph, node_feats,positive_graph)
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    pbar.set_description(f"Loss:{loss.item():.4f},Edges:{input_graph.number_of_edges()}")
    if (i+1)%50==0:
        print(validate_w_loss(model,coauth_val_loader_loss,compute_loss))