In [2]:
%alias_magic r run -p "model.py"
%r

Created `%r` as an alias for `%run model.py`.


# Pixel Graph Neural Network

## Lattice System

### Flat Lattice

Create a flat lattice. List the nodes and their center coordinates.
- **relevant nodes**: the preceeding nodes within a given radius.

In [2]:
latt = FlatLattice(4, 2)
print('ind   center   relevant_nodes')
for node in latt.nodes:
    if node.type is 'lat':
        print(' {:2d} {} {}'.format(node.ind, 
                                 node.center.tolist(),
                                 [node.ind for node in latt.relevant_nodes(node, 2.)]))

ind   center   relevant_nodes
  0 [0.0, 0.0] []
  1 [0.0, 1.0] [0]
  2 [0.0, 2.0] [1]
  3 [0.0, 3.0] [2, 0]
  4 [1.0, 0.0] [0, 3, 1]
  5 [1.0, 1.0] [4, 2, 0, 1]
  6 [1.0, 2.0] [2, 5, 3, 1]
  7 [1.0, 3.0] [0, 2, 3, 4, 6]
  8 [2.0, 0.0] [4, 7, 5]
  9 [2.0, 1.0] [4, 5, 8, 6]
 10 [2.0, 2.0] [9, 7, 5, 6]
 11 [2.0, 3.0] [10, 4, 6, 7, 8]
 12 [3.0, 0.0] [9, 0, 1, 11, 3, 8]
 13 [3.0, 1.0] [9, 0, 1, 10, 2, 12, 8]
 14 [3.0, 2.0] [9, 10, 1, 11, 2, 3, 13]
 15 [3.0, 3.0] [0, 10, 11, 2, 3, 12, 14, 8]


The nodes are arranged on a square lattice with periodic boundary condition.

<img src="./image/flat.png" alt="flat" width="180"/>

Construct the causal graph from the lattice.

In [3]:
graph = latt.causal_graph()
graph.type_dict, graph.source_depths, graph.adjacency_matrix().to_dense()

