In [37]:
%alias_magic r run -p "model.py"
%r

Created `%r` as an alias for `%run model.py`.


# Pixel Graph Neural Network

## Lattice System

### Flat Lattice

Create a flat lattice. List the nodes and their center coordinates.
- **relevant nodes**: the preceeding nodes within a given radius.

In [217]:
latt = FlatLattice(4, 2)
print('ind   center   relevant_nodes')
for node in latt.nodes:
    if node.type is 'lat':
        print(' {:2d} {} {}'.format(node.ind, 
                                 node.center.tolist(),
                                 [node.ind for node in latt.relevant_nodes(node, 2.)]))

ind   center   relevant_nodes
  0 [0.0, 0.0] []
  1 [0.0, 1.0] [0]
  2 [0.0, 2.0] [1]
  3 [0.0, 3.0] [2, 0]
  4 [1.0, 0.0] [1, 3, 0]
  5 [1.0, 1.0] [1, 2, 0, 4]
  6 [1.0, 2.0] [2, 1, 5, 3]
  7 [1.0, 3.0] [6, 0, 2, 3, 4]
  8 [2.0, 0.0] [4, 7, 5]
  9 [2.0, 1.0] [4, 6, 5, 8]
 10 [2.0, 2.0] [9, 6, 7, 5]
 11 [2.0, 3.0] [6, 10, 8, 4, 7]
 12 [3.0, 0.0] [0, 3, 1, 11, 8, 9]
 13 [3.0, 1.0] [12, 0, 2, 10, 1, 8, 9]
 14 [3.0, 2.0] [2, 10, 3, 1, 11, 13, 9]
 15 [3.0, 3.0] [12, 14, 0, 2, 10, 3, 11, 8]


The nodes are arranged on a square lattice with periodic boundary condition.

<img src="./image/flat.png" alt="flat" width="180"/>

Construct the causal graph from the lattice.

In [216]:
graph = latt.causal_graph()
graph.type_dict, graph.source_depths, graph.adjacency_matrix().to_dense()

({1: (0, 1), 2: (1, 1)},
 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 2, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 0, 2, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 2, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 2, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 2, 0, 2, 1, 1, 0, 1, 0, 0, 0, 0, 0],
         [1, 2, 0, 2, 0, 0, 0, 0, 1, 2, 0, 2, 0, 0, 0, 0],
         [2, 1, 2, 0, 0, 0, 0, 0, 2, 1, 2, 0, 1, 0, 0, 0],
         [0, 2, 1, 2, 0, 0, 0, 0, 0, 2, 1, 2, 0, 1, 0, 0],
         [2, 0,

How the causal graph scales with the radius.

In [208]:
for r in [1.01,1.42,2.01,2.24,2.83]:
    graph = latt.causal_graph(radius=r)
    print('r = {}: {}, max_depth = {}'.format(r, graph, graph.max_depth))

r = 1.01: Graph(16x16, 32 edges of 1 types), max_depth = 6
r = 1.42: Graph(16x16, 64 edges of 2 types), max_depth = 15
r = 2.01: Graph(16x16, 80 edges of 3 types), max_depth = 15
r = 2.24: Graph(16x16, 112 edges of 4 types), max_depth = 15
r = 2.83: Graph(16x16, 120 edges of 5 types), max_depth = 15


### Tree Lattice

Create a tree lattice. List the nodes and their center coordinates.
- **relevant nodes**: the preceeding nodes within the past light-cone of nodes in a given radius.

In [228]:
latt = TreeLattice(4, 2)
print('ind   center   relevant_nodes')
for node in latt.nodes:
    if node.type is 'lat':
        print(' {:2d} {} {}'.format(node.ind, 
                                 node.center.tolist() if node.center is not None else '          ',
                                 [node.ind for node in latt.relevant_nodes(node, 2.)]))

ind   center   relevant_nodes
  0            []
  1 [2.0, 2.0] []
  2 [1.0, 2.0] [1]
  3 [3.0, 2.0] [2, 1]
  4 [1.0, 1.0] [2, 1, 3]
  5 [3.0, 1.0] [4, 2, 1, 3]
  6 [1.0, 3.0] [4, 5, 1, 3, 2]
  7 [3.0, 3.0] [6, 4, 5, 1, 3, 2]
  8 [0.5, 1.0] [6, 4, 5, 1, 3, 7, 2]
  9 [2.5, 1.0] [6, 4, 5, 1, 8, 3, 7, 2]
 10 [0.5, 3.0] [6, 4, 5, 1, 8, 3, 7, 2]
 11 [2.5, 3.0] [6, 4, 5, 1, 10, 9, 3, 7, 2]
 12 [1.5, 1.0] [6, 11, 4, 5, 1, 10, 9, 8, 3, 7, 2]
 13 [3.5, 1.0] [6, 11, 4, 5, 1, 10, 9, 12, 8, 3, 7, 2]
 14 [1.5, 3.0] [6, 11, 4, 5, 1, 10, 9, 12, 8, 3, 7, 2]
 15 [3.5, 3.0] [6, 11, 14, 4, 5, 1, 10, 13, 9, 8, 3, 7, 2]


Latent nodes are arranged in on a binary tree representing the hierachical structure.

<img src="./image/tree.png" alt="tree" width="180"/>

Construct the causal graph from the lattice.

In [229]:
graph = latt.causal_graph()
graph.type_dict, graph.source_depths, graph.adjacency_matrix().to_dense()

({1: (0, 1), 2: (1, 3), 3: (3, 3), 4: (1, 1), 5: (2, 3), 6: (0, 3), 7: (0, 2)},
 tensor([0, 0, 1, 2, 2, 3, 2, 3, 3, 4, 3, 4, 5, 5, 5, 5]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 7, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 7, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 7, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 7, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 6, 7, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 6, 0, 7, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 6, 7, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 6, 0, 7, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 6, 7, 2, 1, 5, 0, 0, 4, 3, 0, 0, 0, 0, 0, 0],
         [0, 6, 2, 7, 5, 1, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0],
         [0, 6, 7, 2, 0, 0, 1, 5, 0

How the causal graph scales with the radius.

In [233]:
for r in [1., 1.2, 1.5, 2., 3.]:
    graph = latt.causal_graph(radius=r)
    print('r = {}: {}, max_depth = {}'.format(r, graph, graph.max_depth))

r = 1.0: Graph(16x16, 51 edges of 7 types), max_depth = 5
r = 1.2: Graph(16x16, 67 edges of 9 types), max_depth = 7
r = 1.5: Graph(16x16, 85 edges of 9 types), max_depth = 10
r = 2.0: Graph(16x16, 101 edges of 9 types), max_depth = 12
r = 3.0: Graph(16x16, 105 edges of 9 types), max_depth = 14


## Training

### Flat Lattice

In [249]:
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), FlatLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=2.01).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02);

