In [60]:
# Define Neo4j connections
import pandas as pd
from neo4j import GraphDatabase
host = 'bolt://localhost:7687'
user = 'neo4j'
password = 'letmein'
driver = GraphDatabase.driver(host,auth=(user, password))

def run_query(query, params={}):
    with driver.session() as session:
        result = session.run(query, params)
        return pd.DataFrame([r.values() for r in result], columns=result.keys())

In [2]:
def load_frames():
    """
    Loads the nodes and edges from Neo4j.
    :return: nodes and edges frame
    """
    dfn = run_query("""
    MATCH (u:User)
    RETURN u.id AS id, u.age AS age, u.gender AS gender
    """)
    dfn = dfn.set_index("id")

    dfe = run_query("""
    MATCH (s:User)-[:FRIEND]->(t:User)
    RETURN s.id as source, t.id as target
    """)

    return dfn, dfe

In [3]:
from IPython.display import display

dfn, dfe = load_frames()
display(dfe.head())
display(dfn.head())

Unnamed: 0,source,target
0,1,16
1,1,10
2,1,12
3,1,8
4,1,7


Unnamed: 0_level_0,age,gender
id,Unnamed: 1_level_1,Unnamed: 2_level_1
1,26,1
16,23,1
3,29,1
4,26,0
17,27,0


In [4]:
import torch
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = 'cpu'
print(device)

cuda


In [5]:
def transform_nodes(node_frame):
    print("Transforming nodes")
    # sorting the index does not make sense here
    node_index_map = {str(index): i for i, index in enumerate(node_frame.index.unique())}
    gender_series = node_frame["gender"]
    gender_tensor = torch.zeros(len(gender_series), 2, dtype = torch.float)
    for i, v in enumerate(gender_series.values):
        gender_tensor[i, 0 if np.isnan(v) else int(v)] = 1.0
    age_tensor = torch.tensor(node_frame['age'].values, dtype = torch.float).reshape(len(gender_series), -1)
    x = torch.cat((gender_tensor, age_tensor), dim = -1)  # 1x3 tensor
    return x, node_index_map

nodes_x, nodes_mapping = transform_nodes(dfn)

Transforming nodes


In [6]:
def transform_edges(edge_frame, nodes_mapping):
    print("Transforming edges")

    src = [nodes_mapping[src_id] if src_id in nodes_mapping else -1 for src_id in edge_frame.source]
    dst = [nodes_mapping[tgt_id] if tgt_id in nodes_mapping else -1 for tgt_id in edge_frame.target]
    edge_index = torch.tensor([src, dst])

    return edge_index, None

edges_index, edges_label = transform_edges(dfe, nodes_mapping)

Transforming edges


In [7]:
from torch_geometric.data import Data

data = Data(x = nodes_x, edge_index = edges_index, edge_attr = edges_label, y = None).to(device, non_blocking=True)

In [8]:
 def create_node_masks(d):
        print("Creating classification masks")
        amount = len(d.x)
        # actually the index to the nodes
        nums = np.arange(amount)
        np.random.shuffle(nums)

        train_size = int(amount * 0.7)
        test_size = int(amount * 0.85) - train_size
        val_size = amount - train_size - test_size

        train_set = nums[0:train_size]
        test_set = nums[train_size:train_size + test_size]
        val_set = nums[train_size + test_size:]

        assert len(train_set) + len(test_set) + len(val_set) == amount, "The split should be coherent."

        train_mask = torch.zeros(amount, dtype = torch.long, device = device)
        for i in train_set:
            train_mask[i] = 1.

        test_mask = torch.zeros(amount, dtype = torch.long, device = device)
        for i in test_set:
            test_mask[i] = 1.

        val_mask = torch.zeros(amount, dtype = torch.long, device = device)
        for i in val_set:
            val_mask[i] = 1.

        d.train_mask = train_mask
        d.test_mask = test_mask
        d.val_mask = val_mask
        
create_node_masks(data)

Creating classification masks


In [9]:
import torch_geometric.transforms as T

transform = T.Compose([
    T.ToUndirected(merge = True),
    T.ToDevice(device),
    T.RandomLinkSplit(num_val = 0.0005, num_test = 0.0001, is_undirected = True, add_negative_train_samples = False),
])
transform(data)

(Data(x=[1099121, 3], edge_index=[2, 21575162], train_mask=[1099121], test_mask=[1099121], val_mask=[1099121], edge_label=[10787581], edge_label_index=[2, 10787581]),
 Data(x=[1099121, 3], edge_index=[2, 21575162], train_mask=[1099121], test_mask=[1099121], val_mask=[1099121], edge_label=[10794], edge_label_index=[2, 10794]),
 Data(x=[1099121, 3], edge_index=[2, 21585956], train_mask=[1099121], test_mask=[1099121], val_mask=[1099121], edge_label=[2158], edge_label_index=[2, 2158]))