({1: (0, 1), 2: (1, 1)},
 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 2, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [2, 0, 2, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 1, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 2, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 2, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 2, 0, 2, 1, 1, 0, 1, 0, 0, 0, 0, 0],
         [1, 2, 0, 2, 0, 0, 0, 0, 1, 2, 0, 2, 0, 0, 0, 0],
         [2, 1, 2, 0, 0, 0, 0, 0, 2, 1, 2, 0, 1, 0, 0, 0],
         [0, 2, 1, 2, 0, 0, 0, 0, 0, 2, 1, 2, 0, 1, 0, 0],
         [2, 0,

How the causal graph scales with the radius.

In [4]:
for r in [1.01,1.42,2.01,2.24,2.83]:
    graph = latt.causal_graph(radius=r)
    print('r = {}: {}, max_depth = {}'.format(r, graph, graph.max_depth))

r = 1.01: Graph(16x16, 32 edges of 1 types), max_depth = 6
r = 1.42: Graph(16x16, 64 edges of 2 types), max_depth = 15
r = 2.01: Graph(16x16, 80 edges of 3 types), max_depth = 15
r = 2.24: Graph(16x16, 112 edges of 4 types), max_depth = 15
r = 2.83: Graph(16x16, 120 edges of 5 types), max_depth = 15


### Tree Lattice

Create a tree lattice. List the nodes and their center coordinates.
- **relevant nodes**: the preceeding nodes within the past light-cone of nodes in a given radius.

In [35]:
latt = TreeLattice(4, 2)
print('ind   center   relevant_nodes')
for node in latt.nodes:
    if node.type is 'lat':
        print(' {:2d} {} {}'.format(node.ind, 
                                 node.center.tolist() if node.center is not None else '          ',
                                 [node.ind for node in latt.relevant_nodes(node, 2.)]))

ind   center   relevant_nodes
  0            []
  1 [2.0, 2.0] []
  2 [1.0, 2.0] [1]
  3 [3.0, 2.0] [2, 1]
  4 [1.0, 1.0] [3, 2, 1]
  5 [3.0, 1.0] [4, 3, 2, 1]
  6 [1.0, 3.0] [4, 5, 2, 3, 1]
  7 [3.0, 3.0] [6, 4, 5, 2, 3, 1]
  8 [0.5, 1.0] [6, 4, 5, 2, 3, 1, 7]
  9 [2.5, 1.0] [6, 4, 5, 2, 3, 1, 7, 8]
 10 [0.5, 3.0] [6, 4, 5, 2, 3, 1, 7, 8]
 11 [2.5, 3.0] [6, 4, 5, 9, 2, 3, 1, 10, 7]
 12 [1.5, 1.0] [6, 4, 5, 9, 11, 2, 3, 1, 10, 7, 8]
 13 [3.5, 1.0] [6, 4, 5, 9, 11, 2, 12, 3, 1, 10, 7, 8]
 14 [1.5, 3.0] [6, 4, 5, 9, 11, 2, 12, 3, 1, 10, 7, 8]
 15 [3.5, 3.0] [6, 4, 5, 14, 9, 11, 2, 3, 13, 1, 10, 7, 8]


Latent nodes are arranged in on a binary tree representing the hierachical structure.

<img src="./image/tree.png" alt="tree" width="180"/>

Construct the causal graph from the lattice.

In [36]:
graph = latt.causal_graph()
graph.type_dict, graph.source_depths, graph.adjacency_matrix().to_dense()

({1: (0, 1, 1),
  2: (0, 3, 3),
  3: (0, 0, 2),
  4: (0, 2, 3),
  5: (2, 1, 1),
  6: (0, 0, 1),
  7: (1, 0, 1),
  8: (0, 1, 3),
  9: (2, 0, 1),
  10: (1, 0, 2),
  11: (0, 0, 3)},
 tensor([0, 0, 1, 2, 2, 3, 2, 3, 3, 4, 3, 4, 5, 5, 5, 5]),
 tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  6,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  3,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  3,  0,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  3,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  3,  0,  7,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0, 11, 10,  0,  9,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0, 11,  0, 10,  0,  9,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0

How the causal graph scales with the radius.

In [37]:
for r in [1., 1.2, 1.5, 2., 3.]:
    graph = latt.causal_graph(radius=r)
    print('r = {}: {}, max_depth = {}'.format(r, graph, graph.max_depth))

r = 1.0: Graph(16x16, 51 edges of 11 types), max_depth = 5
r = 1.2: Graph(16x16, 67 edges of 14 types), max_depth = 7
r = 1.5: Graph(16x16, 85 edges of 16 types), max_depth = 10
r = 2.0: Graph(16x16, 101 edges of 16 types), max_depth = 12
r = 3.0: Graph(16x16, 105 edges of 16 types), max_depth = 14


## 

### Flat Lattice

Setup the model and train.

In [27]:
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), FlatLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=2.01).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02);

