In [17]:
import torch
from tqdm import tqdm

In [18]:
device = 'cpu'

In [19]:
# Load graph data

# Load graph from file
A = torch.load('data.pt')

# Get number of nodes
n_nodes = A.shape[0]

# Number of un-ordered node pairs (possible links)
n_pairs = n_nodes*(n_nodes-1)//2

# Get indices of all un-ordered node pairs excluding self-links (shape: 2 x n_pairs)
idx_all_pairs = torch.triu_indices(n_nodes,n_nodes,1)

# Collect all links/non-links in a list (shape: n_pairs)
target = A[idx_all_pairs[0],idx_all_pairs[1]]

print(A)
print(n_pairs)
print(idx_all_pairs, idx_all_pairs.shape)
print(target, target.shape)

tensor([[0., 0., 0.,  ..., 1., 0., 1.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [1., 0., 0.,  ..., 0., 1., 0.]])
19900
tensor([[  0,   0,   0,  ..., 197, 197, 198],
        [  1,   2,   3,  ..., 198, 199, 199]]) torch.Size([2, 19900])
tensor([0., 0., 1.,  ..., 0., 0., 1.]) torch.Size([19900])


  A = torch.load('data.pt')


In [20]:
idx_train = torch.rand(n_pairs) < 0.8
idx_validation = torch.logical_not(idx_train)

In [None]:
# Shallow node embedding
class Shallow(torch.nn.Module):
    '''Shallow node embedding

    Args: 
        n_nodes (int): Number of nodes in the graph
        embedding_dim (int): Dimension of the embedding
    '''
    def __init__(self, n_nodes, embedding_dim):
        super().__init__()
        self.embedding = torch.nn.Embedding(n_nodes, embedding_dim=embedding_dim)
        self.bias = torch.nn.Parameter(torch.tensor([0.]))

    def forward(self, rx, tx):
        '''Returns the probability of a links between nodes in lists rx and tx'''
        # rx and tx are lists of node indices (shape: n_pairs)
        # in the training loop you can see how the they are the pairs of all nodes (in training or other set).
        # the model then returns the probability of a link between the nodes in rx and tx, which is its ouput
        # i.e. sigma(embedding(z_u)*embedding(z_v) + bias)
        
        # Get the embedding of the nodes
        return torch.sigmoid((self.embedding.weight[rx]*self.embedding.weight[tx]).sum(1) + self.bias)

# Embedding dimension
embedding_dim = 2

# Instantiate the model                
model = Shallow(n_nodes, embedding_dim)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

# Loss function
cross_entropy = torch.nn.BCELoss()

In [22]:
# Fit the model
# Number of gradient steps
max_step = 1000

# Optimization loop
for i in (progress_bar := tqdm(range(max_step))):    
    # Compute probability of each possible link
    link_probability = model(idx_all_pairs[0, idx_train], idx_all_pairs[1, idx_train])

    # Cross entropy loss
    loss = cross_entropy(link_probability, target[idx_train])

    # Gradient step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Display loss on progress bar
    progress_bar.set_description(f'Loss = {loss.item():.3f}')

Loss = 0.413: 100%|██████████| 1000/1000 [00:02<00:00, 493.91it/s]


In [23]:
# Compute validation error
link_probability = model(idx_all_pairs[0, idx_validation], idx_all_pairs[1, idx_validation])
loss = cross_entropy(link_probability, target[idx_validation])
print(f'Validation loss = {loss.item():.3f}')

Validation loss = 0.450


In [24]:
# Save final estimated link probabilities
link_probability = model(idx_all_pairs[0], idx_all_pairs[1])
torch.save(link_probability, 'link_probability.pt')