In [None]:
import os
import pandas as pd

edges = pd.read_csv(os.path.join(os.pardir, 'data', 'edgelist.txt'), sep='\t').rename(columns={'1001': 'src', '9304045': 'trg'})
edges.head()

In [None]:
import sys
scripts_dir = os.path.join(os.path.dirname(os.path.abspath('')), 'scripts')
if not scripts_dir in sys.path:
    sys.path.append(scripts_dir)

In [None]:
import networkx as nx
from get_graph import get_digraph

g: nx.DiGraph = get_digraph()

In [None]:
import dgl

G = dgl.from_networkx(g)

In [None]:
import numpy as np
word_embs = np.load(os.path.join(os.pardir, 'embeddings.npy'))

In [None]:
import torch
G.ndata['word_embs'] = torch.from_numpy(np.vstack([word_embs, np.zeros((len(g.nodes()) - len(word_embs), 50))])).float()

In [None]:
from train_test_split import make_split

adj_train, train_edges, train_edges_false, \
        val_edges, val_edges_false, test_edges, test_edges_false = make_split(nx.to_scipy_sparse_array(g))

In [None]:
train = torch.cat([torch.as_tensor(train_edges), torch.as_tensor(train_edges_false)])

train_u = train[:, 0]
train_v = train[:, 1]
train_label = torch.cat([torch.zeros(len(train_edges)), torch.ones(len(train_edges_false))])

In [None]:
test = torch.cat([torch.as_tensor(test_edges), torch.as_tensor(test_edges_false)])

test_u = test[:, 0]
test_v = test[:, 1]
test_label = torch.cat([torch.zeros(len(test_edges)), torch.ones(len(test_edges_false))])

In [None]:
from gat import GraphSAGE

net = GraphSAGE(50, 64)

In [None]:
import torch.nn.functional as F
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

all_logits = []
for e in range(100):
    # forward
    logits = net(G, G.ndata['word_embs'].float())
    pred = torch.sigmoid((logits[train_u] * logits[train_v]).sum(dim=1))
    # compute loss
    loss = F.binary_cross_entropy(pred, train_label)
    
    # backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    all_logits.append(logits.detach())
    
    if e % 5 == 0:
        print('In epoch {}, loss: {}'.format(e, loss))

In [None]:
pred = torch.sigmoid((logits[test_u] * logits[test_v]).sum(dim=1))
print('Accuracy', ((pred >= 0.5) == test_label).sum().item() / len(pred))