In [1]:
import jraph
import torch
import numpy as np
import networkx as nx
import community as community_louvain
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import to_networkx
from networkx import karate_club_graph, to_numpy_matrix
from absl import app
import jax
import jax.numpy as jnp
import haiku as hk
from absl import app
import logging
from jax.interpreters.xla import DeviceArray
from jax.experimental import optimizers

In [2]:
class KarateClub(InMemoryDataset):

    def __init__(self, transform=None):
        super(KarateClub, self).__init__('.', transform, None, None)

        G = nx.karate_club_graph()

        x = torch.eye(G.number_of_nodes(), dtype=torch.float)
        order = sorted(list(G.nodes()))
        adjency = to_numpy_matrix(G,nodelist=order)
        
        adj = nx.to_scipy_sparse_matrix(G).tocoo()
        row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long)
        col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long)
        edge_index = torch.stack([row, col], dim=0)

        # Compute communities.
        partition = community_louvain.best_partition(G)
        y = torch.tensor([partition[i] for i in range(G.number_of_nodes())])
        senders = edge_index[0]
        receivers = edge_index[1]
        # Select a single training node for each community
        # (we just use the first one).
        train_mask = torch.zeros(y.size(0), dtype=torch.bool)
        num_nodes = G.number_of_nodes()
        for i in range(int(y.max()) + 1):
            train_mask[(y == i).nonzero(as_tuple=False)[0]] = True

        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
                    adjency = adjency)

        self.data, self.slices = self.collate([data])

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)
dataset = KarateClub()
dataset = dataset[0]
dataset

Data(adjency=[[0. 1. 1. ... 1. 0. 0.]
 [1. 0. 1. ... 0. 0. 0.]
 [1. 1. 0. ... 0. 1. 0.]
 ...
 [1. 0. 0. ... 0. 1. 1.]
 [0. 0. 1. ... 1. 0. 1.]
 [0. 0. 0. ... 1. 1. 0.]], edge_index=[2, 156], train_mask=[34], x=[34, 34], y=[34])

In [3]:
def get_zacharys_karate_club() -> jraph.GraphsTuple:
    
    social_graph = dataset.edge_index.t()
    n_club_members = dataset.num_nodes
    return jraph.GraphsTuple(
    n_node=jnp.asarray([n_club_members]),
    n_edge=jnp.asarray([len(social_graph)]),
      # One-hot encoding for nodes.
    nodes=jnp.eye(n_club_members),
      # No edge features.
    edges=None,
    globals=None,
    senders=jnp.asarray([edge[0] for edge in social_graph]),
    receivers=jnp.asarray([edge[1] for edge in social_graph]))

In [4]:
get_zacharys_karate_club()



GraphsTuple(nodes=DeviceArray([[1., 0., 0., ..., 0., 0., 0.],
             [0., 1., 0., ..., 0., 0., 0.],
             [0., 0., 1., ..., 0., 0., 0.],
             ...,
             [0., 0., 0., ..., 1., 0., 0.],
             [0., 0., 0., ..., 0., 1., 0.],
             [0., 0., 0., ..., 0., 0., 1.]], dtype=float32), edges=None, receivers=DeviceArray([ 1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 17, 19, 21,
             31,  0,  2,  3,  7, 13, 17, 19, 21, 30,  0,  1,  3,  7,  8,
              9, 13, 27, 28, 32,  0,  1,  2,  7, 12, 13,  0,  6, 10,  0,
              6, 10, 16,  0,  4,  5, 16,  0,  1,  2,  3,  0,  2, 30, 32,
             33,  2, 33,  0,  4,  5,  0,  0,  3,  0,  1,  2,  3, 33, 32,
             33, 32, 33,  5,  6,  0,  1, 32, 33,  0,  1, 33, 32, 33,  0,
              1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29,
             33,  2, 23, 24, 33,  2, 31, 33, 23, 26, 32, 33,  1,  8, 32,
             33,  0, 24, 25, 28, 32, 33,  2,  8, 14, 15, 18, 20, 22, 23,
    

In [5]:
def get_ground_truth_assignments_for_zacharys_karate_club() -> jnp.ndarray:
    
    return jnp.array([np.array(dataset.y)])

In [6]:
get_ground_truth_assignments_for_zacharys_karate_club()

DeviceArray([[0, 0, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0, 0, 2, 2, 1, 0, 2, 0,
              2, 0, 2, 2, 3, 3, 2, 2, 3, 2, 2, 3, 2, 2]], dtype=int32)

In [7]:
def network_definition(graph: jraph.GraphsTuple) -> jraph.ArrayTree:
    

    gn = jraph.GraphConvolution(
      update_node_fn=lambda n: jax.nn.relu(hk.Linear(5)(n)),
      add_self_edges=True)
    graph = gn(graph)

    gn = jraph.GraphConvolution(
      update_node_fn=hk.Linear(4))
    graph = gn(graph)
    return graph.nodes

In [8]:
network = hk.without_apply_rng(hk.transform(network_definition))
zacharys_karate_club = get_zacharys_karate_club()
labels = get_ground_truth_assignments_for_zacharys_karate_club()
params = network.init(jax.random.PRNGKey(42), zacharys_karate_club)

In [9]:
def prediction_loss(params):
    
    decoded_nodes = network.apply(params, zacharys_karate_club)
    # We interpret the decoded nodes as a pair of logits for each node.
    log_prob = jax.nn.log_softmax(decoded_nodes)
    # The only two assignments we know a-priori are those of Mr. Hi (Node 0)
    # and John A (Node 33).
    return -(log_prob[0, 0] + log_prob[33, 1])

In [10]:
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(params)
def step(step, opt_state):
    
    value, grads = jax.value_and_grad(prediction_loss)(get_params(opt_state))
    opt_state = opt_update(step, grads, opt_state)
    return value, opt_state

In [11]:
def accuracy(params):
    decoded_nodes = network.apply(params, zacharys_karate_club)
    return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == labels)

In [12]:
for shag in range(500):
    
    
    val, opt_state = step(shag, opt_state)

In [13]:
val


DeviceArray(0.16148417, dtype=float32)