In [28]:
import time
batch_size = 5000
steps = 2000
echo = 100
cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
tic = time.time()
for k in range(steps):
    loss, meanfree, stdfree = model.loss(batch_size, return_statistics=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cum_loss += loss.item()
    cum_meanfree += meanfree
    cum_stdfree += stdfree
    if (k+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f} ±{:8.4f}'.format(k+1, cum_loss/echo, cum_meanfree/echo, cum_stdfree/echo))
        cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
toc = time.time()
print('total time: {:6.2f}s, {:6.2f}ms per step.'.format(toc-tic, (toc-tic)/steps*1000))

  100 loss:  -1.0314, free energy: -14.0495 ±  1.3977
  200 loss:  -0.2964, free energy: -15.4287 ±  0.4349
  300 loss:  -0.1251, free energy: -15.4752 ±  0.3027
  400 loss:  -0.0872, free energy: -15.4799 ±  0.2845
  500 loss:  -0.0921, free energy: -15.4883 ±  0.2572
  600 loss:  -0.0813, free energy: -15.4918 ±  0.2443
  700 loss:  -0.0691, free energy: -15.4933 ±  0.2378
  800 loss:  -0.0669, free energy: -15.4978 ±  0.2255
  900 loss:  -0.0496, free energy: -15.4873 ±  0.2594
 1000 loss:  -0.0634, free energy: -15.4992 ±  0.2161
 1100 loss:  -0.0598, free energy: -15.4997 ±  0.2098
 1200 loss:  -0.0572, free energy: -15.5002 ±  0.2120
 1300 loss:  -0.0588, free energy: -15.5011 ±  0.2101
 1400 loss:  -0.0527, free energy: -15.4988 ±  0.2149
 1500 loss:  -0.0535, free energy: -15.4987 ±  0.2171
 1600 loss:  -0.0532, free energy: -15.4995 ±  0.2146
 1700 loss:  -0.0554, free energy: -15.5014 ±  0.2067
 1800 loss:  -0.0544, free energy: -15.4999 ±  0.2139
 1900 loss:  -0.0536, free e

Short-range correlation function.

In [29]:
x = model.sample(100000);
dims = tuple(range(-model.model.lattice.dimension,0))
for shift in [(-1,0), (0,-1)]:
    rolled = model.model.group.inv(x.roll(shift, dims))
    coupled = model.model.group.mul(rolled, x)
    corr = model.model.group(coupled, torch.tensor([1.,-1.]))
    print(corr.mean(0))

tensor([[0.7922, 0.8011, 0.7962, 0.7861],
        [0.7895, 0.7803, 0.7835, 0.7775],
        [0.7928, 0.7822, 0.7800, 0.7827],
        [0.7908, 0.7867, 0.7837, 0.7776]], device='cuda:0')
tensor([[0.7961, 0.7879, 0.7702, 0.7936],
        [0.7677, 0.7791, 0.7711, 0.7658],
        [0.7603, 0.7586, 0.7709, 0.7668],
        [0.7767, 0.7789, 0.7767, 0.7747]], device='cuda:0')


### Tree Lattice (With Scale-Invariance)

Set up the model and train.

In [47]:
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), TreeLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=1.5, scale_invariance=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02);

In [53]:
import time
batch_size = 5000
steps = 2000
echo = 100
cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
tic = time.time()
for k in range(steps):
    loss, meanfree, stdfree = model.loss(batch_size, return_statistics=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cum_loss += loss.item()
    cum_meanfree += meanfree
    cum_stdfree += stdfree
    if (k+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f} ±{:8.4f}'.format(k+1, cum_loss/echo, cum_meanfree/echo, cum_stdfree/echo))
        cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
toc = time.time()
print('total time: {:6.2f}s, {:6.2f}ms per step.'.format(toc-tic, (toc-tic)/steps*1000))

  100 loss:  -0.0008, free energy: -15.4908 ±  0.2219
  200 loss:  -0.0008, free energy: -15.4872 ±  0.2299
  300 loss:  -0.0027, free energy: -15.4941 ±  0.2088
  400 loss:  -0.0019, free energy: -15.4902 ±  0.2203
  500 loss:  -0.0045, free energy: -15.4960 ±  0.2018
  600 loss:  -0.0004, free energy: -15.4897 ±  0.2234
  700 loss:  -0.0051, free energy: -15.4962 ±  0.1989
  800 loss:  -0.0036, free energy: -15.4960 ±  0.2006
  900 loss:   0.0119, free energy: -15.4698 ±  0.2665
 1000 loss:  -0.0064, free energy: -15.4926 ±  0.2137
 1100 loss:   0.0005, free energy: -15.4887 ±  0.2241
 1200 loss:  -0.0046, free energy: -15.4963 ±  0.1993
 1300 loss:  -0.0034, free energy: -15.4964 ±  0.1989
 1400 loss:   0.0045, free energy: -15.4602 ±  0.2740
 1500 loss:   0.0059, free energy: -15.4759 ±  0.2633
 1600 loss:  -0.0030, free energy: -15.4902 ±  0.2210
 1700 loss:  -0.0051, free energy: -15.4939 ±  0.2097
 1800 loss:  -0.0056, free energy: -15.4952 ±  0.2075
 1900 loss:   0.0003, free e

Short-range correlation function.

In [52]:
x = model.sample(100000);
dims = tuple(range(-model.model.lattice.dimension,0))
for shift in [(-1,0), (0,-1)]:
    rolled = model.model.group.inv(x.roll(shift, dims))
    coupled = model.model.group.mul(rolled, x)
    corr = model.model.group(coupled, torch.tensor([1.,-1.]))
    print(corr.mean(0))

tensor([[0.8074, 0.8101, 0.7988, 0.8088],
        [0.8057, 0.7937, 0.7995, 0.7878],
        [0.8053, 0.8152, 0.8027, 0.8149],
        [0.8005, 0.7934, 0.7990, 0.7912]], device='cuda:0')
tensor([[0.8197, 0.8057, 0.8142, 0.8069],
        [0.8345, 0.8004, 0.8332, 0.8125],
        [0.8219, 0.8033, 0.8164, 0.8084],
        [0.8338, 0.8065, 0.8413, 0.8100]], device='cuda:0')


### Tree Lattice (No Scale-Invariance)

Set up the model and train.

In [41]:
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), TreeLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=1.5, scale_invariance=False).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02);

