In [149]:
import random
import networkx as nx
from torch_geometric.utils.convert import from_networkx
import torch
import numpy as np
from torch.nn import Linear
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
import matplotlib.pylab as plt

In [172]:
def gen_data_bfs(n,p):
    source = random.randint(0, n-1)
    y = torch.zeros(n, dtype=torch.int32)
                    
    G = nx.erdos_renyi_graph(n, p)
    nx.set_edge_attributes(G, 0.0, "pred")
    
    # BFS result 
    it = nx.bfs_predecessors(G, source=source)
    for node, pred in iter(it):
        y[node] = pred
        G[node][pred]['pred'] = 1.0
    
    
    # features
    x0 = np.linspace(0,1, n)
    x1 = np.zeros_like(x0)
    x1[source] = 1

    data = from_networkx(G, group_edge_attrs=["pred"])
    data.x = torch.tensor(np.array([x0,x1]), dtype=torch.float32).t()
    data.y = y
    return data

In [173]:
n_dataset = 100
n = 20
p = 0.15
data_list= []
for _ in range(n_dataset):
    data = gen_data_bfs(n,p)
    data_list.append(data)

In [174]:
from models import MPNN

class Reasoning(nn.Module):
    def __init__(self, emb_size, num_node, T):
        super(Reasoning, self).__init__()
        self.T = T 
        self.emb_size = emb_size
        self.num_node = num_node
        self.ENC = Linear(2, self.emb_size)
        self.DEC = Linear(2 * self.emb_size, 1)
        self.mpnn = MPNN(2 * self.emb_size, self.emb_size)
        self.h_0 = torch.zeros((self.num_node, self.emb_size))
        
    def encoder(self, x):
        return self.ENC(x)

    def decoder(self, src, dst):
        x = torch.cat([src, dst], dim=1)
        return self.DEC(x)

    def softmax(self, alpha_ij, edge_index):
        exp_alpha_ij = torch.exp(alpha_ij)
        alpha_ij_norm = torch.zeros_like(alpha_ij)
        for i, idx in enumerate(edge_index[0]):
            index_i = (edge_index[0] == idx).nonzero().squeeze()
            alpha_ij_norm[i] = torch.sum(exp_alpha_ij[index_i])
        return exp_alpha_ij / alpha_ij_norm

    def forward(self, x, edge_index):
        Z = self.encoder(x)
        input = torch.cat([self.h_0, Z], dim=1)
        for t in range(self.T):
            H = self.mpnn(input, edge_index)
            input = torch.cat([H, Z], dim=1)
        src = H[edge_index[0]]
        dst = H[edge_index[1]]
        alpha_ij = self.decoder(src, dst)
        return self.softmax(alpha_ij, edge_index) 