In [251]:
batch_size = 100
echo = 100
cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
for k in range(2000):
    loss, meanfree, stdfree = model.loss(batch_size, return_statistics=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cum_loss += loss.item()
    cum_meanfree += meanfree
    cum_stdfree += stdfree
    if (k+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f} ±{:8.4f}'.format(k+1, cum_loss/echo, cum_meanfree/echo, cum_stdfree/echo))
        cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.

  100 loss:  -0.5880, free energy: -15.1891 ±  0.6904
  200 loss:  -0.1124, free energy: -15.3628 ±  0.4837
  300 loss:  -0.0965, free energy: -15.4377 ±  0.3748
  400 loss:  -0.0559, free energy: -15.4739 ±  0.3153
  500 loss:  -0.0609, free energy: -15.4852 ±  0.2773
  600 loss:  -0.0234, free energy: -15.4324 ±  0.3450
  700 loss:  -0.0432, free energy: -15.4757 ±  0.2760
  800 loss:  -0.0358, free energy: -15.4859 ±  0.2527
  900 loss:  -0.0530, free energy: -15.4946 ±  0.2280
 1000 loss:  -0.0095, free energy: -15.4715 ±  0.2847
 1100 loss:  -0.0347, free energy: -15.4877 ±  0.2383
 1200 loss:  -0.0175, free energy: -15.4862 ±  0.2493
 1300 loss:  -0.0349, free energy: -15.4929 ±  0.2224
 1400 loss:  -0.0259, free energy: -15.4935 ±  0.2243
 1500 loss:  -0.0251, free energy: -15.5040 ±  0.1834
 1600 loss:  -0.0326, free energy: -15.5052 ±  0.1859
 1700 loss:  -0.0102, free energy: -15.4927 ±  0.2491
 1800 loss:  -0.0216, free energy: -15.5051 ±  0.1816
 1900 loss:  -0.0085, free e

### Tree Lattice

In [247]:
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), TreeLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=1.5).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02);

In [248]:
batch_size = 100
echo = 100
cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
for k in range(2000):
    loss, meanfree, stdfree = model.loss(batch_size, return_statistics=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cum_loss += loss.item()
    cum_meanfree += meanfree
    cum_stdfree += stdfree
    if (k+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f} ±{:8.4f}'.format(k+1, cum_loss/echo, cum_meanfree/echo, cum_stdfree/echo))
        cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.

  100 loss:  -0.5822, free energy: -14.7509 ±  0.8652
  200 loss:   0.0133, free energy: -15.1743 ±  0.3668
  300 loss:   0.0120, free energy: -15.2134 ±  0.3346
  400 loss:   0.0141, free energy: -15.2364 ±  0.3650
  500 loss:   0.0168, free energy: -15.2325 ±  0.3552
  600 loss:   0.0310, free energy: -15.2542 ±  0.3379
  700 loss:   0.0260, free energy: -15.2759 ±  0.4870
  800 loss:   0.0259, free energy: -15.3284 ±  0.4301
  900 loss:   0.0409, free energy: -15.3497 ±  0.3789
 1000 loss:   0.0042, free energy: -15.3600 ±  0.3767
 1100 loss:   0.0166, free energy: -15.3793 ±  0.4057
 1200 loss:   0.0183, free energy: -15.3926 ±  0.4032
 1300 loss:   0.0132, free energy: -15.4008 ±  0.3883
 1400 loss:   0.0123, free energy: -15.4034 ±  0.3944
 1500 loss:  -0.0026, free energy: -15.4017 ±  0.3867
 1600 loss:   0.0021, free energy: -15.4106 ±  0.3700
 1700 loss:   0.0163, free energy: -15.4219 ±  0.3569
 1800 loss:   0.0180, free energy: -15.4117 ±  0.3622
 1900 loss:  -0.0117, free e