# Holographic Pixel Graph Convolutional Network

In [1]:
%run "model.py"

## Motivation

Generative models have been applied to solve Statistical Mechanics problems by learning to model the Boltzmann distribution of microscopic configurations given the energy functional. This project aims to explore the combination of both **flow-based** and **autoregressive** models together with ideas from **renormalization group** and **holography** to tackle **critical** (scale-free) spin systems.

### Technical Goals

* Develop **hierachical flow-based and autoregressive models**, combining them in the hyperbolic space. (For hierachical flow see [NeuralRG](https://arxiv.org/pdf/1802.02840.pdf) and related works)
  * Hierachical flow learns the holographic mapping (i.e. wavelet transformation) that brings spin configurations to their wavelet encodings (like Haar wavelet encodings) in the holographic bulk.
  * Autoregressive model learns the base distribution of wavelet encodings in the holographic bulk instead of the spin configurations on the holographic boundary.

* Generalize [descrete flow](https://arxiv.org/pdf/1905.10347.pdf) and [integer flow](https://arxiv.org/pdf/1905.07376.pdf) to generic **non-Abelian descrete groups** (previously proposed XOR and Mod-K transforms are $\mathbb{Z}_2$ and $\mathbb{Z}_K$ groups). Some future ideas [*Not implemented yet, the current project will use a fixed transformation (e.g. Haar wavelet)*]:
  * Foward computation allow two group elements to perform a controlled transform, such as $(g_1,g_2)\to(g_1,g_1 g_2)$ (can be left- or right-multiplication, inversion, conjugation ...). Non-Abelian group may provide sufficient scrambling. Abelian group can be embedded in non-abelian groups.
  * Fuzzy group: focus on distribution of group elements which is contineous and can flow. Similar to the idea of Gumbel-softmax.
  
* Develop more flexible autoregressive model based on message passing on **directed causal graphs** in the holographic space using graph convolutional network (GCN) techniques. This will allow us to apply autoregressive model in the holographic bulk, where the causal relations depends on the flow-based model and can be quite involved. GCN provides us the flexibility to dynamically determine the neural network connectivity.

### Scientific Goals

* **Analyze the RG Flow (Down Sampling)**: autoregressive model is good at marginalize causal dependants, which correspond to the UV degree of freedom. This can be used to analyze the RG flow after training. 
* **Analyze the Scaling Behavior (Up Sampling)** If we impose parameter sharing across RG scales, the model could learn a scale invariant transformatioin rule that can be genralize to larger systems. This will allow us to perform finite-size scaling analysis by up sampling.
* **Probe Operator Scaling Dimension**: using the wormhole idea in the holographic space by resampling UV latent variables. This provide us a novel approach to obtain scaling dimension without a notion of spacetime.
* **Speed up Monte Carlo(?)**: Need to measure the dynamic exponent to see.
* **Application - $S_N$ Models**: $S_N$ spin model have $N!$ spin states (corresponding to $S_N$ group elements) on each site, which grows with $N$ quickly. Local update will be inefficient in this case, will hierachical autoregressive model be more efficient? These models will be important for us to understand **entanglement transitions** in random quantum circuits.

## Model Design

### Architecture Overview

The model consists of the following parts:

<img src="./image/model.png" alt="model" width="360"/>

* A generative model $p(x)$ consist of
  * A base model $p(z)$ realized as an **autoregressive model**, which uses graph convolutional network techniques to compute conditional distributions on a directed causal graph.
  * A stack of transformations containing
     * A **bijective encoding** (beetween one-hot and categorical)
     * A **renormalization group (RG) transformation** realized as a flow model (but currently fixed to be Haar wavelet transformation in this project).
* An **energy model** $E(x)$ must be provided as input to drive the training.
* All these modules are based on information provided by the infrastructure layer which contains:
  * A **group model** to provide basic functions of group operation and group function evaluation.
  * A **lattice model** to provide indexing of nodes and to construct the causal graph in the holographic bulk.
  
Finally, the model is trained to minize the variational free energy.

### Group

`Group` represents a group $G$ specified by the multiplication table. Group elements will be labeled by integers (ranging from 0 to the order of the group). The element 0 is always treated as the identity element of the group. `Group` provides methods to:
* perform element-wise group *multiplication* and *inversion* for Torch tensors.
* perform element-wise evaluation of *group functions* $f:G\to \mathbb{R}$.

#### Generic Discrete Group

Create a $S_3$ group.

In [2]:
G = Group(torch.tensor([[0, 1, 2, 3, 4, 5],
                        [1, 0, 4, 5, 2, 3],
                        [2, 3, 0, 1, 5, 4],
                        [3, 2, 5, 4, 0, 1],
                        [4, 5, 1, 0, 3, 2],
                        [5, 4, 3, 2, 1, 0]]))

Multiplying two tensors element-wise following the group multiplication rule.

In [3]:
a = torch.tensor([[0, 1, 2], [3, 4, 5]])
b = torch.tensor([[5, 4, 3], [2, 1, 0]])
G.mul(a, b)

tensor([[5, 2, 1],
        [5, 5, 5]])

Product of each row of a tensor in the given dimension.

In [4]:
G.prod(a, dim=0)

tensor([3, 2, 4])

Group inversion of all elements.

In [5]:
G.inv(a)

tensor([[0, 1, 2],
        [4, 3, 5]])

Evaluate a group function given by a value table `val_table` (default function: group delta function).

In [6]:
G.val(a)

tensor([[1., 0., 0.],
        [0., 0., 0.]])

In [7]:
G.val(a, val_table=torch.tensor([1.,0.,0.,-0.5,-0.5,0.]))

tensor([[ 1.0000,  0.0000,  0.0000],
        [-0.5000, -0.5000,  0.0000]])

#### Symmetric Group

Symmetric group can be constructed more conviniently

In [8]:
SymmetricGroup(3)

Group(6 elements)

In [9]:
SymmetricGroup(3).mul_table

tensor([[0, 1, 2, 3, 4, 5],
        [1, 0, 4, 5, 2, 3],
        [2, 3, 0, 1, 5, 4],
        [3, 2, 5, 4, 0, 1],
        [4, 5, 1, 0, 3, 2],
        [5, 4, 3, 2, 1, 0]])

### Lattice

`Lattice` represent a $d$-dimensional regular lattice of size $L$ (containing totally $L^d$ sites).

#### Site and Node Index Systems
Each physical site at coordinate $(x_0,\cdots,x_{d-1})$ with $x_k=0,\cdots,L-1$ is labeled by the integer index (C-format ordering), called the **site index**,
$$i=\sum_{k=0}^{d-1}L^{d-1-k}x_k.$$

The renormalization group transformation will be organized on a binary tree, which coarse grains along each direction cyclicaly. The information flow under the RG transformation in a binary tree, which forms a [H-tree](https://en.wikipedia.org/wiki/H_tree) fractal in the space. The leaves of the binary tree are the lattice sites, but they follow a different ording if the binary tree is stored in a 1D heap list. The 1D heap list index of the leave is called the **node index**.

<img src="./image/index_systems.png" alt="index_systems" width="380"/>

The **node index** rearranges the lattice of any dimension into a 1D array, such that the RG scheme is fixed by the canonical *binary heap structure*. In this way, the RG transformation can be universally applied to lattices of any *size* (need to be $2^n$) and any *dimensions*. 

The conversion between site and node indices are provided by the `Lattice` class. `Lattice.node_index` provides the mapping from node to site index (i.e. it lists the site index of each node in an array).

In [31]:
Lattice(4, 2).node_index

tensor([ 0,  1,  4,  5,  2,  3,  6,  7,  8,  9, 12, 13, 10, 11, 14, 15])

#### Causal Graph

The autoregressive model uses conditional probabilities to model the joint probability. Each conditional probability entails the underlying causal influence that a random variable will receive from its conditional variables. The causal relations form a *directed graph*, called the **causal graph**. Since we have mapped the RG transform to a binary tree universally, we only need to analyze the causal relations on the tree.

<img src="./image/causal_graph.png" alt="causal_graph" width="400"/>

The following causal relations will beconsidered
* 0: self
* 1: child
* 2: sibling
* 3: niephew
* 4: cousin
* 5: grandchild

`Lattice.causal_graph()` returns table of sources, targets and types of all edges. This piece of data will be used to construct graph convolutional layers.

In [25]:
Lattice(4, 2).causal_graph()

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,  1,  1,  2,
          2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  2,  4,  6,  8, 10, 12, 14,
          3,  3,  2,  2,  5,  5,  4,  4,  7,  7,  6,  6,  4,  5,  4,  5,  8,  9,
          8,  9, 12, 13, 12, 13,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,  2,  3,  4,
          5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,  3,  5,  7,  9, 11, 13, 15,
          4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,  6,  6,  7,  7, 10, 10,
         11, 11, 14, 14, 15, 15,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,
          3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,
          4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5]])

### Energy Model

We need an energy model to describe the Statistical Mechanics system. `EnergyModel` provides the function to evalutate the energy of a configuration.

#### Construction

Consider a 2D Ising model on a square lattice
$$H= -J \sum_{i}(\sigma_i\sigma_{i+\hat{x}} + \sigma_i\sigma_{i+\hat{y}}).$$
The Hamiltonian can be typed in as (see the following subsection for explaination of the notation)

In [11]:
H = lambda J: -J*(TwoBody(torch.tensor([1.,-1.]), (1,0)) 
                  + TwoBody(torch.tensor([1.,-1.]), (0,1)))

The Hamiltonian at this point is just an abstract notation. It must be combined with the specific `Group` and `Lattice` to form a concrete energy model. The energy model itself is a torch module (without any trainable parameters), which can be used to evaluate the energy of any spin configuration.

In [19]:
energy = EnergyModel(H(0.5), SymmetricGroup(2), Lattice(4, 2))
energy

EnergyModel(
  (group): Group(2 elements)
  (lattice): Lattice(4x4 grid with tree depth 5)
  (energy): EnergyTerms(
    (0): TwoBody(G -> [-0.5, 0.5] across (1, 0))
    (1): TwoBody(G -> [-0.5, 0.5] across (0, 1))
  )
)

Let us generate some spin configurations.

In [22]:
x = torch.randint(2, (2, 4, 4))
x

tensor([[[1, 1, 1, 1],
         [1, 0, 0, 1],
         [0, 1, 0, 0],
         [1, 1, 1, 1]],

        [[1, 0, 0, 0],
         [0, 0, 0, 0],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]])

Evalutate the energy of these spin configuraitons by the energy model.

In [23]:
energy(x)

tensor([-4., -6.])

We can use `EnergyModel.update(H)` to assign a new Hamiltonian to the energy model (without changing the group and the lattice). This will be useful for annealing the system during training.

In [24]:
energy.update(H(1.))
energy(x)

tensor([ -8., -12.])

#### Hamiltonian Scripting System

In order to facilitate the formulation of Hamiltonian intuitively, we introduce a scripting system. Physical Hamiltonians are always sum of local energy terms. Each energy term is a subclass of `nn.Module` and each Hamiltonian is a subclass of `nn.ModuleList` (which contains the collection of energy terms). In this way, the evaluation of the total energy of the Hamiltonian can be distributed to each energy term in parallel.

We introduce two kinds of energy terms
* `OnSite`: on-site energy term $E_1(g_i)$,
* `TwoBody`: two-body interaction term $E_2(g_i,g_j)$.

More types of interaction terms can be introduced under this framework if necessary. These energy terms are group functions: $E_1:G\to\mathbb{R}$, $E_2:G\times G\to\mathbb{R}$. These group functions can be specified by value tables, which enumerate the value that each group element maps to. For example, for the $\mathbb{Z}_2=\{0,1\}$ group ($0$-identity, $1$-flip), if we want to specify
$$E_1(0)=+1, E_1(1)=-1,$$
the value talbe is $[+1,-1]$. Such a term can be created as follows

In [25]:
OnSite(torch.tensor([1.,-1.]))

OnSite(G -> [1.0, -1.0])

We assume the two-body term always take the form of
$$E_2(g_i,g_j)=E_2(g_i^{-1}g_j),$$
such that we will only need to a single-variable group function, unsing the same value table representation. For example,

In [26]:
TwoBody(torch.tensor([1.,-1.]), (1,0))

TwoBody(G -> [1.0, -1.0] across (1, 0))

The two-body term also carries a second argument to specify the relative direction from site-$i$ to site-$j$. If the value table is not specified, the default group function will be used:
* For generic `Group`, the default group function is the delta function (like Potts model), which maps the identity element to 1 and the others to 0.
* For `SymmetricGroup`, the default group function is the cycle counting function (count the number of permutation cycles).

We can add, subtract, scalar multiply and negate the energy terms. Energy terms adding together will be represented as a collection of terms in a list (`nn.ModuleList`), which correspond to a Hamiltonian. For example

In [27]:
-2.8 * OnSite() + 5.2 * (TwoBody(shifts=(1,0)) + TwoBody(shifts=(0,1)))

EnergyTerms(
  (0): TwoBody(5.2 across (1, 0))
  (1): TwoBody(5.2 across (0, 1))
  (2): OnSite(-2.8)
)

### Haar Transformation

`HaarTransformation` is a bijective map between the spin configuration and the Haar wavelet encoding. It is used to realize a fixed version of **invertible RG transform** (**holographic mapping**). In the future, it could be replaced by a trainable **descrete flow** model.

#### Decoding Map (Generation Flow)

The decoding map takes the wavelet component $z$ to the spin configuration $x$ following 
$$x_i = \prod_a z_a^{w_{ai}},$$
where $w_{ai}=0,1$ is the Haar wavelet matrix, which is given by

In [28]:
HaarTransform(SymmetricGroup(2), Lattice(4, 2)).wavelet

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1],
        [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int32)

