In [3]:
import numpy as np

class Node:
    def __init__(self):
        # the vector stored at this node
        self.data = np.random.randn(20)

        # weights governing how this node interacts with other nodes
        self.wkey = np.random.randn(20, 20)
        self.wquery = np.random.randn(20,20)
        self.wvalue = np.random.randn(20, 20)

    def key(self):
        # what do I have?
        return self.wkey @ self.data

    def query(self):
        # what am I looking for?
        return self.wquery @ self.data

    def value(self):
        # what do I publicly reveal/broadcast to others?
        return self.wvalue @ self.data

class Graph:
    def __init__(self):
        # make 10 nodes
        self.nodes = [Node() for _ in range(10)]
        # make 40 edges
        randi = lambda: np.random.randint(len(self.nodes))
        self.edges = [[randi(), randi()] for _ in range(40)]

    def run(self):
        updates = []
        for i,n in enumerate(self.nodes):

            # what is this node looking for?
            q = n.query()

            # find all edges that are input to this node
            inputs = [self.nodes[ifrom] for (ifrom,ito) in self.edges if ito == i]
            if len(inputs) == 0:
                continue # ignore
            # gather their keys, i.e. what they hold
            keys = [m.key() for m in inputs]
            # calculate the compatibilities
            scores = [k.dot(q) for k in keys]
            # softmax them so they sum to 1
            scores = np.exp(scores)
            scores = scores / np.sum(scores)
            # gather the appropriate values with a weighted sum
            values = [m.value() for m in inputs]
            update = sum([s * v for s, v in zip(scores, values)])
            updates.append(update)

        for n,u in zip(self.nodes, updates):
            n.data = n.data + u # residual connection

In [4]:
g = Graph()

In [5]:
g.run()

In [6]:
g.nodes[0].value()

array([-19.19484004,   6.97836972,  12.94827538, -11.9027061 ,
         9.16783912,   9.00492918,   2.59671444,  22.25372667,
         3.52819385,  -9.37056147,  11.16274127, -14.75526086,
       -27.26950209,   4.17288922,   5.75706181, -22.99143236,
        11.26445524,  -9.86872412,  -2.27152432,  -3.83321625])

In [8]:
g.nodes[0].data

array([ -1.85336748,  -0.78992899,   1.44651486, -10.05951407,
         4.8522011 ,   1.53458908,   2.0938737 ,  -0.23831768,
        -1.53342003,   0.35576812,  -2.62942549,   3.51083561,
        -3.00156765,   0.85751808,   3.56824856,  -4.91785425,
        -0.13923238,   0.19038518,   1.39853444,  -6.03155581])

In [11]:
g.edges

[[9, 6],
 [6, 5],
 [6, 9],
 [0, 5],
 [7, 6],
 [5, 9],
 [6, 3],
 [3, 6],
 [9, 7],
 [6, 9],
 [5, 9],
 [8, 9],
 [3, 2],
 [1, 8],
 [7, 4],
 [2, 4],
 [7, 1],
 [5, 9],
 [1, 6],
 [9, 6],
 [5, 6],
 [9, 6],
 [2, 4],
 [2, 6],
 [2, 2],
 [7, 1],
 [3, 6],
 [9, 6],
 [6, 8],
 [3, 6],
 [3, 5],
 [3, 2],
 [1, 7],
 [7, 4],
 [1, 1],
 [6, 3],
 [9, 1],
 [7, 6],
 [1, 6],
 [5, 6]]

In [13]:
g.nodes[0].data.shape

(20,)