In [1]:
import os.path as osp

import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv
from torch_geometric.utils import train_test_split_edges

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [38]:
# load the Cora dataset
dataset = 'Cora'
path = osp.join('.', 'data', dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]
print(dataset.data)

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])


In [39]:
# use train_test_split_edges to create neg and positive edges
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data)
print(data)


Data(x=[2708, 1433], val_pos_edge_index=[2, 263], test_pos_edge_index=[2, 527], train_pos_edge_index=[2, 8976], train_neg_adj_mask=[2708, 2708], val_neg_edge_index=[2, 263], test_neg_edge_index=[2, 527])




In [169]:
data.train_pos_edge_index[:, :20]


tensor([[   0,    0,    1,    1,    1,    2,    2,    2,    2,    2,    3,    4,
            4,    4,    4,    4,    5,    5,    6,    6],
        [ 633, 2582,    2,  652,  654,    1,  332, 1454, 1666, 1986, 2544, 1016,
         1256, 1761, 2175, 2176, 1629, 2546, 1042, 1416]])

In [182]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, 128)
        self.conv2 = GCNConv(128, 64)

    def encode(self):
        x = self.conv1(data.x, data.train_pos_edge_index) # convolution 1
        x = x.relu()
        return self.conv2(x, data.train_pos_edge_index) # convolution 2

    def decode(self, z, pos_edge_index, neg_edge_index, test=False): # only pos and neg edges
        # if not test:
        #     print('pos_edge_index', pos_edge_index.shape)
        #     print('neg_edge_index', neg_edge_index.shape)
            # print('pos_edge_index', pos_edge_index)
            # print('neg_edge_index', neg_edge_index)

        edge_index = torch.cat([neg_edge_index,pos_edge_index], dim=-1) # concatenate pos and neg edges
        # if not test:
        #     print('edge_index 0 shape:', edge_index[0].shape)
        #     # print("Edge index 0: ", edge_index[0])
        #     print("Edge index 1 shape: ", edge_index[1].shape)
        #     # print("Edge index 1: ", edge_index[1])
        #     print("z shape: ", z.shape)

        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)  # dot product 
        # if not test:
        #     print('logits shape:', logits.shape)
        #     print('logits', logits)
        #     4/0
        return logits

    def decode_all(self, z): 
        prob_adj = z @ z.t() # get adj NxN
        return (prob_adj > 0).nonzero(as_tuple=False).t() # get predicted edge_list 

In [186]:
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)

In [187]:
def get_link_labels(neg_edge_index, pos_edge_index):
    # returns a tensor:
    # [1,1,1,1,...,0,0,0,0,0,..] with the number of ones is equel to the lenght of pos_edge_index
    # and the number of zeros is equal to the length of neg_edge_index
    E =  + neg_edge_index.size(1) + pos_edge_index.size(1)
    link_labels = torch.zeros(E, dtype=torch.float, device=device)
    link_labels[:neg_edge_index.size(1)] = 0.
    link_labels[neg_edge_index.size(1):] = 1.
    return link_labels


def train():
    model.train()

    # for el in batch_size:

    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index, #positive edges
        num_nodes=data.num_nodes, # number of nodes
        num_neg_samples=data.train_pos_edge_index.size(1)) # number of neg_sample equal to number of pos_edges

    optimizer.zero_grad()
    
    z = model.encode() #encode
    link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index) # decode
    
    link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index)
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()

    return loss


@torch.no_grad()
def test():
    model.eval()
    perfs = []
    for prefix in ["val", "test"]:
        pos_edge_index = data[f'{prefix}_pos_edge_index']
        neg_edge_index = data[f'{prefix}_neg_edge_index']

        z = model.encode() # encode train
        link_logits = model.decode(z, pos_edge_index, neg_edge_index, test=True) # decode test or val
        link_probs = link_logits.sigmoid() # apply sigmoid
        
        link_labels = get_link_labels(neg_edge_index, pos_edge_index) # get link
        
        perfs.append(roc_auc_score(link_labels.cpu(), link_probs.cpu())) #compute roc_auc score
    return perfs

In [188]:
best_val_perf = test_perf = 0
for epoch in range(1, 101):
    train_loss = train()
    val_perf, tmp_test_perf = test()
    if val_perf > best_val_perf:
        best_val_perf = val_perf
        test_perf = tmp_test_perf
    log = 'Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    if epoch % 10 == 0:
        print(log.format(epoch, train_loss, best_val_perf, test_perf))

Epoch: 010, Loss: 0.6858, Val: 0.7038, Test: 0.7385
Epoch: 020, Loss: 0.6301, Val: 0.7038, Test: 0.7385
Epoch: 030, Loss: 0.5614, Val: 0.7305, Test: 0.7376
Epoch: 040, Loss: 0.5160, Val: 0.8032, Test: 0.8014
Epoch: 050, Loss: 0.4817, Val: 0.8511, Test: 0.8561
Epoch: 060, Loss: 0.4665, Val: 0.8694, Test: 0.8751
Epoch: 070, Loss: 0.4630, Val: 0.8723, Test: 0.8767
Epoch: 080, Loss: 0.4622, Val: 0.8723, Test: 0.8767
Epoch: 090, Loss: 0.4470, Val: 0.8753, Test: 0.8816
Epoch: 100, Loss: 0.4435, Val: 0.8763, Test: 0.8816


In [175]:
t1 = 10*torch.rand((2,5))
t2 = 10*torch.rand((2,5))
t1 = t1.round()
t2 = -t2.round()
a = torch.cat([t1,t2], dim=-1)
a.shape

t1[1]

tensor([ 9.,  8.,  4.,  5., 10.])

In [143]:
nodes = torch.rand(2704, 64)
n1 = nodes[[12,13,14]]
n2 = nodes[[15,16,17]]
ressss = n1*n2
ressss.sum(-1)

tensor([17.6733, 15.0368, 14.3092])