The decoding (generation) process can be considered as starting with a trivial initial configuration of $\forall i: x_i=0$, walking through every latent variable $z_a$: if $w_{ai}=0$, do nothing; if $w_{ai}=1$, left multiply $z_a$ to $x_i$ according to the group algebra (note: $i$ in $x_i$ is the site index).

#### Encoding Map (Renormalization Flow)

The encoding map takes the spin configuration $x$ to the wavelet component $z$ following the renormalization procedure. The first step is to rearrange the spin $x_i$ by the **node index** order and define it as
$$y_a^{(0)} = x_{i(a)},$$
where the index map $i(a)$ is given by `Lattice.node_index`. Then follow the iterative approach to perform the RG transformation (let $N$ be the total number of nodes in binary heap tree)
* for $k=1,2,\cdots,\log_2 N -1$
  * for $q=0,1,\cdots,2^{-k}N-1$
$$\begin{split}
\text{IR: }y_q^{(k+1)}&\leftarrow y_{2q}^{(k)},\\
\text{UV: }y_{2^{-k}N+q}^{(k+1)}&\leftarrow \big(y_{2q}^{(k)}\big)^{-1}y_{2q+1}^{(k)}
\end{split}$$

The final outcome is $z=y^{(\log_2 N)}$.

#### Example

