In [None]:
import argparse
import networkx as nx
import numpy as np
import scipy as sp
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import my

In [None]:
args = argparse.Namespace()
args.depth = 10
args.dense = True
args.graph = 'soc-Epinions1-reduced'
args.n_features = 8
args.n_machines = 10
args.radius = 3

In [None]:
class Objective:
    def __init__(self, g, n_machines, dense):
#         number_of_nodes = max(tuple(g.nodes())) + 1
        s = sp.sparse.lil_matrix((g.number_of_nodes(), g.number_of_edges()))
        for i, (j, k) in enumerate(g.edges()):
            s[j, i] = 1
            s[k, i] = 1
        self.s = my.sparse_sp2th(s)
        if dense:
            self.s = self.s.to_dense()
        self.n_machines = n_machines
        
    def __call__(self, x):
        y = th.multinomial(x, num_samples=1)
        y = my.onehot(y, self.n_machines)
        y = th.mm(self.s, y)
        r = th.sum(y)
        p = (th.sum(y, 1) + 1) / r
        b = -p * th.log(p)
        objective = -(r + b) * th.log(x)
        return objective

In [None]:
class GNNModule(nn.Module):
    def __init__(self, in_features, out_features, adj, nonlinear):
        super().__init__()
        self.adj = adj
        self.deg = th.sum(adj[0], 1, keepdim=True)
        new_linear = lambda: nn.Parameter(th.randn(in_features, out_features))
        self.alpha1, self.alpha2, self.alpha3 = new_linear(), new_linear(), new_linear()
        self.alpha4 = nn.ParameterList([new_linear() for a in adj])
        self.beta1, self.beta2, self.beta3 = new_linear(), new_linear(), new_linear()
        self.beta4 = nn.ParameterList([new_linear() for a in adj])
        self.bn_alpha, self.bn_beta = nn.BatchNorm1d(out_features), nn.BatchNorm1d(out_features)
        self.nonlinear = nonlinear
    
    def forward(self, x):
        deg = self.deg * x
        u = th.zeros_like(x) + th.mean(x, 1, keepdim=True)
        adj = [th.mm(a, x) for a in self.adj]
        alpha = th.mm(x, self.alpha1) + th.mm(deg, self.alpha1) + th.mm(u, self.alpha2) + \
            sum(th.mm(a, alpha) for alpha, a in zip(self.alpha4, adj))
        alpha = self.bn_alpha(self.nonlinear(alpha))
        beta = th.mm(x, self.beta1) + th.mm(deg, self.beta1) + th.mm(u, self.beta2) + \
            sum(th.mm(a, beta) for beta, a in zip(self.beta4, adj))
        beta = self.bn_beta(beta)
        return alpha + beta

class EdgeLinear(nn.Module):
    def __init__(self, in_features, out_features, adj):
        super().__init__()
        self.linear_list = nn.ModuleList([nn.Linear(in_features, 1) for i in range(out_features)])
        self.adj = adj
    
    def forward(self, x):
        z_list = []
        for linear in self.linear_list:
            z = linear(x)
            z = z * self.adj + z.t() * self.adj
#             convert z (|V| * |V|) to |E| * 1
            z_list.append()
        z = th.cat(z_list, 1)
        return z

class GNN(nn.Module):
    def __init__(self, features, n_classes, adj, radius, nonlinear, dense):
        super().__init__()
        adj = my.sparse_sp2th(adj).float()
        if dense:
            adj = adj.to_dense()
        a, adj_list = adj, [adj]
        for i in range(radius - 1):
            a = th.mm(a, a)
            adj_list.append(a)
        # TODO nn.Sequential
        self.module_list = nn.ModuleList([GNNModule(m, n, adj_list, nonlinear)
                                          for m, n in zip(features[:-1], features[1:])])
        self.linear = EdgeLinear(features[-1], n_classes, adj)
    
    def forward(self, x):
        for module in self.module_list:
            x = module(x)
        x = self.linear(x)
        x = F.softmax(x, 1)
        return x

In [None]:
g = my.read_edgelist(args.graph)
adj = nx.adj_matrix(g)
objective = Objective(g, 10, args.dense)
gnn = GNN((1,) + (args.n_features,) * args.depth, args.n_machines, adj, args.radius, F.relu, args.dense)

In [None]:
x = th.from_numpy(np.sum(adj, 1)).float()
objective(gnn(x)).backward()