In [45]:
import time
batch_size = 5000
steps = 2000
echo = 100
cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
tic = time.time()
for k in range(steps):
    loss, meanfree, stdfree = model.loss(batch_size, return_statistics=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cum_loss += loss.item()
    cum_meanfree += meanfree
    cum_stdfree += stdfree
    if (k+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f} ±{:8.4f}'.format(k+1, cum_loss/echo, cum_meanfree/echo, cum_stdfree/echo))
        cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
toc = time.time()
print('total time: {:6.2f}s, {:6.2f}ms per step.'.format(toc-tic, (toc-tic)/steps*1000))

  100 loss:  -0.0052, free energy: -15.4751 ±  0.2636
  200 loss:  -0.0032, free energy: -15.4756 ±  0.2536
  300 loss:  -0.0050, free energy: -15.4774 ±  0.2510
  400 loss:  -0.0051, free energy: -15.4774 ±  0.2537
  500 loss:  -0.0042, free energy: -15.4785 ±  0.2533
  600 loss:  -0.0054, free energy: -15.4787 ±  0.2524
  700 loss:  -0.0044, free energy: -15.4775 ±  0.2570
  800 loss:  -0.0058, free energy: -15.4769 ±  0.2528
  900 loss:  -0.0057, free energy: -15.4791 ±  0.2511
 1000 loss:  -0.0059, free energy: -15.4798 ±  0.2482
 1100 loss:  -0.0062, free energy: -15.4797 ±  0.2501
 1200 loss:  -0.0049, free energy: -15.4760 ±  0.2615
 1300 loss:  -0.0059, free energy: -15.4786 ±  0.2559
 1400 loss:  -0.0053, free energy: -15.4803 ±  0.2476
 1500 loss:  -0.0033, free energy: -15.4772 ±  0.2592
 1600 loss:  -0.0071, free energy: -15.4810 ±  0.2455
 1700 loss:  -0.0058, free energy: -15.4815 ±  0.2459
 1800 loss:  -0.0043, free energy: -15.4788 ±  0.2572
 1900 loss:  -0.0032, free e

Short-range correlation function.

In [46]:
x = model.sample(100000);
dims = tuple(range(-model.model.lattice.dimension,0))
for shift in [(-1,0), (0,-1)]:
    rolled = model.model.group.inv(x.roll(shift, dims))
    coupled = model.model.group.mul(rolled, x)
    corr = model.model.group(coupled, torch.tensor([1.,-1.]))
    print(corr.mean(0))

tensor([[0.8051, 0.7828, 0.7994, 0.7838],
        [0.8053, 0.7586, 0.7926, 0.7517],
        [0.7971, 0.7746, 0.7866, 0.7724],
        [0.8098, 0.7715, 0.7993, 0.7634]], device='cuda:0')
tensor([[0.7897, 0.7933, 0.7962, 0.7986],
        [0.7898, 0.7839, 0.7955, 0.7904],
        [0.7833, 0.7845, 0.7829, 0.7892],
        [0.7812, 0.7774, 0.7895, 0.7831]], device='cuda:0')


**Conclusion**
- Although there are more parameters without scale invariance, the model is harder to converge (more easily trapped at local minimum of fully correlated configurations), and the performance is not improved.
- The translation symmetry breaking still persists in the correlation function.

**Possible Solution:** Consider symmetrizing the probability distribution. To restore the translation symmetry, one idea is to symmetrize the sample by random translations. This amounts to model the target probability $p(x)$ by a **mixtrue model**
$$q_\theta(x) = \frac{1}{N}\sum_a q_\theta(T_a(x)),$$
where $T_a$ denotes the translation operator that translates the configuration $x$ by $a$. $a$ is summed over the translation group of the lattice and $N$ is the order of the translation group.

The loss function becomes
$$\begin{split}\mathcal{L}&=\sum_x \Big(\frac{1}{N}\sum_a q_\theta(T_a(x))\Big)\bigg(E(x)+\ln \Big(\frac{1}{N}\sum_b q_\theta(T_b(x))\Big)\bigg)\\
&=\sum_x \Big(\frac{1}{N}\sum_a q_\theta(x)\Big)\bigg(E(T_{-a}(x))+\ln \Big(\frac{1}{N}\sum_b q_\theta(T_{b-a}(x))\Big)\bigg)\\
&=\sum_x  q_\theta(x)\bigg(E(x)+\ln \Big(\frac{1}{N}\sum_b q_\theta(T_{b}(x))\Big)\bigg)
\end{split}$$
Here we have used the fact that the summation of configuration $\sum_x$, the summation of translation group $\sum_b$ and the energy function $E(x)$ are all translationally invariant. Therefore, we only need to replace the log probability by its symmetrized version, which can be approximate by the log-mean-exp of the log probability over finite number of random translations.

## Mixtrue Model

### Random Space Group

Space group operation of a point $x\in \mathbb{Z}_n^d$ includes:
- translation $x\to x+a$ for $a\in \mathbb{Z}_n^d$,
- inversion $x\to (-)^s\odot x$ for $s\in \mathbb{Z}_2^d$,
- permutation of coordinates $(x_1,x_2,\cdots)\to(x_{\pi_1}, x_{\pi_2},\cdots)$ for $\pi\in S_d$.

`randperm(sample_size)` method samples random space group transformation (as site permutation) of the lattice.

In [190]:
Lattice(4,2).randperm(3)

tensor([[14, 10,  6,  2, 15, 11,  7,  3, 12,  8,  4,  0, 13,  9,  5,  1],
        [ 1, 13,  9,  5,  0, 12,  8,  4,  3, 15, 11,  7,  2, 14, 10,  6],
        [12, 13, 14, 15,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]])

In [45]:
%r
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), TreeLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=1.5, mixtures=4).to(device);