In [10]:
from torch_geometric.loader import NeighborLoader

# the larger the batch size the faster things will be
batch_size = 2048

# define batch loaders for the three sets
train_loader = NeighborLoader(data, num_neighbors = [10] * 2, shuffle = True, input_nodes = data.train_mask, batch_size = batch_size)
val_loader = NeighborLoader(data, num_neighbors = [10] * 2, input_nodes = data.val_mask, batch_size = batch_size)
test_loader = NeighborLoader(data, num_neighbors = [10] * 2, input_nodes = data.test_mask, batch_size = batch_size)

In [11]:
from datetime import datetime
from tqdm import tqdm
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import GCNConv
from sklearn.metrics import roc_auc_score, f1_score
import os


class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        # chaining two convolutions with a standard relu activation

        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        # cosine similarity
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim = -1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple = False).t()

In [12]:
model = Net(data.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params = model.parameters(), lr = 0.01)
# BCELoss creates a criterion that measures the Binary Cross Entropy between the target and the output.
criterion = torch.nn.BCEWithLogitsLoss()

In [13]:
def train():
    """
    Single epoch model training in batches.
    :return: total loss for the epoch
    """
    model.train()
    total_examples = total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        batch = batch.to(device)
        batch_size = batch.batch_size
        z = model.encode(batch.x, batch.edge_index)
        neg_edge_index = negative_sampling(edge_index = batch.edge_index, num_nodes = batch.num_nodes, num_neg_samples = None, method = 'sparse')
        edge_label_index = torch.cat([batch.edge_index, neg_edge_index], dim = -1, )
        edge_label = torch.cat([torch.ones(batch.edge_index.size(1)), torch.zeros(neg_edge_index.size(1))], dim = 0).to(device)
        out = model.decode(z, edge_label_index).view(-1)
        # loss = criterion(out[:batch_size], edge_label[:batch_size])
        loss = criterion(out, edge_label)
        # standard torch mechanics here
        loss.backward()
        optimizer.step()
        total_examples += batch_size
        total_loss += float(loss) * batch_size
    return total_loss / total_examples

In [40]:
@torch.no_grad()
def test(loader):
    """
    Evalutes the model on the test set.
    :param loader: the batch loader
    :return: a score
    """
    model.eval()
    scores = []
    threshold = torch.tensor([0.7]).to(device)
    for batch in tqdm(loader):
        batch.to(device)
        z = model.encode(batch.x, batch.edge_index)
        out = model.decode(z, batch.edge_index).view(-1).sigmoid()
        pred = (out > threshold).float() * 1
        score = f1_score(np.ones(batch.edge_index.size(1)), pred.cpu().numpy())
        scores.append(score)
    return np.average(scores)

In [49]:
def predictions(max = 1000, threshold = 0.99):
    """
    Creates predictions for the specified run.
    :param run_id: model id
    :param max: the maximum amount of predictions to output
    """
    pred_edges = []

    loader = NeighborLoader(data, num_neighbors = [10] * 2, shuffle = True, input_nodes = None, batch_size = batch_size)
    threshold_tensor = torch.tensor([threshold]).to(device)
    for batch in tqdm(loader):
        batch.to(device)
        z = model.encode(batch.x, batch.edge_index)
        # collecting negative edge tuples ensure that the decode are actual non-existing edges
        neg_edge_index = negative_sampling(edge_index = batch.edge_index, num_nodes = None, num_neg_samples = None, method = 'sparse')
        out = model.decode(z, neg_edge_index).view(-1).sigmoid()
        pred = ((out > threshold_tensor).float() * 1).cpu().numpy()
        found = np.argwhere(pred == 1)
        if found.size > 0:
            edge_tuples = neg_edge_index.t().cpu().numpy()
            select_index = found.reshape(1, found.size)[0]
            edges = edge_tuples[select_index]
            pred_edges += edges.tolist()
            if len(pred_edges) >= max:
                break
    
    return pd.DataFrame.from_dict([{'source': a, 'target': b} for a,b in pred_edges])

