In [1]:
%matplotlib inline
import networkx as nx
from networkx.drawing.nx_agraph import write_dot, graphviz_layout
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from coolname import generate_slug as name

In [None]:
class graph_neural_network():
    def __init__(self,n_in,n_out):
        self.G = nx.DiGraph()
        self.outputs = 
        
    def get_layers(self):
        G = self.G
        G2 = nx.topological_sort(G)
        max_layer = 0
        for n in G2:
            G.node[n]['layer'] = max([G.node[k[0]]['layer'] for k in G.in_edges(n)] + [-1]) + 1
            max_layer = max(max_layer,G.node[n]['layer'])

        layers = [[] for _ in range(max_layer + 1)]

        for n in G:
            layers[G.node[n]['layer']].append(n)

        flat_idx = 0
        for i, layer in enumerate(layers):
            for j, n in enumerate(layer):
                G.node[n]['idx'] = j
                G.node[n]['flat_idx'] = flat_idx
                flat_idx += 1
        return layers
    
    def get_output_idxes(self):
        outputs = self.outputs
        G = self.G
        out_idxes = []
        for n in G:
            if(n in outputs):
                out_idxes.append(G.node[n]['flat_idx'])
        return out_idxes
    
    def graph_to_weights(self):
        G = self.G
        outputs = self.outputs
        
        layers = self.get_layers()
        mask = []    
        weights = []
        biases = []

        n_nodes = 0
        for i in range(len(layers) - 1):
            n_nodes += len(layers[i])
            mask.append(np.zeros((len(layers[i+1]),n_nodes)))
            biases.append(np.zeros((len(layers[i+1]))))
            weights.append(np.zeros((len(layers[i+1]),n_nodes)))

            for j, node1 in enumerate(layers[i+1]):
                biases[i][j] = G.node[node1]['bias']
                for node0, _ in G.in_edges(node1): 
                    u = G.node[node0]['flat_idx']
                    v = G.node[node1]['idx']
                    mask[i][v,u] = 1
                    weights[i][v,u] = G[node0][node1]['weight']

        n_nodes += len(layers[-1])
        out_idxes = self.get_output_idxes(G, outputs)
        lastlayer = np.zeros((len(out_idxes),n_nodes))
        for v, u in enumerate(out_idxes): #identity layer mapping to output neurons
            lastlayer[v,u] = 1

        mask.append(lastlayer)
        weights.append(lastlayer)
        biases.append(np.zeros(len(out_idxes)))

        return weights, biases, mask
    
    def weights_to_graph(self,weights,biases):
        G = self.G
        layers = get_layers(G)
        for i in range(len(layers) - 1):
            for j, node1 in enumerate(layers[i+1]):
                G.node[node1]['bias'] = biases[i][j]
                for node0, _ in G.in_edges(node1): 
                    u = G.node[node0]['flat_idx']
                    v = G.node[node1]['idx']
                    G[node0][node1]['weight'] = weights[i][v,u]
                    
def zero_grad(self, grad_input, grad_output):
    temp = list(grad_input)
    temp[2] *= torch.transpose(self.mask,0, 1)
    return tuple(temp)

class SparseNet(nn.Module):
    def __init__(self, wieghts, biases, masks):
        super(SparseNet, self).__init__()
        self.layers = nn.ModuleList()
        self.masks = masks
        i = 0
        for w, b, m in zip(weights,biases,masks):
            temp = nn.Linear(w.shape[1],w.shape[0])
            temp.weight.data = torch.from_numpy(w.astype(np.float32))
            temp.bias.data = torch.from_numpy(b.astype(np.float32))
            temp.mask = torch.from_numpy(m.astype(np.float32))
            temp.num = i
            temp.register_backward_hook(zero_grad)
            self.layers.append(temp)
            i += 1
        self.layers[-1].requires_grad = False

    def forward(self, x):
        for l in self.layers:
            y = l(x)
            x = torch.cat((x,y),1)
            x = F.relu(x)
        return y

    def dumpweights(self):
        weights = []
        biases = []
        for l in self.layers:
            weights.append(l.weight.data.numpy().astype(np.float64))
            biases.append(l.bias.data.numpy().astype(np.float64))
        return weights, biases