In [46]:
x = model.sample(3)
xs = model.randmix(x, 2)

tensor([[[[1, 0, 1, 0],
          [0, 0, 0, 1],
          [1, 0, 0, 0],
          [1, 0, 1, 0]],

         [[1, 0, 0, 0],
          [1, 0, 1, 0],
          [1, 0, 1, 0],
          [0, 1, 0, 0]]],


        [[[1, 1, 0, 0],
          [0, 1, 0, 1],
          [1, 1, 1, 0],
          [0, 0, 1, 1]],

         [[1, 0, 1, 1],
          [0, 1, 1, 0],
          [1, 0, 0, 1],
          [0, 1, 0, 1]]],


        [[[1, 0, 1, 0],
          [0, 1, 0, 1],
          [0, 1, 0, 1],
          [0, 1, 0, 1]],

         [[0, 1, 0, 1],
          [0, 1, 0, 1],
          [1, 0, 1, 0],
          [0, 1, 0, 1]]]], device='cuda:0')

### Bin Statistics

Bin statistics is a technique to estimate the expectation of a nonlinear map of random variables. In our problem, we need to evaluate the log probability of the mixture model $\ln(\frac{1}{N}\sum_{b}q_\theta(T_b(x)))$ with finite number of samples. Suppose $\bar{q}_\theta$ be the expectation of $q_\theta$ in the large-$N$ limit, we expect
$$\ln\big(\frac{1}{N}\sum q_\theta\big)=\ln\big(\bar{q}_\theta + \frac{1}{N}\sum \delta q_\theta\big)=\ln \bar{q}_\theta - \frac{1}{2N}\frac{\text{var}(\delta q_\theta)}{\bar{q}_\theta^2}+\mathcal{O}(N^{-2})$$
Therefore we should estimate the log of mixture probability by
$$\ln \bar{q}_\theta \simeq \ln\big(\frac{1}{N}\sum q_\theta\big)+ \frac{1}{2N}\frac{\text{var}(q_\theta)}{\bar{q}_\theta^2}.$$
In evaluating the subleading term, $q_\theta$ can be renormalized by an overall factor to avoid over/under flow.

