In [13]:
import numpy as np

In [12]:
def log_actions(action: str):
    with open("actions.log", "a") as f:
        f.write(action + "\n")

In [14]:
class Node:
    def __init__(self):
        #the vector that is stored in the node
        self.data = np.random.randn(20)

        #the weights governing how this node interacts with other nodes
        self.wkey = np.random.randn(20, 20)
        self.wquery = np.random.randn(20, 20)
        self.wvalue = np.random.randn(20, 20)
        log_actions(f"Node created with data {self.data}")

    def key(self):
        #what do I have
        return self.wkey @ self.data
    
    def query(self):
        #what am I lookin for
        return self.wquery @ self.data
    
    def value(self):
        #what do I reveal to others
        return self.wvalue @ self.data

In [15]:
class Graph:
    def __init__(self):
        #make 10 nodes
        self.nodes = [Node() for _ in range(10)]

        #make forty edges
        randi = lambda: np.random.randint(len(self.nodes))
        self.edges = [(randi(), randi()) for _ in range(40)]
        
        log_actions(f"Graph created with {len(self.nodes)} nodes and {len(self.edges)} edges")

    def run(self):
        updates = []
        for i, node in enumerate(self.nodes):
            #what is this node looking for?
            nodequery = node.query()

            #find all edges that are inputs to this node
            inputs = [self.nodes[ifrom] for (ifrom, ito) in self.edges if ito == i]
            if len(inputs) == 0:
                continue #ignore this because there are no edges that are inputs to this node

            #gather their keys, what they essentially hold 
            keys = [input.key() for input in inputs]
            #calculate the compatibilities which is the dot product of the query and the key
            compatibilities = [nodequery @ key for key in keys]
            #normalize the compatibilities using softmax
            compatibilities = np.exp(compatibilities) / np.sum(np.exp(compatibilities))
            #gather appropriate values with a weighted sum
            values = [input.value() for input in inputs]
            update = sum([compatibility * value for compatibility, value in zip(compatibilities, values)])
            updates.append(update)
        for n, u in zip(self.nodes, updates):
            n.data = n.data + u



In [16]:
graph = Graph()
graph.run()

In [17]:
for node in graph.nodes:
    print(node.data)

[ 0.34980428 -8.55576913  9.0326494  -1.11787839  5.47553906 -5.92661741
 -0.04154986 -0.98786874  3.09899036  1.37551059 -5.24542785  0.18765526
  8.81740656  6.21005609  4.73918136  9.82351925 -4.9224487  -1.7196883
  0.57575593 -7.72066211]
[ 5.87292921 -2.36133427 -3.1973296   0.18512551  7.61649843 -0.39547014
  4.3745727  -0.14832297 -4.18022875  3.06017197  4.50714099  1.55945377
  6.11558694 -5.12928901  4.27900318 -3.90178027  1.75024987  1.32372836
  4.77864828  4.22777459]
[-0.21966258 -3.57766546 -0.59257447 -2.77776391  0.54390535 -6.85424971
 -0.84302688 -8.16723918  6.85945154 -1.61620675 -0.12685021  0.02567127
  2.93191602  2.03226816 -2.38240417  0.984404   -0.64782734 -2.02829612
  2.38097766 -2.53934162]
[-0.61025965 -7.51727355  6.23311601 -1.35134273  6.37955509 -6.99405095
  1.90698605 -0.60711774  2.63339999  2.80530698 -2.23144661 -0.58365874
 10.74699615  6.70597942  3.70135305  9.1405145  -2.17853582  0.5123895
  1.98337525 -7.76351175]
[-5.26324456 -9.586107