Consider a $S_3$ group. Generate some wavelet configuration

In [29]:
G = SymmetricGroup(3)
z = torch.randint(G.order, (16,))
z

tensor([3, 0, 3, 2, 0, 4, 0, 3, 1, 2, 2, 0, 4, 4, 3, 4])

Transform the 16 wavelet components to the spin configuration on a 4 x 4 lattice.

In [30]:
ht = HaarTransform(G, Lattice(4, 2))
ht(z)

tensor([[3, 2, 4, 1],
        [3, 5, 3, 3],
        [3, 0, 5, 2],
        [3, 0, 2, 5]])

Transform back and verify that the encoder and decoder are inverse to each other.

In [31]:
ht.inv(ht(z))

tensor([3, 0, 3, 2, 0, 4, 0, 3, 1, 2, 2, 0, 4, 4, 3, 4])

### One-Hot Categorical Transformation

`OneHotCategoricalTransform` is a **bijective embedding** that convert between the group elements and their one-hot embeddings. This serves as an interface between the RG transformation (which works with group elements for efficiency) and the autoregressive model (which works with one-hot embeddings for training performance).

#### Example

In [32]:
z_cat = torch.randint(6, (2, 3))
z_cat

tensor([[4, 1, 1],
        [0, 5, 3]])

In [33]:
oc = OneHotCategoricalTransform(6)
z_emb = oc.inv(z_cat)
z_emb

tensor([[[0., 0., 0., 0., 1., 0.],
         [0., 1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 1., 0., 0.]]])

In [34]:
oc(z_emb)

tensor([[4, 1, 1],
        [0, 5, 3]])

### Graph Convolution Layer

`GraphConv` provides a graph convolution layer, given the graph structure. It creates separate linear layers for different types of edges, and perform the afine transformation following the edge direction

$$y_a=\sum_{b\to a}W_{t(b\to a)}x_b+ B_{t(b\to a)},$$

where $t(b\to a)$ denotes the type of the edge $b\to a$ and the summation goes through all edges in the directed graph. $W$ and $B$ are trainable weights and biases that depend on the edge type. `GraphConv` is implemented using `torch_geometric.nn.MessagePassing`.

#### Gerneric Non-Linear Causal Maps

We can create generic **non-linear causal maps** (maps that respect causality) by stacking multiple graph convolution layers (on the causal graph) and non-linear activation layers

$$\begin{split}
z_a^{(0)} &= z_a,\\
z_a^{(k+1)} &= \sum_{b\to a} \phi\big(W_{t(b\to a)}z_b^{(k)}+B_{t(b\to a)}\big)\\
\end{split}$$

In theory, after infinite iterations, causal influences will propagate throughout the entire system, such that the output $z^{(\infty)}$ will be a causal map

$$z_a^{(\infty)} = f(\{z_b\}_{b\to\cdots\to a}).$$

Suppose $z_a$ is the one-hot encoding of the input configuration, if we treat $z_a^{(\infty)}$ as a score function, we can create a model for the conditional distribution

$$\ln p(z_a|\{z_b\}_{b\to\cdots\to a}) = z_a\cdot\ln\text{softmax}[z_a^{(\infty)}(\{z_b\}_{b\to\cdots\to a})],$$

which can be combined to establish an autoregressive model

$$p(z)=\prod_{a}p(z_a|\{z_b\}_{b\to\cdots\to a}).$$

<img src="./image/causal_graph.png" alt="causal_graph" width="400"/>

*Comments*:
* Note that the first layer should not have self-connections (otherwise the causal relation is no longer directed), but all subsequent layers are allow to have self-connections.
* The node 0 is somewhat special, that it is always sampled independently from uniform distribution (one can see that it corresponds to the global symmetry of the spin model). It also has no causal relations with other nodes (a consequence of the Goldstone theorem: the order parameter should have zero excitation energy, this it can not interact with other modes and hence can not establish any causal relation).
* Edges of different colors indicate different types of causal relations. 
* Batch norm layer is added after each graph convolution to speed up training.

