In [26]:
import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import f1_score
    
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, NNConv
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score

from torch_geometric.utils import negative_sampling
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges

In [27]:
%run dataset_ep.ipynb

DATASET LOADED
Data(x=[43, 1], edge_index=[2, 45], edge_attr=[45, 1], y=[45, 16])


Processing...
Done!


In [28]:
torch.manual_seed(42)
dataset = dataset_ep.shuffle()

train_dataset = dataset[:int(len(dataset) * 0.65)]
test_dataset = dataset[int(len(dataset) * 0.35):]

print(f'Number of graphs total: {len(dataset)}')
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')
print(dataset[0])
print(len(dataset))

Number of graphs total: 54
Number of training graphs: 35
Number of test graphs: 36
Data(x=[48, 1], edge_index=[2, 51], edge_attr=[51, 1], y=[51, 16])
54


In [29]:
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset, Data

NUM_GRAPHS_PER_BATCH = len(dataset)
NUM_HIDDEN_CHANNELS = 64

train_loader = DataLoader(train_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=False)

In [33]:
x = []
edge_index = []
edge_attr = []
y = []
for batch in train_loader:
    x = batch.x
    edge_index = batch.edge_index
    edge_attr = batch.edge_attr
    y = batch.y
    print(batch)
    
# print(x)
# print(edge_index)
# print(edge_attr)
# print(y)

data = Data(x=x, 
            edge_index=edge_index,
            edge_attr=edge_attr,
            y=y,
            ) 

print(data)

DataBatch(x=[1006, 1], edge_index=[2, 1053], edge_attr=[1053, 1], y=[1053, 16], batch=[1006], ptr=[36])
Data(x=[1006, 1], edge_index=[2, 1053], edge_attr=[1053, 1], y=[1053, 16])


In [34]:
# use train_test_split_edges to create neg and positive edges
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
print(data)

Data(x=[1006, 1], val_pos_edge_index=[2, 29], val_pos_edge_attr=[29, 1], test_pos_edge_index=[2, 58], test_pos_edge_attr=[58, 1], train_pos_edge_index=[2, 1000], train_pos_edge_attr=[1000, 1], train_neg_adj_mask=[1006, 1006], val_neg_edge_index=[2, 29], test_neg_edge_index=[2, 58])


In [37]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, 128)
        self.conv2 = GCNConv(128, 64)

    def encode(self):
        x = self.conv1(data.x, data.train_pos_edge_index) # convolution 1
        x = x.relu()
        return self.conv2(x, data.train_pos_edge_index) # convolution 2

    def decode(self, z, pos_edge_index, neg_edge_index): # only pos and neg edges
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # concatenate pos and neg edges
        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)  # dot product 
        return logits

    def decode_all(self, z): 
        prob_adj = z @ z.t() # get adj NxN
        return (prob_adj > 0).nonzero(as_tuple=False).t() # get predicted edge_list 

In [43]:
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)

In [42]:

def get_link_labels(pos_edge_index, neg_edge_index):
    # returns a tensor:
    # [1,1,1,1,...,0,0,0,0,0,..] with the number of ones is equel to the lenght of pos_edge_index
    # and the number of zeros is equal to the length of neg_edge_index
    E = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(E, dtype=torch.float, device=device)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels


def train():
    model.train()

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index, #positive edges
        num_nodes=data.num_nodes, # number of nodes
        num_neg_samples=data.train_pos_edge_index.size(1)) # number of neg_sample equal to number of pos_edges

    optimizer.zero_grad()
    
    z = model.encode() #encode
    link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index) # decode
    
    link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index)
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()

    return loss


@torch.no_grad()
def test():
    model.eval()
    perfs = []
    for prefix in ["val", "test"]:
        pos_edge_index = data[f'{prefix}_pos_edge_index']
        neg_edge_index = data[f'{prefix}_neg_edge_index']

        z = model.encode() # encode train
        link_logits = model.decode(z, pos_edge_index, neg_edge_index) # decode test or val
        link_probs = link_logits.sigmoid() # apply sigmoid
        
        link_labels = get_link_labels(pos_edge_index, neg_edge_index) # get link
        
        perfs.append(roc_auc_score(link_labels.cpu(), link_probs.cpu())) #compute roc_auc score
    return perfs


In [44]:

best_val_perf = test_perf = 0
for epoch in range(1, 1001):
    train_loss = train()
    val_perf, tmp_test_perf = test()
    if val_perf > best_val_perf:
        best_val_perf = val_perf
        test_perf = tmp_test_perf
    log = 'Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    if epoch % 10 == 0:
        print(log.format(epoch, train_loss, best_val_perf, test_perf))



Epoch: 010, Loss: 4.1557, Val: 0.5327, Test: 0.5593
Epoch: 020, Loss: 1.5846, Val: 0.5327, Test: 0.5593
Epoch: 030, Loss: 0.8188, Val: 0.5327, Test: 0.5593
Epoch: 040, Loss: 0.7282, Val: 0.5386, Test: 0.5664
Epoch: 050, Loss: 0.6963, Val: 0.5458, Test: 0.5670
Epoch: 060, Loss: 0.6902, Val: 0.5755, Test: 0.5771
Epoch: 070, Loss: 0.6887, Val: 0.5815, Test: 0.5870
Epoch: 080, Loss: 0.6889, Val: 0.5815, Test: 0.5870
Epoch: 090, Loss: 0.6874, Val: 0.5886, Test: 0.6101
Epoch: 100, Loss: 0.6875, Val: 0.5886, Test: 0.6101
Epoch: 110, Loss: 0.6861, Val: 0.5886, Test: 0.6101
Epoch: 120, Loss: 0.6831, Val: 0.5981, Test: 0.6271
Epoch: 130, Loss: 0.6818, Val: 0.6005, Test: 0.6304
Epoch: 140, Loss: 0.6805, Val: 0.6052, Test: 0.6378
Epoch: 150, Loss: 0.6817, Val: 0.6052, Test: 0.6378
Epoch: 160, Loss: 0.6737, Val: 0.6100, Test: 0.6476
Epoch: 170, Loss: 0.6717, Val: 0.6159, Test: 0.6473
Epoch: 180, Loss: 0.6730, Val: 0.6159, Test: 0.6473
Epoch: 190, Loss: 0.6668, Val: 0.6159, Test: 0.6473
Epoch: 200, 

In [45]:
z = model.encode()
final_edge_index = model.decode_all(z)
print(final_edge_index.shape)

torch.Size([2, 588560])


In [46]:
torch.save(model, 'models/model_ep.pth')

In [49]:
loaded_model = torch.load('models/model_ep.pth')
test_acc = test()
print(test_acc)

[0.6159334126040428, 0.656807372175981]