In [175]:
lr = 0.001
n_epochs = 30
criterion = nn.BCELoss()
model = Reasoning(128, data.num_nodes, 4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
dataloader = DataLoader(data_list, batch_size=1)

def train(model, optimizer, criterion, dataloader, n_epochs):
    loss_list = []
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        loss = 0.0
        for data in dataloader:
            yhat = model(data.x, data.edge_index)
            y = data.edge_attr
            loss += criterion(yhat, y)
        loss_list.append(loss.cpu().detach().numpy())
        loss.backward()
        optimizer.step()
    return np.array(loss_list)

#loss, yhat = train(model, optimizer, criterion, dataloader, n_epochs)

In [176]:
data = data_list[0]

In [177]:
data

Data(edge_index=[2, 68], edge_attr=[68, 1], num_nodes=20, x=[20, 2], y=[20])

In [181]:
lr = 0.001
n_epochs = 30
criterion = nn.BCELoss()
model = Reasoning(128, data.num_nodes, 4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
dataloader = DataLoader(data_list, batch_size=1)
data = data_list[0]
y = data.edge_attr
x = data.x
edge_index = data.edge_index
ytrue = data.y

for epoch in range(30):
    optimizer.zero_grad()
    loss = 0.0
    yhat = model(data.x, data.edge_index)
    loss = criterion(yhat, y)
    print(loss)
    loss.backward()
    optimizer.step()

tensor(0.9337, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.9097, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8909, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8772, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8769, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8749, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8670, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8629, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8619, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8615, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8608, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8601, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8585, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8557, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8525, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8496, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8475, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8458, grad_fn=<BinaryCrossEntropyBackward0>)
tensor(0.8440, grad_fn=<Bina

In [182]:
torch.cat([yhat,y], dim=1)

tensor([[0.2284, 1.0000],
        [0.1727, 1.0000],
        [0.2787, 0.0000],
        [0.2684, 1.0000],
        [0.0517, 1.0000],
        [0.8201, 1.0000],
        [0.1799, 0.0000],
        [0.7207, 1.0000],
        [0.2793, 0.0000],
        [0.2393, 0.0000],
        [0.5784, 1.0000],
        [0.1823, 1.0000],
        [0.2153, 0.0000],
        [0.7847, 1.0000],
        [1.0000, 1.0000],
        [0.3580, 0.0000],
        [0.3897, 1.0000],
        [0.0785, 1.0000],
        [0.1214, 1.0000],
        [0.0524, 1.0000],
        [0.2172, 1.0000],
        [0.2455, 1.0000],
        [0.1774, 1.0000],
        [0.0709, 1.0000],
        [0.1731, 1.0000],
        [0.1159, 1.0000],
        [1.0000, 1.0000],
        [0.1733, 1.0000],
        [0.4036, 1.0000],
        [0.1210, 0.0000],
        [0.0778, 0.0000],
        [0.1254, 1.0000],
        [0.0990, 1.0000],
        [0.2952, 0.0000],
        [0.3603, 1.0000],
        [0.0904, 0.0000],
        [0.2541, 0.0000],
        [0.4463, 1.0000],
        [0.3

In [256]:
def pred(alpha_ij, edge_index, n_node):
    y = torch.zeros(n_node)
    for node in range(n_node):
        index_i = (edge_index[0] == node).nonzero().squeeze()
        if index_i.nelement() == 1:
            y[node] = edge_index[1,index_i]
        else:
            #print(index_i, edge_index[0, index_i])
            elem = torch.argmax(alpha_ij[index_i])
            print("node src", node, "dest", edge_index[1,index_i], alpha_ij[index_i], edge_index[1, index_i[elem]])
            #print("elem", elem, index_i)
            #print("index", index_i[elem])
            y[node] = edge_index[1, index_i[elem]]
            #print("res", node, y[node])
    return y

In [257]:
pred(yhat, edge_index, data.num_nodes)

node src 0 dest tensor([ 1,  2,  6,  7, 12]) tensor([[0.2284],
        [0.1727],
        [0.2787],
        [0.2684],
        [0.0517]], grad_fn=<IndexBackward0>) tensor(6)
node src 1 dest tensor([ 0, 10]) tensor([[0.8201],
        [0.1799]], grad_fn=<IndexBackward0>) tensor(0)
node src 2 dest tensor([0, 3]) tensor([[0.7207],
        [0.2793]], grad_fn=<IndexBackward0>) tensor(0)
node src 3 dest tensor([ 2,  8, 18]) tensor([[0.2393],
        [0.5784],
        [0.1823]], grad_fn=<IndexBackward0>) tensor(8)
node src 4 dest tensor([15, 18]) tensor([[0.2153],
        [0.7847]], grad_fn=<IndexBackward0>) tensor(18)
node src 6 dest tensor([ 0,  7, 10, 13, 15]) tensor([[0.3580],
        [0.3897],
        [0.0785],
        [0.1214],
        [0.0524]], grad_fn=<IndexBackward0>) tensor(7)
node src 7 dest tensor([ 0,  6,  9, 11, 17, 18]) tensor([[0.2172],
        [0.2455],
        [0.1774],
        [0.0709],
        [0.1731],
        [0.1159]], grad_fn=<IndexBackward0>) tensor(6)
node src 9 dest t

tensor([ 6.,  0.,  0.,  8., 18.,  9.,  7.,  6.,  3.,  7.,  6.,  7.,  0.,  6.,
        17.,  6.,  9.,  7.,  7.,  9.])

In [258]:
data.y

tensor([ 7,  0,  0, 18, 18,  9,  7,  0,  3,  7,  6,  7,  0,  6, 17,  6,  9,  7,
         7,  9], dtype=torch.int32)

In [222]:
ytrue

tensor([ 7,  0,  0, 18, 18,  9,  7,  0,  3,  7,  6,  7,  0,  6, 17,  6,  9,  7,
         7,  9], dtype=torch.int32)