In [2]:
import torch
import numpy as np
import networkx as nx
import community as community_louvain
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn.initializers import glorot_normal, glorot_uniform, normal, uniform, zeros
from networkx import karate_club_graph, to_numpy_matrix
from jax.experimental import optimizers
import time


In [3]:
def visualize(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])

    if torch.is_tensor(h):
        h = h.detach().cpu().numpy()
        plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
        if epoch is not None and loss is not None:
            plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    else:
        nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                         node_color=color, cmap="Set2")
    plt.show()

In [4]:
class KarateClub(InMemoryDataset):

    def __init__(self, transform=None):
        super(KarateClub, self).__init__('.', transform, None, None)

        G = nx.karate_club_graph()

        x = torch.eye(G.number_of_nodes(), dtype=torch.float)
        order = sorted(list(G.nodes()))
        adjency = to_numpy_matrix(G,nodelist=order)
        
        adj = nx.to_scipy_sparse_matrix(G).tocoo()
        row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long)
        col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long)
        edge_index = torch.stack([row, col], dim=0)

        # Compute communities.
        partition = community_louvain.best_partition(G)
        y = torch.tensor([partition[i] for i in range(G.number_of_nodes())])
        senders = edge_index[0]
        receivers = edge_index[1]
        # Select a single training node for each community
        # (we just use the first one).
        train_mask = torch.zeros(y.size(0), dtype=torch.bool)
        num_nodes = G.number_of_nodes()
        for i in range(int(y.max()) + 1):
            train_mask[(y == i).nonzero(as_tuple=False)[0]] = True

        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,
                    adjency = adjency)

        self.data, self.slices = self.collate([data])

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)

In [5]:
dataset = KarateClub()
dataset = dataset[0]

In [6]:
out_dim = 12
bias = False
def init_fun(rng, input_shape):
    output_shape = input_shape[:-1] + (out_dim,)
    k1, k2 = random.split(rng)
    W_init, b_init = glorot_uniform(), zeros
    W = W_init(k1, (input_shape[-1], out_dim))
    b = b_init(k2, (out_dim,)) if bias else None
    return output_shape, (W, b)

In [7]:
key = random.PRNGKey(1)



In [8]:
dataset.x.shape

torch.Size([34, 34])

In [9]:
a,b = init_fun(key,dataset.x.shape)

In [10]:
a

torch.Size([34, 12])

In [11]:
b[0].shape

(34, 12)

In [12]:
def apply_fun(params, x, adj, **kwargs):
    W, b = params
    x = np.array(x)
    support = np.dot(x, W)
    out = np.matmul(adj, support)
    if bias:
        out += b
    return out

In [13]:
c = apply_fun(b,dataset.x,dataset.adjency)
c.shape

(34, 12)

In [14]:
def Dropout(rate):
    """
    Layer construction function for a dropout layer with given rate.
    This Dropout layer is modified from stax.experimental.Dropout, to use
    `is_training` as an argument to apply_fun, instead of defining it at
    definition time.
    Arguments:
        rate (float): Probability of keeping and element.
    """
    def init_fun(rng, input_shape):
        return input_shape, ()
    def apply_fun(params, inputs, is_training, **kwargs):
        rng = kwargs.get('rng', None)
        if rng is None:
            msg = ("Dropout layer requires apply_fun to be called with a PRNG key "
                   "argument. That is, instead of `apply_fun(params, inputs)`, call "
                   "it like `apply_fun(params, inputs, rng)` where `rng` is a "
                   "jax.random.PRNGKey value.")
            raise ValueError(msg)
        keep = random.bernoulli(rng, rate, inputs.shape)
        outs = np.where(keep, inputs / rate, 0)
        # if not training, just return inputs and discard any computation done
        out = lax.cond(is_training, outs, lambda x: x, inputs, lambda x: x)
        return out
    return init_fun, apply_fun

In [15]:
def GraphConvolution(out_dim, bias=False):
    """
    Layer constructor function for a Graph Convolution layer 
    as the one in https://arxiv.org/abs/1609.02907
    """
    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2 = random.split(rng)
        W_init, b_init = glorot_uniform(), zeros
        W = W_init(k1, (input_shape[-1], out_dim))
        b = b_init(k2, (out_dim,)) if bias else None
        return output_shape, (W, b)

    def apply_fun(params, x, adj, **kwargs):
        W, b = params
        support = np.dot(x, W)
        out = np.matmul(adj, support)
        if bias:
            out += b
        return out

    return init_fun, apply_fun

In [16]:
gc1_init, gc1_fun = GraphConvolution(4)
_, drop_fun = Dropout(1)
gc2_init, gc2_fun = GraphConvolution(4)

In [17]:
drop_fun

<function __main__.Dropout.<locals>.apply_fun(params, inputs, is_training, **kwargs)>

In [18]:
def GCN(nhid: int, nclass: int, dropout: float):
    """
    This function implements the GCN model that uses 2 Graph Convolutional layers.
    """
    gc1_init, gc1_fun = GraphConvolution(nhid)
    _, drop_fun = Dropout(dropout)
    gc2_init, gc2_fun = GraphConvolution(nclass)

    init_funs = [gc1_init, gc2_init]

    def init_fun(rng, input_shape):
        params = []
        for init_fun in init_funs:
            rng, layer_rng = random.split(rng)
            input_shape, param = init_fun(layer_rng, input_shape)
            params.append(param)
        return input_shape, params

    def apply_fun(params, x, adj, is_training=False, **kwargs):
        rng = kwargs.pop('rng', None)
        k1, k2, k3, k4 = random.split(rng, 4)
        x = drop_fun(None, x, is_training=is_training, rng=k1)
        x = gc1_fun(params[0], x, adj, rng=k2)
        x = nn.relu(x)
        x = drop_fun(None, x, is_training=is_training, rng=k3)
        x = gc2_fun(params[1], x, adj, rng=k4)
        x = nn.log_softmax(x)
        return x
    
    return init_fun, apply_fun

In [19]:
rng_key = random.PRNGKey(1)
init_fun, predict_fun = GCN(nhid=4, 
                            nclass=len(dataset.y.unique()),
                            dropout=4)
input_shape = (-1, 34, 34)
rng_key, init_key = random.split(rng_key)
_, init_params = init_fun(init_key, input_shape)

In [20]:
def loss(params, batch):
    """
    The idxes of the batch indicate which nodes are used to compute the loss.
    """
    inputs, targets, adj, is_training, rng, idx = batch
    preds = predict_fun(params, inputs, adj, is_training=is_training, rng=rng)
    ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))
    l2_loss = 5e-4 * optimizers.l2_norm(params)
    return ce_loss + l2_loss

In [21]:
opt_init, opt_update, get_params = optimizers.adam(0.001)
opt_state = opt_init(init_params)


In [22]:
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

In [25]:
batch = (features, labels, adj, True, rng_key, idx_train)
    # update parameters
opt_state = update(epoch, opt_state, batch)


TypeError: bernoulli probability `p` must have a floating dtype, got int32.

In [24]:
params = get_params(opt_state)
eval_batch = (features, labels, adj, False, rng_key, idx_val)
val_acc = accuracy(params, eval_batch)
val_loss = loss(params, eval_batch)
print(f"val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}")

NameError: name 'idx_val' is not defined