#### Implementation

`GraphConv(causal_graph, in_features, out_freatures, bias = True, self_loop = True)` 

Arguments:
* `causal_graph`: a dictionary mapping edge type to edge indices (stacking the source indices and the target indicies in a tensor). It can be generated by  `Lattice.causal_graph()`.
* `in_features`: number of input features on each node
* `out_features`: number of output features on each node
* `bias`: whether to learn the bias (edge type dependent)
* `self_loop`: whether to allow self connection

#### Examples

Create a graph convolution layer

In [34]:
GraphConv(Lattice(4, 2).causal_graph(), 2, 3)

GraphConv(edge_types=6, in_features=2, out_features=3, bias=True, self_loop=True)

It provides the `foward` method to evaluate the foward pass. This is realized by calling `MessagePassign.propagate`.

In [35]:
gc = GraphConv(Lattice(4, 2).causal_graph(), 2, 3)
x = torch.randn(16, 2)
gc(x)

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 1.2622,  0.2945, -0.9587],
        [ 0.1706, -0.8889,  0.3528],
        [ 0.6945, -0.1077,  0.3019],
        [-0.3917, -0.3284, -0.3799],
        [-0.6608, -0.7630,  0.3102],
        [ 1.0474, -0.5188, -0.9339],
        [ 0.9330, -0.0557, -0.9502],
        [ 0.4415, -1.2299,  0.2887],
        [ 0.3944, -0.8794,  0.1608],
        [ 0.9661,  0.8026, -0.8039],
        [ 1.3950,  1.4764, -0.9321],
        [ 0.2601, -0.0661, -0.1173],
        [-0.1682,  0.0967, -0.2770],
        [-0.6308,  1.1954,  0.3197],
        [-0.5837,  1.4543,  0.7457]], grad_fn=<ScatterAddBackward>)

To be used for autoregressive sampling, the `forward` method allows to forward *from* a specific node given by the second argument. This is realized by masking out other other edges that are not going out from the particular node.

In [36]:
gc(x, 3)

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.5347, -0.1380, -0.3850],
        [ 0.0544, -0.4091, -0.5442],
        [ 0.0544, -0.4091, -0.5442],
        [ 0.3125, -0.0503, -0.3516],
        [ 0.3125, -0.0503, -0.3516],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [-0.4237,  0.4151,  0.4241],
        [-0.4237,  0.4151,  0.4241],
        [-0.4237,  0.4151,  0.4241],
        [-0.4237,  0.4151,  0.4241]], grad_fn=<ScatterAddBackward>)

#### Update Causal Graph

One can update the causal graph (like updating a python dict) by `update_causal_graph` method.

In [39]:
new_graph = torch.cat([gc.graph, torch.tensor([
    [ 8, 8, 8, 8, 9, 9, 9, 9,10,10,10,10,11,11,11,11],
    [12,13,14,15,12,13,14,15,12,13,14,15,12,13,14,15],
    [ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6]])], -1)
gc.update_graph(new_graph)

GraphConv(edge_types=7, in_features=2, out_features=3, bias=True, self_loop=True)

New linear maps will be added if necessary. But the existing linear maps will not be modified (such that their parameters are preserved).

### Autoregressive Model

Autoregressive model are usually named as "pixel-..." due to its element-wise sampling approach. The following "..." will be the neural network architecture used to model the conditional distribution. As we model the conditional distribution by graph convolutional network (GCN), it might be fair to call the resulting autoregressive model as a pixel-GCN.

#### Advantages [*To be demonstrated yet*]

The major innovation is to put the pixel-GCN in the holographic bulk and use it to model the distribution of the Haar wavelet encodings. What could be the advantage of this approach?

* **Resolve the criticality**: the holographic mapping brings a scale-free system to a local system (with an emergnent scale set by the hyperbolic radius and the critical exponent). This can be seen from the correlation function of two spins of distance $r$ on the holographic boundary
$$C(r)\sim r^{-\alpha} \sim e^{-d/\xi},$$
where $d=R\ln r$ is the geodesic distance through the bulk and $\xi=R/\alpha$ is an emerngent length scale. The complexity of modeling correlation at all scales is reduced to modeling correlations locally in the bulk. This arguement justifies our assumption that only limited number of local causal relations need to be considered. 

* **Shorten the causal chain**: conventional approach like pixel-CNN has unnatural causal structures (why a single pixel must causally depend on its upper-half-plane?). The natural way to think about generating a image is to start paining the outline first, then add the details. In this way, the scale itself becomes the emergent time of the generation process, which impose a natural causal structure in the holographic bulk. A remarkable feature is that the holographic bulk has a hyperbolic (tree-like) geometry, such that **time is short**, i.e. the causal chain is at most of the length $\sim\log L$ (logarithmic in system size), and the causal cone has limited width (like the past light cone in an expanding universe, which light can not catch up the collapse of universe if we look backwards). This greatly reduce the model complexity for large systems and enables more efficient sampling and evaluation.

#### Implementation

`AutoregressiveModel` is both a torch Module and a torch Distribution, which can be used for both forward evaluation and sampling. It is constructed by

`AutoregressiveModel(lattice, features, nonlinarity = 'Tanh', bias = True)`

Arguments:
* `lattice`: a Lattice class which provides the information about the lattice and the method to construct the causal graph.
* `features`: a list of integers specifying the number of features form the first to the last layer of the GCNs.
* `nonlinearity`: nonlinear activation to use.
* `bias`: whether to learn the bias.

#### Sampling Methods

Create an autoregressive model and generate some samples by the `sample` method. This method generate samples under `torch.no_grad()` context, sucht that memory consumption is reduced.

In [17]:
ar = AutoregressiveModel(Lattice(2, 2), [2, 4, 2])
x = ar.sample(1)
x

tensor([[[1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.]]])

The log probability of the samples can then be evaluated by the `log_prob` method (through a one-pass forward pass).

In [21]:
ar.log_prob(x)

tensor([-3.2476], grad_fn=<SumBackward1>)

`sample_with_log_prob` method can return the samples and their log probabilities simutaneously by generating samples under gradient mode. The log probability is obtained from node-wise forward pass in this case. 

In [19]:
ar.sample_with_log_prob(1)

(tensor([[[1., 0.],
          [0., 1.],
          [0., 1.],
          [0., 1.]]]), tensor([-2.2215], grad_fn=<SumBackward1>))

