In [26]:
import json
import networkx as nx
import dgl
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

In [27]:
json_file = 'elyzee_data/elyzee_train.json'
graph_networkx = nx.node_link_graph(json.load(open(json_file)))

graph_networkx = graph_networkx.to_directed()

In [49]:
idx = list(graph_networkx.nodes)
# print(idx)

node_0 = graph_networkx.nodes[idx[0]]
print(len(node_0['features']), node_0['affiliation'])

999 fi
7


In [50]:
# plt.figure(figsize=[15,7])
# nx.draw(graph_networkx, with_labels=True, font_weight='bold')

In [51]:
graph_dgl = dgl.from_networkx(graph_networkx, node_attrs=['features'])
print(graph_dgl)
print(graph_dgl.ndata['features'].shape)

Graph(num_nodes=5507, num_edges=19284,
      ndata_schemes={'features': Scheme(shape=(999,), dtype=torch.float32)}
      edata_schemes={})
torch.Size([5507, 999])


In [100]:
labels = [node['affiliation'] for node in graph_networkx.nodes.values()]
num_classes = len(set(labels))
print("number classes :", num_classes)

labels_dict = {label: i for i, label in enumerate(set(labels))}
print(labels_dict)

labels = [labels_dict[label] for label in labels]

labels = F.one_hot(torch.tensor(labels))
labels = labels.float()
print(labels.shape)

number classes : 7
{'fi': 0, 'indetermined': 1, 'ps': 2, 'multi_affiliations': 3, 'lr': 4, 'fn': 5, 'em': 6}
torch.Size([5507, 7])


In [101]:


# Define the message & reduce function
# NOTE: we ignore the GCN's normalization constant c_ij for this tutorial.
def gcn_message(edges):
    # The argument is a batch of edges.
    # This computes a (batch of) message called 'msg' using the source node's feature 'h'.
    return {'msg' : edges.src['h']}

def gcn_reduce(nodes):
    # The argument is a batch of nodes.
    # This computes the new 'h' features by summing received 'msg' in each node's mailbox.
    return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)}

# Define the GCNLayer module
class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, g, inputs):
        # g is the graph and the inputs is the input node features
        # first set the node features
        g.ndata['h'] = inputs
        g.send_and_recv(g.edges(), gcn_message, gcn_reduce)
        # get the result node features
        h = g.ndata.pop('h')
        # perform linear transformation
        return self.linear(h)

In [102]:
# Define a 2-layer GCN model
class GCN(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes):
        super(GCN, self).__init__()
        self.gcn1 = GCNLayer(in_feats, hidden_size)
        self.gcn2 = GCNLayer(hidden_size, num_classes)

    def forward(self, g, inputs):
        h = self.gcn1(g, inputs)
        h = torch.relu(h)
        h = self.gcn2(g, h)
        return h


In [103]:
model = GCN(in_feats=5507, hidden_size=1000, num_classes=7)

In [104]:
inputs = torch.eye(5507)
labeled_nodes = torch.arange(0, 5507).reshape(-1)
print(labeled_nodes.shape)

torch.Size([5507])


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
all_logits = []
for epoch in range(30):
    logits = model(graph_dgl, inputs)
    # we save the logits for visualization later
    all_logits.append(logits.detach())
    logp = F.log_softmax(logits, 1)
    # we only compute loss for labeled nodes
    loss = F.cross_entropy(logp[labeled_nodes], labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))       

Epoch 0 | Loss: 4.1972
Epoch 1 | Loss: 2.7402
Epoch 2 | Loss: 2.2290
Epoch 3 | Loss: 2.0115
Epoch 4 | Loss: 1.9147
Epoch 5 | Loss: 1.8916
Epoch 6 | Loss: 1.8730
Epoch 7 | Loss: 1.8698
Epoch 8 | Loss: 1.8627
Epoch 9 | Loss: 1.8568
Epoch 10 | Loss: 1.8516
Epoch 11 | Loss: 1.8475
Epoch 12 | Loss: 1.8457
Epoch 13 | Loss: 1.8416
Epoch 14 | Loss: 1.8378
Epoch 15 | Loss: 1.8344
Epoch 16 | Loss: 1.8324
Epoch 17 | Loss: 1.8302
Epoch 18 | Loss: 1.8270
Epoch 19 | Loss: 1.8251
Epoch 20 | Loss: 1.8235
Epoch 21 | Loss: 1.8219
Epoch 22 | Loss: 1.8213