In [47]:
model.log_prob(x)

tensor([-9.6945, -9.2867, -5.4256], device='cuda:0', grad_fn=<SumBackward1>)

In [72]:
model.log_mixed_prob(x, 4)

tensor([-12.2184, -10.4397,  -7.2398], device='cuda:0', grad_fn=<AddBackward0>)

## Mixture Model

### Flat Lattice

no mixture

In [73]:
%r
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), FlatLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=2.01).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02);

In [74]:
import time
batch_size = 5000
steps = 2000
echo = 100
cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
tic = time.time()
for k in range(steps):
    loss, meanfree, stdfree = model.loss(batch_size, return_statistics=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cum_loss += loss.item()
    cum_meanfree += meanfree
    cum_stdfree += stdfree
    if (k+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f} ±{:8.4f}'.format(k+1, cum_loss/echo, cum_meanfree/echo, cum_stdfree/echo))
        cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
toc = time.time()
print('total time: {:6.2f}s, {:6.2f}ms per step.'.format(toc-tic, (toc-tic)/steps*1000))

  100 loss:  -0.3220, free energy: -14.6379 ±  0.5976
  200 loss:   0.1130, free energy: -15.0297 ±  0.4113
  300 loss:   0.0331, free energy: -15.2847 ±  0.4678
  400 loss:  -0.0740, free energy: -15.4785 ±  0.2889
  500 loss:  -0.0492, free energy: -15.4928 ±  0.2397
  600 loss:  -0.0434, free energy: -15.5005 ±  0.2060
  700 loss:  -0.0256, free energy: -15.4858 ±  0.2451
  800 loss:  -0.0386, free energy: -15.5050 ±  0.1848
  900 loss:  -0.0347, free energy: -15.5062 ±  0.1775
 1000 loss:  -0.0326, free energy: -15.5072 ±  0.1700
 1100 loss:  -0.0259, free energy: -15.5038 ±  0.1844
 1200 loss:  -0.0233, free energy: -15.5037 ±  0.1858
 1300 loss:  -0.0278, free energy: -15.5091 ±  0.1604
 1400 loss:  -0.0273, free energy: -15.5092 ±  0.1580
 1500 loss:  -0.0157, free energy: -15.5002 ±  0.1942
 1600 loss:  -0.0268, free energy: -15.5100 ±  0.1551
 1700 loss:  -0.0247, free energy: -15.5104 ±  0.1505
 1800 loss:  -0.0250, free energy: -15.5096 ±  0.1553
 1900 loss:  -0.0229, free e

2 mixtures

In [75]:
%r
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), FlatLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=2.01).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02);