Reparametrized sampling is supported by the `rsample` method, using Gumbel softmax. More specifically, `rsample(sample_size, tau, hard = False)` returns the tuple of (sample, log_prob) simutaneously, where `tau` sets the softmax temperature and `hard` determines whether the samples will be discretized as one-hot vectors (but will be defferentiated as if it is soft). [*Nevertheless, this functionality will not be used in this project.*]

In [11]:
ar.rsample(1, tau=0.1)

tensor([[[4.1223e-12, 1.0000e+00],
         [9.8222e-01, 1.7780e-02],
         [2.0688e-12, 1.0000e+00],
         [9.9397e-01, 6.0309e-03]]], grad_fn=<CopySlices>)

#### Cache System

Behind the sampling procedure is a cache system working with recursive node-wise forward passes. The cache tensors can be accessed by the internal method `_sample` (for debug use). Check that the site-wise forward is equivalent to the one-pass forward (within round-off error).

In [27]:
ar = AutoregressiveModel(Lattice(4, 2), [2, 4, 2])
with torch.no_grad():
    cache = ar._sample(1)
ar(cache[0]) - cache[-1]

tensor([[[ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 2.9802e-08, -2.9802e-08],
         [ 0.0000e+00,  1.4901e-08],
         [-1.1921e-07,  5.9605e-08],
         [ 1.1921e-07,  0.0000e+00],
         [ 2.9802e-08, -5.9605e-08],
         [ 5.9605e-08, -1.1921e-07],
         [ 1.1921e-07,  0.0000e+00],
         [ 0.0000e+00,  1.1921e-07],
         [-5.9605e-08,  5.9605e-08],
         [ 0.0000e+00, -8.9407e-08],
         [ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00]]], grad_fn=<SubBackward0>)

#### Causal Connections

Check the causal dependence by evaluating the jacobian matrix and collecting the nonzero indices as (target, source) tuple.

In [28]:
x = cache[0].squeeze().clone()
j = torch.autograd.functional.jacobian(ar, x).permute(0,2,1,3).det()
torch.nonzero(j, as_tuple=True)

(tensor([ 2,  3,  3,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  6,  7,  7,  7,
          7,  7,  7,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10,
         10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12,
         13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15,
         15, 15, 15, 15, 15, 15, 15, 15, 15]),
 tensor([ 1,  1,  2,  1,  2,  3,  1,  2,  3,  4,  1,  2,  3,  4,  5,  1,  2,  3,
          4,  5,  6,  1,  2,  3,  4,  5,  1,  2,  3,  4,  5,  8,  1,  2,  3,  4,
          5,  8,  9,  1,  2,  3,  4,  5,  8,  9, 10,  1,  2,  3,  4,  5,  6,  7,
          1,  2,  3,  4,  5,  6,  7, 12,  1,  2,  3,  4,  5,  6,  7, 12, 13,  1,
          2,  3,  4,  5,  6,  7, 12, 13, 14]))

Compared to the original (fully-connected) version of autoregressive model, some connections outside the causal cone is removed. For example, the groups {8,9,10,11} and {12,13,14,15} have no direct mutual connections (their correlations are mediated by other nodes). More precisely speaking, they do have mutual information but they do not have *conditional* mutual information. This weakers the model, but that could be the price to pay for efficiency.

| model     | connections     | parameters    |
|-----------|-----------------|---------------|
| original  | $\sim N^2$      | $\sim N^2$    |
| pixel-GCN | $\sim N \log N$ | $\sim \log N$ |
| pixel-CNN | $\sim N$        | $\sim 1$      |
| pixel-RNN | $\sim N$        | $\sim 1$      |

Their performance should be further compared in the future.

#### Update Causal Graph

`AutoregressiveModel` also has the `update_causal_graph` method, which calls every `GraphConv` layer to update its causal graph. We can add the 2nd cousins and 2nd neiphews to the causal graph, which will exhaust all possible causal relations among 15 nodes. 

In [29]:
new_graph = torch.cat([ar.graph,
    torch.tensor([
        [ 8, 8, 8, 8, 9, 9, 9, 9,10,10,10,10,11,11,11,11, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7],
        [12,13,14,15,12,13,14,15,12,13,14,15,12,13,14,15,12,13,14,15,12,13,14,15, 8, 9,10,11, 8, 9,10,11],
        [ 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]
    ])], -1)
ar.update_graph(new_graph)
x = cache[0].squeeze().clone()
j = torch.autograd.functional.jacobian(ar, x).permute(0,2,1,3).det()
torch.nonzero(j, as_tuple=True)

(tensor([ 2,  3,  3,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  6,  7,  7,  7,
          7,  7,  7,  8,  8,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,  9,
         10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11,
         11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13,
         13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
         14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15]),
 tensor([ 1,  1,  2,  1,  2,  3,  1,  2,  3,  4,  1,  2,  3,  4,  5,  1,  2,  3,
          4,  5,  6,  1,  2,  3,  4,  5,  6,  7,  1,  2,  3,  4,  5,  6,  7,  8,
          1,  2,  3,  4,  5,  6,  7,  8,  9,  1,  2,  3,  4,  5,  6,  7,  8,  9,
         10,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,  1,  2,  3,  4,  5,  6,
          7,  8,  9, 10, 11, 12,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12,
         13,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]))

### Holographic Pixel GCN

Putting all components together to create the interface module.

`HolographicPixelGCN(energy, hidden_features, nonlinearity = 'Tanh', bias = True)`

Arguments:
* `energy`: an `EnergyModel` class which contains Hamiltonian, group and lattice information.
* `hidden_features`: a list of number of features for hidden layers (the input and output features has been specified by the order of the group and do not need to be specified here).
* `nonlinearity`: nonlinear activation layer to use.
* `bias`: whether to learn the bias.

Create a holographic pixel-GCN model. It has the following components.

In [30]:
H = lambda J: -J*(TwoBody(torch.tensor([1.,-1.]), (1,0)) 
                  + TwoBody(torch.tensor([1.,-1.]), (0,1)))
model = HolographicPixelGCN(
            EnergyModel(
                H(0.440686793), # Ising critical point
                SymmetricGroup(2), 
                Lattice(4, 2)), 
            hidden_features = [4, 4])
model