In [42]:
def run():
    """
        Run the training and makes predictions.
    """
    run_id = int(datetime.timestamp(datetime.now()))
    start_time = datetime.now()
    epochs = 10
    #with trange(epochs + 1) as t:
    for epoch in range(epochs):
        try:
            #t.set_description('Epoch %i/%i train' % (epoch, epochs))
            loss = train()
            #t.set_description('Epoch %i/%i test' % (epoch, epochs))
            val_acc = test(test_loader)
            #t.set_postfix(loss = loss, accuracy = val_acc)
            print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Acc: {val_acc:.4f}")
        except KeyboardInterrupt:
            break
    torch.save(model.state_dict(), f"model_{run_id}")
    time_elapsed = datetime.now() - start_time
    print("Creating predictions")
    print(f"\nRun {run_id}:")
    print(f"\tEpochs: {epoch}")
    print(f"\tTime: {time_elapsed}")
    print(f"\tAccuracy: {val_acc * 100:.01f}")

In [43]:
run()

100%|████████████████████████████████████████████████████████████████████████████████| 537/537 [00:12<00:00, 43.54it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 537/537 [00:05<00:00, 103.52it/s]
  1%|▌                                                                                 | 4/537 [00:00<00:15, 33.33it/s]

Epoch: 000, Loss: 0.5306, Acc: 0.9035


100%|████████████████████████████████████████████████████████████████████████████████| 537/537 [00:12<00:00, 42.16it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 537/537 [00:05<00:00, 101.66it/s]
  1%|▌                                                                                 | 4/537 [00:00<00:15, 35.41it/s]

Epoch: 001, Loss: 0.5471, Acc: 0.8040


100%|████████████████████████████████████████████████████████████████████████████████| 537/537 [00:14<00:00, 37.75it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 537/537 [00:05<00:00, 102.78it/s]
  1%|▌                                                                                 | 4/537 [00:00<00:15, 35.10it/s]

Epoch: 002, Loss: 0.5342, Acc: 0.9006


100%|████████████████████████████████████████████████████████████████████████████████| 537/537 [00:14<00:00, 37.67it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 537/537 [00:05<00:00, 103.72it/s]
  1%|▌                                                                                 | 4/537 [00:00<00:15, 34.49it/s]

Epoch: 003, Loss: 0.5306, Acc: 0.8986


100%|████████████████████████████████████████████████████████████████████████████████| 537/537 [00:14<00:00, 37.40it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 537/537 [00:05<00:00, 103.96it/s]
  1%|▌                                                                                 | 4/537 [00:00<00:14, 36.05it/s]

Epoch: 004, Loss: 0.5258, Acc: 0.9004


100%|████████████████████████████████████████████████████████████████████████████████| 537/537 [00:14<00:00, 37.50it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 537/537 [00:05<00:00, 102.81it/s]
  1%|█                                                                                 | 7/537 [00:00<00:15, 35.30it/s]

Epoch: 005, Loss: 0.5240, Acc: 0.8875


100%|████████████████████████████████████████████████████████████████████████████████| 537/537 [00:14<00:00, 36.81it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 537/537 [00:05<00:00, 102.52it/s]
  1%|▌                                                                                 | 4/537 [00:00<00:15, 33.34it/s]

Epoch: 006, Loss: 0.5188, Acc: 0.8915


100%|████████████████████████████████████████████████████████████████████████████████| 537/537 [00:14<00:00, 37.24it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 537/537 [00:05<00:00, 102.36it/s]
  1%|▌                                                                                 | 4/537 [00:00<00:16, 32.26it/s]

Epoch: 007, Loss: 0.5039, Acc: 0.8822


100%|████████████████████████████████████████████████████████████████████████████████| 537/537 [00:14<00:00, 36.09it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 537/537 [00:05<00:00, 101.14it/s]
  1%|▌                                                                                 | 4/537 [00:00<00:15, 35.41it/s]

Epoch: 008, Loss: 0.5037, Acc: 0.8831


100%|████████████████████████████████████████████████████████████████████████████████| 537/537 [00:14<00:00, 36.74it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 537/537 [00:05<00:00, 101.92it/s]

Epoch: 009, Loss: 0.4971, Acc: 0.8811
Creating predictions

Run 1640882244:
	Epochs: 9
	Time: 0:03:13.095027
	Accuracy: 88.1





In [51]:
preds_df = predictions()
print(preds_df.head())

  0%|                                                                                          | 0/537 [00:00<?, ?it/s]

   source  target
0    9257   85860
1   18041   83302
2   47626    8697
3   16195   96111
4  108720   15420





In [65]:
# Check how many already exist

run_query("""
UNWIND $data AS row
MATCH (s:User)-[:FRIEND]-(t:User)
WHERE s.id = toString(row[0]) AND t.id = toString(row[1])
RETURN count(*) AS already_exists
""", {'data': preds_df.values.tolist()})


Unnamed: 0,already_exists
0,1