In [76]:
import time
batch_size = 5000
steps = 2000
echo = 100
cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
tic = time.time()
for k in range(steps):
    loss, meanfree, stdfree = model.loss(batch_size, mixtures=2, return_statistics=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cum_loss += loss.item()
    cum_meanfree += meanfree
    cum_stdfree += stdfree
    if (k+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f} ±{:8.4f}'.format(k+1, cum_loss/echo, cum_meanfree/echo, cum_stdfree/echo))
        cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
toc = time.time()
print('total time: {:6.2f}s, {:6.2f}ms per step.'.format(toc-tic, (toc-tic)/steps*1000))

  100 loss:   0.9273, free energy: -14.4217 ±  1.5607
  200 loss:  -0.3956, free energy: -14.9241 ±  0.9625
  300 loss:  -0.7220, free energy: -15.3860 ±  0.7543
  400 loss:  -0.2778, free energy: -15.5271 ±  0.3313
  500 loss:  -0.1579, free energy: -15.5233 ±  0.2587
  600 loss:  -0.1072, free energy: -15.5220 ±  0.2335
  700 loss:  -0.0918, free energy: -15.5189 ±  0.2196
  800 loss:  -0.0892, free energy: -15.5210 ±  0.2138
  900 loss:  -0.0564, free energy: -15.5204 ±  0.2210
 1000 loss:  -0.0489, free energy: -15.5174 ±  0.2940
 1100 loss:  -0.0707, free energy: -15.5202 ±  0.1886
 1200 loss:  -0.0562, free energy: -15.5218 ±  0.1809
 1300 loss:  -0.0490, free energy: -15.5217 ±  0.1818
 1400 loss:  -0.0445, free energy: -15.5212 ±  0.1716
 1500 loss:  -0.0414, free energy: -15.5217 ±  0.1743
 1600 loss:  -0.0356, free energy: -15.5208 ±  0.1903
 1700 loss:  -0.0388, free energy: -15.5191 ±  0.1782
 1800 loss:   0.0074, free energy: -15.5172 ±  0.3311
 1900 loss:  -0.0246, free e

### Tree Lattice

2 mixtures, scale-invariant

In [77]:
%r
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), TreeLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=1.5).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02);

In [78]:
import time
batch_size = 5000
steps = 2000
echo = 100
cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
tic = time.time()
for k in range(steps):
    loss, meanfree, stdfree = model.loss(batch_size, mixtures=2, return_statistics=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cum_loss += loss.item()
    cum_meanfree += meanfree
    cum_stdfree += stdfree
    if (k+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f} ±{:8.4f}'.format(k+1, cum_loss/echo, cum_meanfree/echo, cum_stdfree/echo))
        cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
toc = time.time()
print('total time: {:6.2f}s, {:6.2f}ms per step.'.format(toc-tic, (toc-tic)/steps*1000))

  100 loss:  -0.3460, free energy: -15.1937 ±  1.5339
  200 loss:  -0.1629, free energy: -15.5608 ±  0.9051
  300 loss:   0.0047, free energy: -15.5330 ±  0.7099
  400 loss:  -0.0880, free energy: -15.5308 ±  0.6101
  500 loss:  -0.0123, free energy: -15.5319 ±  0.5071
  600 loss:  -0.0285, free energy: -15.5335 ±  0.4480
  700 loss:  -0.0214, free energy: -15.5268 ±  0.3983
  800 loss:  -0.0184, free energy: -15.5257 ±  0.3715
  900 loss:  -0.0224, free energy: -15.5262 ±  0.3481
 1000 loss:  -0.0190, free energy: -15.5307 ±  0.3410
 1100 loss:  -0.0135, free energy: -15.5254 ±  0.3330
 1200 loss:  -0.0235, free energy: -15.5277 ±  0.3440
 1300 loss:  -0.0191, free energy: -15.5251 ±  0.3231
 1400 loss:  -0.0134, free energy: -15.5259 ±  0.2944
 1500 loss:  -0.0159, free energy: -15.5281 ±  0.2985
 1600 loss:  -0.0122, free energy: -15.5252 ±  0.2914
 1700 loss:  -0.0214, free energy: -15.5244 ±  0.2825
 1800 loss:  -0.0155, free energy: -15.5202 ±  0.3095
 1900 loss:  -0.0100, free e

4 mixtures, scale-invariant

In [80]:
%r
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), TreeLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=1.5).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02);