HolographicPixelGCN(
  (energy): EnergyModel(
    (group): Group(2 elements)
    (lattice): Lattice(4x4 grid with tree depth 5)
    (energy): EnergyTerms(
      (0): TwoBody(G -> [-0.44068679213523865, 0.44068679213523865] across (1, 0))
      (1): TwoBody(G -> [-0.44068679213523865, 0.44068679213523865] across (0, 1))
    )
  )
  (base_dist): AutoregressiveModel(
    (layers): ModuleList(
      (0): GraphConv(edge_types=6, in_features=2, out_features=4, bias=True, self_loop=False)
      (1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
      (2): Tanh()
      (3): GraphConv(edge_types=6, in_features=4, out_features=4, bias=True, self_loop=True)
      (4): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
      (5): Tanh()
      (6): GraphConv(edge_types=6, in_features=4, out_features=2, bias=True, self_loop=True)
    )
  )
)

Draw samples from the model.

In [31]:
x = model.sample(2)
x

tensor([[[0, 0, 0, 0],
         [1, 0, 0, 0],
         [0, 1, 1, 1],
         [0, 1, 0, 1]],

        [[0, 1, 1, 0],
         [0, 0, 0, 1],
         [0, 1, 1, 1],
         [1, 1, 0, 0]]])

Evaluate log probabilities of the samples.

In [32]:
model.log_prob(x)

tensor([-13.0147, -11.5663], grad_fn=<AddBackward0>)

Evaluate energies of the samples.

In [33]:
model.energy(x)

tensor([5.9605e-08, 1.7627e+00])

Trasnform the samples to Haar wavelet configurations.

In [34]:
model.haar.inv(x)

tensor([[0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1],
        [0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0]])

## Model Training

### Loss Function

**Reverse KL with log-trick**. The goal is to minimize the difference between the model distribution $q_\theta(x)$ and the target distribution $p(x) \propto e^{-E(x)}$ by minimizing the reverse KL divergence (see [Wu, Wang, Zhang 2019](https://arxiv.org/pdf/1809.10606.pdf) for more details)

$$\begin{split}\mathcal{L}&=\mathsf{KL}(q_\theta||p)\\
&=\sum_{x} q_\theta(x) \ln \frac{q_\theta(x)}{p(x)}\\
&=\sum_{x}q_\theta(x)(E(x)+\ln q_\theta(x)). 
\end{split}$$

All the parameter dependence is in the model distribution $q_\theta$. The gradient of the loss function with respect to the parameters is given by

$$\begin{split}\partial_\theta\mathcal{L}&= \partial_\theta \sum_{x}q_\theta(x)(E(x)+\ln q_\theta(x))\\
&= \sum_{x}[(\partial_\theta q_\theta(x))(E(x)+\ln q_\theta(x))+q_\theta(x)\partial_\theta \ln q_\theta(x)]\\
\end{split}$$

The last term can be dropped because

$$\sum_x q_\theta(x)\partial_\theta \ln q_\theta(x) = \sum_x \partial_\theta q_\theta(x)=\partial_\theta\sum_x q_\theta(x)=\partial_\theta 1 = 0,$$

the remaining term reads

$$\begin{split}\partial_\theta\mathcal{L}&= \sum_{x}(\partial_\theta q_\theta(x))(E(x)+\ln q_\theta(x))\\
&= \sum_{x}(\partial_\theta q_\theta(x))R(x)\\
&= \mathbb{E}_{x\sim q_\theta}(\partial_\theta \ln q_\theta(x))R(x)\\
\end{split}$$

with a reward signal $R(x)=E(x)+\ln q_\theta(x)$ in the context of reinforcement learning. The gradient signal $\partial_\theta \ln q_\theta(x)$ is weighted by $R(x)$, such that when $R(x)$ is large for a configuration $x$, the gradient descent will decrease the log likelihood $\ln q_\theta(x)$ for that configuration, hence the optimzation will try to reduce the free energy.

However we should not just drop the last term for finite batches, instead we should introduce a Lagrangian multiplier to balance unphysical the gradient signal that tries to change the normalization of $q_\theta$. This amounts to subtracting $R(x)$ by a baseline value $b=\mathbb{E}_{x\sim q_\theta} R(x)$, which can be estimated within each batch. The baseline subtraction helps to reduce the variance of the gradient.

### Training

Set up a model and link to an optimizer.

In [3]:
H = lambda J: -J*(TwoBody(torch.tensor([1.,-1.]), (1,0)) 
                  + TwoBody(torch.tensor([1.,-1.]), (0,1)))
model = HolographicPixelGCN(
            EnergyModel(
                H(0.440686793), # Ising critical point
                SymmetricGroup(2), 
                Lattice(4, 2)), 
            hidden_features = [4, 4])
optimizer = optim.Adam(model.parameters(), lr=0.02)

Start training

In [4]:
batch_size = 100
train_loss = 0.
free_energy = 0.
echo = 100
for epoch in range(2000):
    x = model.sample(batch_size)
    log_prob = model.log_prob(x)
    energy = model.energy(x)
    free = energy + log_prob.detach()
    meanfree = free.mean()
    loss = torch.sum(log_prob * (free - meanfree))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    free_energy += meanfree.item()
    if (epoch+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f}'.format(epoch+1, train_loss/echo, free_energy/echo))
        train_loss = 0.
        free_energy = 0.

  100 loss: -14.0713, free energy: -14.6249
  200 loss:  -3.5949, free energy: -15.0638
  300 loss:  -0.9001, free energy: -15.0858
  400 loss:  -0.6680, free energy: -15.1225
  500 loss:   3.4754, free energy: -15.1830
  600 loss:   2.0200, free energy: -15.2467
  700 loss:   2.2074, free energy: -15.2875
  800 loss:  -0.5333, free energy: -15.3057
  900 loss:   2.0551, free energy: -15.3201
 1000 loss:   0.3552, free energy: -15.3422
 1100 loss:  -0.1164, free energy: -15.3555
 1200 loss:  -3.3130, free energy: -15.3668
 1300 loss:  -2.8403, free energy: -15.3817
 1400 loss:  -1.8921, free energy: -15.3973
 1500 loss:   0.0031, free energy: -15.3881
 1600 loss:  -1.1260, free energy: -15.3982
 1700 loss:   1.2248, free energy: -15.4102
 1800 loss:   0.2574, free energy: -15.4096
 1900 loss:  -3.0939, free energy: -15.4141
 2000 loss:   0.8727, free energy: -15.4063


The model converges to a free energy of -15.42, while the exact value is -15.52. The relative error is about 0.6%. It seems that this has saturated the representation power of the pixel-GCN. What is the cause of the mismatch? How to improve it?

### Attempts to Improve

#### Extend the Causal Graph

One conjecture is that the loss is due to the missing causal connections. We can complete the causal connections and retrain.

In [5]:
ext_causal_graph = {
    'cousin2': torch.tensor([
        [ 8, 8, 8, 8, 9, 9, 9, 9,10,10,10,10,11,11,11,11],
        [12,13,14,15,12,13,14,15,12,13,14,15,12,13,14,15]
    ]),
    'neiphew2': torch.tensor([
        [ 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7],
        [12,13,14,15,12,13,14,15, 8, 9,10,11, 8, 9,10,11]
    ])}
model.base_dist.update_causal_graph(ext_causal_graph)
optimizer = optim.Adam(model.parameters(), lr=0.02)

In [7]:
batch_size = 100
train_loss = 0.
free_energy = 0.
echo = 100
for epoch in range(2000):
    x = model.sample(batch_size)
    log_prob = model.log_prob(x)
    energy = model.energy(x)
    free = energy + log_prob.detach()
    meanfree = free.mean()
    loss = torch.sum(log_prob * (free - meanfree))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    free_energy += meanfree.item()
    if (epoch+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f}'.format(epoch+1, train_loss/echo, free_energy/echo))
        train_loss = 0.
        free_energy = 0.

  100 loss:   2.9610, free energy: -15.2145
  200 loss:   3.1487, free energy: -15.3894
  300 loss:  -0.2074, free energy: -15.3915
  400 loss:   2.5028, free energy: -15.4043
  500 loss:   0.2075, free energy: -15.4074
  600 loss:  -0.1916, free energy: -15.4164
  700 loss:   2.3962, free energy: -15.4139
  800 loss:   1.4960, free energy: -15.3982
  900 loss:   1.1358, free energy: -15.4248
 1000 loss:   2.3010, free energy: -15.4249
 1100 loss:   0.3801, free energy: -15.4155
 1200 loss:  -0.8659, free energy: -15.4177
 1300 loss:   2.2450, free energy: -15.4218
 1400 loss:  -0.6914, free energy: -15.4220
 1500 loss:  -0.5840, free energy: -15.4234
 1600 loss:   2.9715, free energy: -15.4283
 1700 loss:   0.4065, free energy: -15.4166
 1800 loss:   2.1872, free energy: -15.4075
 1900 loss:   1.2261, free energy: -15.4252
 2000 loss:  -0.5291, free energy: -15.4327


No obvious improvement is observed. Probabily the causal connection is not the issue.

#### Expand the Features

Another possibility is to increase the number of hidden features.

In [19]:
model = HolographicPixelGCN(
            EnergyModel(
                H(0.440686793), # Ising critical point
                SymmetricGroup(2), 
                Lattice(4, 2)), 
            hidden_features = [16, 16])
optimizer = optim.Adam(model.parameters(), lr=0.02)

In [20]:
batch_size = 100
train_loss = 0.
free_energy = 0.
echo = 100
for epoch in range(2000):
    x = model.sample(batch_size)
    log_prob = model.log_prob(x)
    energy = model.energy(x)
    free = energy + log_prob.detach()
    meanfree = free.mean()
    loss = torch.sum(log_prob * (free - meanfree))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    free_energy += meanfree.item()
    if (epoch+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f}'.format(epoch+1, train_loss/echo, free_energy/echo))
        train_loss = 0.
        free_energy = 0.

  100 loss:  -1.6530, free energy: -14.8921
  200 loss:   6.8747, free energy: -15.0805
  300 loss:   8.4251, free energy: -15.2173
  400 loss:   4.8030, free energy: -15.3102
  500 loss:   0.0790, free energy: -15.3567
  600 loss:   0.7821, free energy: -15.3910
  700 loss:   5.5580, free energy: -15.3815
  800 loss:   1.5039, free energy: -15.4283
  900 loss:   2.9880, free energy: -15.4282
 1000 loss:   1.4602, free energy: -15.4300
 1100 loss:   4.1522, free energy: -15.4113
 1200 loss:   3.0402, free energy: -15.4297
 1300 loss:   1.7957, free energy: -15.4347
 1400 loss:   1.3020, free energy: -15.4447
 1500 loss:   2.6516, free energy: -15.4409
 1600 loss:   1.7364, free energy: -15.4311
 1700 loss:   2.9771, free energy: -15.4360
 1800 loss:   0.8413, free energy: -15.4538
 1900 loss:   1.0908, free energy: -15.4469
 2000 loss:   1.2062, free energy: -15.4505


The improvement is still marginal. If we further increase the number of features, the performance could get even worse.

#### Increase the Batch Size

In [6]:
model = HolographicPixelGCN(
            EnergyModel(
                H(0.440686793), # Ising critical point
                SymmetricGroup(2), 
                Lattice(4, 2)), 
            hidden_features = [16, 16])
optimizer = optim.Adam(model.parameters(), lr=0.02)

In [7]:
batch_size = 500
train_loss = 0.
free_energy = 0.
echo = 100
for epoch in range(2000):
    x = model.sample(batch_size)
    log_prob = model.log_prob(x)
    energy = model.energy(x)
    free = energy + log_prob.detach()
    meanfree = free.mean()
    loss = torch.sum(log_prob * (free - meanfree))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    free_energy += meanfree.item()
    if (epoch+1)%echo == 0:
        print('{:5} loss: {:8.4f}, free energy: {:8.4f}'.format(epoch+1, train_loss/echo, free_energy/echo))
        train_loss = 0.
        free_energy = 0.

  100 loss: -20.0954, free energy: -14.9410
  200 loss:  36.8592, free energy: -15.2483
  300 loss:  -4.1415, free energy: -15.3942
  400 loss:   8.7836, free energy: -15.4259
  500 loss:   5.9250, free energy: -15.4433
  600 loss:   3.8869, free energy: -15.4481
  700 loss:   6.1446, free energy: -15.4507
  800 loss:  12.8854, free energy: -15.4482
  900 loss:   7.4954, free energy: -15.4498
 1000 loss:   6.0012, free energy: -15.4559
 1100 loss:   3.2535, free energy: -15.4617
 1200 loss:  12.3520, free energy: -15.4550
 1300 loss:   1.4665, free energy: -15.4573
 1400 loss:   6.7205, free energy: -15.4573
 1500 loss:   6.3586, free energy: -15.4623
 1600 loss:   4.1372, free energy: -15.4618
 1700 loss:   6.8854, free energy: -15.4591
 1800 loss:   6.8244, free energy: -15.4563
 1900 loss:   5.3451, free energy: -15.4611
 2000 loss:   2.4827, free energy: -15.4632


Increasing the batch size could further improve the result slightly. But this is a little cheating as we have allowed machine to sample more configurations. 

### Profiling

Profiling a single training iteration with snakeviz. Sampling takes the most time. How to improve that?

In [48]:
%run "model2.py"
%load_ext snakeviz
H = lambda J: -J*(TwoBody(torch.tensor([1.,-1.]), (1,0))+ TwoBody(torch.tensor([1.,-1.]), (0,1)))
model = HolographicPixelGCN(EnergyModel(H(0.440686793), SymmetricGroup(2), Lattice(8, 2)), hidden_features = [16, 16])
optimizer = optim.Adam(model.parameters(), lr=0.02)

The snakeviz extension is already loaded. To reload it, use:
  %reload_ext snakeviz


In [49]:
%%snakeviz
batch_size = 100
x = model.sample(batch_size)
log_prob = model.log_prob(x)
energy = model.energy(x)
free = energy + log_prob.detach()
meanfree = free.mean()
loss = torch.sum(log_prob * (free - meanfree))
optimizer.zero_grad()
loss.backward()
optimizer.step()

 
*** Profile stats marshalled to file '/var/folders/1m/3nz1kxmj2mgb2s2gwq2ndxqh0000gn/T/tmpa3esdht_'. 
Embedding SnakeViz in this document...


In [47]:
%%snakeviz
batch_size = 100
x = model.sample(batch_size)
log_prob = model.log_prob(x)
energy = model.energy(x)
free = energy + log_prob.detach()
meanfree = free.mean()
loss = torch.sum(log_prob * (free - meanfree))
optimizer.zero_grad()
loss.backward()
optimizer.step()

 
*** Profile stats marshalled to file '/var/folders/1m/3nz1kxmj2mgb2s2gwq2ndxqh0000gn/T/tmpg0dq5e7k'. 
Embedding SnakeViz in this document...


In [90]:
%run "model2.py"
Lattice(4, 2).causal_graph()

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,  1,  1,  2,
          2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  2,  4,  6,  8, 10, 12, 14,
          3,  3,  2,  2,  5,  5,  4,  4,  7,  7,  6,  6,  4,  5,  4,  5,  8,  9,
          8,  9, 12, 13, 12, 13,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3],
        [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,  2,  3,  4,
          5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,  3,  5,  7,  9, 11, 13, 15,
          4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,  6,  6,  7,  7, 10, 10,
         11, 11, 14, 14, 15, 15,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,
          3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,
          4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5]])

In [58]:
range(Out[55][-1].max()+1)

range(0, 6)

In [13]:
%run "model2.py"
gc = GraphConv(Lattice(4, 2).causal_graph(), 2, 3)
x = torch.randn(16, 2)
gc(x)

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.3012,  0.2776,  0.3127],
        [ 0.3009,  0.5534,  0.7169],
        [ 0.5736,  0.1682,  0.5310],
        [ 0.4237,  0.2469,  0.1553],
        [-0.0298, -0.3841, -0.4074],
        [-0.3373,  0.3048,  0.1673],
        [ 0.0491,  0.2723,  0.3801],
        [ 0.2279,  0.0229, -0.0648],
        [-0.5026,  0.0404,  0.3990],
        [-0.2315,  0.5415, -0.4645],
        [ 0.3633,  0.1021, -0.0606],
        [-0.0474,  0.0462, -0.6494],
        [ 0.1576, -0.0630, -0.3810],
        [-0.8195,  0.8951, -0.5223],
        [-0.9188,  0.6666, -0.3999]], grad_fn=<ScatterAddBackward>)

In [22]:
gc.weight.data

tensor([[[-0.2217, -0.1109],
         [-0.1819,  0.3944],
         [-0.2872,  0.1643]],

        [[-0.1599,  0.1846],
         [-0.1460,  0.0961],
         [-0.1569, -0.3170]],

        [[ 0.2782,  0.1886],
         [-0.1437,  0.2142],
         [ 0.3294, -0.1635]],

        [[-0.2238, -0.1495],
         [ 0.2997, -0.2304],
         [ 0.3807, -0.2237]],

        [[ 0.1207, -0.3932],
         [ 0.3417,  0.2015],
         [-0.3207, -0.1033]],

        [[-0.0936, -0.4016],
         [ 0.0579,  0.2707],
         [-0.1408, -0.3832]]])