In [81]:
import time
batch_size = 5000
steps = 2000
echo = 100
cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
tic = time.time()
for k in range(steps):
    loss, meanfree, stdfree = model.loss(batch_size, mixtures=4, return_statistics=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cum_loss += loss.item()
    cum_meanfree += meanfree
    cum_stdfree += stdfree
    if (k+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f} ±{:8.4f}'.format(k+1, cum_loss/echo, cum_meanfree/echo, cum_stdfree/echo))
        cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
toc = time.time()
print('total time: {:6.2f}s, {:6.2f}ms per step.'.format(toc-tic, (toc-tic)/steps*1000))

  100 loss:  -0.3825, free energy: -15.3044 ±  1.2170
  200 loss:  -0.0540, free energy: -15.5257 ±  0.4823
  300 loss:   0.0003, free energy: -15.5078 ±  0.3442
  400 loss:   0.0071, free energy: -15.5129 ±  0.3115
  500 loss:  -0.0018, free energy: -15.5215 ±  0.2993
  600 loss:  -0.0081, free energy: -15.5197 ±  0.2747
  700 loss:  -0.0010, free energy: -15.5167 ±  0.2644
  800 loss:  -0.0068, free energy: -15.5206 ±  0.2535
  900 loss:  -0.0042, free energy: -15.5233 ±  0.2364
 1000 loss:  -0.0050, free energy: -15.5221 ±  0.2313
 1100 loss:  -0.0087, free energy: -15.5224 ±  0.2202
 1200 loss:  -0.0089, free energy: -15.5189 ±  0.2153
 1300 loss:   0.0002, free energy: -15.5210 ±  0.2350
 1400 loss:  -0.0051, free energy: -15.5231 ±  0.2060
 1500 loss:  -0.0008, free energy: -15.5208 ±  0.2086
 1600 loss:  -0.0048, free energy: -15.5231 ±  0.2086
 1700 loss:  -0.0049, free energy: -15.5186 ±  0.1963
 1800 loss:  -0.0020, free energy: -15.5195 ±  0.1952
 1900 loss:  -0.0046, free e

16 mixtures, scale-invariant

In [83]:
%r
H = lambda J: -J*(TwoBody([1,0],[1,-1]) + TwoBody([0,1],[1,-1]))
model = Model(H(0.440686793), TreeLattice(4, 2), SymmetricGroup(2))
model = PixelGNN(model, hidden_features=[8, 8], radius=1.5).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.02);

In [84]:
import time
batch_size = 5000
steps = 2000
echo = 100
cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
tic = time.time()
for k in range(steps):
    loss, meanfree, stdfree = model.loss(batch_size, mixtures=16, return_statistics=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    cum_loss += loss.item()
    cum_meanfree += meanfree
    cum_stdfree += stdfree
    if (k+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f} ±{:8.4f}'.format(k+1, cum_loss/echo, cum_meanfree/echo, cum_stdfree/echo))
        cum_loss, cum_meanfree,cum_stdfree = 0., 0., 0.
toc = time.time()
print('total time: {:6.2f}s, {:6.2f}ms per step.'.format(toc-tic, (toc-tic)/steps*1000))

  100 loss:  -0.8319, free energy: -15.0163 ±  1.0294
  200 loss:  -0.0243, free energy: -15.5077 ±  0.3929
  300 loss:  -0.0091, free energy: -15.5093 ±  0.3110
  400 loss:   0.0055, free energy: -15.5146 ±  0.2755
  500 loss:   0.0035, free energy: -15.5175 ±  0.2530
  600 loss:   0.0048, free energy: -15.5162 ±  0.2436
  700 loss:   0.0027, free energy: -15.5113 ±  0.2447
  800 loss:   0.0025, free energy: -15.5185 ±  0.1948
  900 loss:   0.0036, free energy: -15.5152 ±  0.1894
 1000 loss:   0.0047, free energy: -15.5190 ±  0.1747
 1100 loss:   0.0023, free energy: -15.5216 ±  0.1672
 1200 loss:   0.0032, free energy: -15.5208 ±  0.1606
 1300 loss:   0.0079, free energy: -15.5199 ±  0.1554
 1400 loss:   0.0048, free energy: -15.5220 ±  0.1560
 1500 loss:   0.0019, free energy: -15.5219 ±  0.1416
 1600 loss:   0.0040, free energy: -15.5203 ±  0.1449
 1700 loss:   0.0025, free energy: -15.5184 ±  0.1480
 1800 loss:   0.0050, free energy: -15.5211 ±  0.1387
 1900 loss:   0.0041, free e