In [117]:
import torch
x = torch.tensor([[[1,1],[2,2],[3,3],[5,5]],[[6,6],[7,7],[4,4],[0,0]]])
i = torch.tensor([[[0,0],[0,0],[1,1],[1,1]],[[0,0],[0,0],[1,1],[1,1]]])
y = torch.tensor([[[0,0],[0,0]],[[0,0],[0,0]]])
y.scatter_add(-2, i, x)

tensor([[[ 3,  3],
         [ 8,  8]],

        [[13, 13],
         [ 4,  4]]])

In [2]:
i = torch.tensor([0,0,1,1])

In [3]:
import torch_scatter
y = torch_scatter.scatter_add(x, i, dim =-2)
y

tensor([[[ 3,  3],
         [ 8,  8]],

        [[13, 13],
         [ 4,  4]]])

In [99]:
x=torch.tensor([[1.,2.,3.],[4.,5.,6.]], requires_grad=True)
print(x[...,[2]].tanh())
x.scatter(-1, torch.tensor([[2]]), x[:,[2]].tanh())


tensor([[0.9951],
        [1.0000]], grad_fn=<TanhBackward>)


tensor([[1.0000, 2.0000, 0.9951],
        [4.0000, 5.0000, 6.0000]], grad_fn=<ScatterBackward0>)

In [109]:
torch.empty([1]*x.dim(), dtype=torch.long).fill_(2)

tensor([[2]])

In [118]:
torch.tensor(3).view([1]*3).expand(y.size())

tensor([[[3, 3],
         [3, 3]],

        [[3, 3],
         [3, 3]]])