# Hierachical Autoregressive Model

In [1]:
%run "main.py"

## Model Design

### Heap Linear Layer

Arrange the latent features in a binary heap tree. Each node of the heap tree hosts a feature vector. The `HeapLinear` layer applies linear transformation that maps each node to its causal dependants on the heap tree. A $k$-heap linear transformation is defined as
$$y_{2^k m+q}^{j}=\sum_{i}x_{m}^{i}W_{k}^{ij}+b_{k}^{j}, \forall q=0,\cdots,2^k-1,$$
where:
* $x$ - input features
* $y$ - output features
* $W$ - weight matrix (depends on $k$)
* $b$ - bias vector (depends on $k$)

The node 0 is some what special that it should be thought as $m=\frac{1}{2}$ in terms of calculating the dependant
$$y_{\lfloor 2^{k-1}\rfloor+q}^{j}=\sum_{i}x_{0}^{i}W_{k}^{ij}+b_{k}^{j},\forall q=0,\cdots,\lceil 2^{k-1}\rceil-1$$
In this way, under 0-heap: $0\to\{0\}$, under 1-heap: $0\to\{1\}$, under 2-heap: $0\to\{2,3\}$ and so on. 

The given the input $x$, the output $y$ will be calculated by adding up all possible $k$-heap linear transformations from $x$ (including the $*$-heap). The following code defines the `HeapLinear` layer with print outs.

In [82]:
import math
class HeapLinear(nn.Module):
    """Applies a heap linear transformation to the incoming data (assuming binary heap)
    
    Args:
        nodes: number nodes in the heap tree (better be 2^n-1) 
        in_features: size of input features at each heap node
        out_features: size of output features at each heap node
        bias: whether to learn an additive bias
        minheap: minimum heap step to start, must be non-negative int (default = 0)
    """
    def __init__(self, nodes: int, in_features: int, out_features: int, bias: bool = True, minheap: int = 0):
        super(HeapLinear, self).__init__()
        self.nodes = nodes
        self.depth = (nodes-1).bit_length()+1 # fast log2 ceiling
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.minheap = minheap
        self.linears = nn.ModuleDict()
        for k in range(minheap, self.depth):
            self.linears[str(k)] = nn.Linear(in_features, out_features, bias)
    
    def forward(self, input: torch.Tensor):
        output = torch.zeros(input.size()[:-1]+(self.out_features,))
        for l in range(self.depth):
            in_node0 = math.floor(2**(l-1))
            in_node1 = 2**l
            block_input = input[..., in_node0:in_node1, :]
            for k, linear in self.linears.items():
                k = int(k)
                if in_node0 == 0:
                    heap_factor = math.ceil(2**(k-1))
                    out_node0 = math.floor(2**(k-1))
                    out_node1 = out_node0 + heap_factor
                else:
                    heap_factor = 2**k
                    out_node0 = in_node0 * heap_factor
                    out_node1 = in_node1 * heap_factor
                if out_node1 <= self.nodes:
                    block_output = linear(block_input).repeat_interleave(heap_factor, dim=-2)
                    output[..., out_node0:out_node1, :] += block_output
                    print('{}-heap:'.format(k),list(range(in_node0,in_node1)),'->',list(range(out_node0,out_node1)))
        return output     
    
    def forward_from(self, input: torch.Tensor, in_node: int):
        output = torch.zeros(input.size()[:-2]+(self.nodes, self.out_features,))
        in_node_dim = input.size(-2)
        if in_node_dim == self.nodes:
            input = input[..., [in_node], :]
        elif in_node_dim != 1:
            raise ValueError('The node dimension must be either {} or 1, get {}.'.format(self.nodes, in_node_dim))
        for k, linear in self.linears.items():
            k = int(k)
            if in_node == 0:
                heap_factor = math.ceil(2**(k-1))
                out_node0 = math.floor(2**(k-1))
                out_node1 = out_node0 + heap_factor
            else:
                heap_factor = 2**k
                out_node0 = in_node * heap_factor
                out_node1 = (in_node + 1) * heap_factor
            if out_node1 <= self.nodes:
                output[..., out_node0:out_node1, :] += linear(input)
                print('{}-heap:'.format(k),[in_node],'->',list(range(out_node0,out_node1)))
        return output

Test:

In [83]:
hl = HeapLinear(16, 2, 3, minheap=1, bias=True)
x = torch.randn(16, 2)

Forward pass maps $x$ to $y$ through all possible heaps.

In [84]:
ya = hl(x)

1-heap: [0] -> [1]
2-heap: [0] -> [2, 3]
3-heap: [0] -> [4, 5, 6, 7]
4-heap: [0] -> [8, 9, 10, 11, 12, 13, 14, 15]
1-heap: [1] -> [2, 3]
2-heap: [1] -> [4, 5, 6, 7]
3-heap: [1] -> [8, 9, 10, 11, 12, 13, 14, 15]
1-heap: [2, 3] -> [4, 5, 6, 7]
2-heap: [2, 3] -> [8, 9, 10, 11, 12, 13, 14, 15]
1-heap: [4, 5, 6, 7] -> [8, 9, 10, 11, 12, 13, 14, 15]


<img src="./image/tree.png" alt="tree" style="width: 400px;"/>

`forward_from` can initiate the map from a specific node to all its dependants. This will be used in sampling. If the output from every node is summed up, the result should be the same.

In [85]:
yb = sum(hl.forward_from(x, i) for i in range(16))

1-heap: [0] -> [1]
2-heap: [0] -> [2, 3]
3-heap: [0] -> [4, 5, 6, 7]
4-heap: [0] -> [8, 9, 10, 11, 12, 13, 14, 15]
1-heap: [1] -> [2, 3]
2-heap: [1] -> [4, 5, 6, 7]
3-heap: [1] -> [8, 9, 10, 11, 12, 13, 14, 15]
1-heap: [2] -> [4, 5]
2-heap: [2] -> [8, 9, 10, 11]
1-heap: [3] -> [6, 7]
2-heap: [3] -> [12, 13, 14, 15]
1-heap: [4] -> [8, 9]
1-heap: [5] -> [10, 11]
1-heap: [6] -> [12, 13]
1-heap: [7] -> [14, 15]


Verify that `forward` and `forward_from` results are indeed consistent.

In [86]:
ya - yb

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], grad_fn=<SubBackward0>)

### Autoregressive Model

`AutoregressiveModel` uses heap linear layer to realize the hierachical causal structure. It provides the functionality to generate samples and calculating log probabilities. Each sample should be a 2D tensor of the shape $(N_\text{unit}, d_\text{states})$, where $N_\text{unit}$ is the number of units (nodes), and $d_\text{states}$ is the number of states for each unit (number of physical features). Samples may come in batches with arbitary batch shape. The hierachical autoregressive model models the probability of a sample $x$ as
$$p(x)=p(x_0)p(x_1|x_0)p(x_2|x_0,x_1)p(x_3|x_0,x_1)p(x_4|x_0,x_1,x_2)p(x_5|x_0,x_1,x_2)p(x_6|x_0,x_1,x_3)p(x_7|x_0,x_1,x_3)\cdots$$
The conditional distributions are modeled by neural networks.

<img src="./image/prob_tree.png" alt="prob_tree" style="width: 700px;"/>

Possible issue:
* The connection is too sparse, the causal cone is too narrow. $x_5,x_6$ hard to establish corelation.
* Why the same heap level need to share weight? Can we pass different message to different child? Currently, all nodes on the same leaf have identical distribution.

In [8]:
%run "auto.py"
class AutoregressiveModel(nn.Module, dist.Distribution):
    """ Represent a generative model that can generate samples and evaluate log probabilities.
        
        Args:
        units: number of units in the model
        features: a list of feature dimensions from the input layer to the output layer
        nonlinearity: activation function to use 
        bias: whether to learn the additive bias in heap linear layers
    """
    
    def __init__(self, units: int, features, nonlinearity: str = 'ReLU', bias: bool = True):
        super(AutoregressiveModel, self).__init__()
        self.units = units
        self.features = features
        if features[0] != features[-1]:
            raise ValueError('In features {}, the first and last feature dimensions must be equal.'.format(features))
        self.layers = nn.ModuleList()
        for l in range(len(features)-1):
            if l == 0: # first heap linear layer must have minheap=1
                self.layers.append(HeapLinear(units, features[0], features[1], bias, minheap = 1))
            else: # remaining heap linear layers have minheap=0 (by default)
                self.layers.append(getattr(nn, nonlinearity)())
                self.layers.append(HeapLinear(units, features[l], features[l+1], bias))
        dist.Distribution.__init__(self, event_shape=torch.Size([units, features[0]]))
        self.has_rsample = True
    
    def extra_repr(self):
        return '(units): {}\n(features): {}'.format(self.units, self.features) + super(AutoregressiveModel, self).extra_repr()
    
    def forward(self, input):
        logits = input # logits as a workspace, initialized to input
        for layer in self.layers: # apply layers
            logits = layer(logits)
        return logits # logits output
    
    def log_prob(self, value):
        logits = self(value) # forward pass to get logits
        return torch.sum(value * F.log_softmax(logits, dim=-1), (-2,-1))
        
    def _sample(self, batch_size: int, sampler = None):
        if sampler is None: # use default sampler
            sampler = self.sampler
        # create a list of tensors to cache layer-wise outputs
        cache = [torch.zeros(batch_size, self.units, feature) for feature in self.features]
        # autoregressive batch sampling
        for i in range(self.units):
            for l in range(len(self.features)-1):
                if l==0: # first linear layer
                    if i > 0:
                        cache[1] += self.layers[0].forward_from(cache[0], i - 1) # heap linear
                else: # remaining layers
                    activation = self.layers[2*l-1](cache[l][..., [i], :]) # element-wise
                    cache[l + 1] += self.layers[2*l].forward_from(activation, i) # heap linear
            # the last record hosts logits 
            cache[0][..., i, :] = sampler(cache[-1][..., i, :])
        return cache # cache[0] hosts the sample
        
    def sampler(self, logits, dim=-1): # simplified from F.gumbel_softmax
        gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
        gumbels += logits
        index = gumbels.max(dim, keepdim=True)[1]
        return torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
    
    def sample(self, batch_size=1):
        with torch.no_grad():
            cache = self._sample(batch_size)
        return cache[0]
    
    def rsample(self, batch_size=1, tau=None, hard=False):
        if tau is None: # if temperature not given
            tau = 1/(self.features[-1]-1) # set by the out feature dimension
        cache = self._sample(batch_size, lambda x: F.gumbel_softmax(x, tau, hard))
        return cache[0]

The model contains the following components

In [1047]:
ar = AutoregressiveModel(8, [2,3,2], bias=True)
ar

AutoregressiveModel(
  (units): 8
  (features): [2, 3, 2]
  (layers): ModuleList(
    (0): HeapLinear(
      (linears): ModuleDict(
        (1): Linear(in_features=2, out_features=3, bias=True)
        (2): Linear(in_features=2, out_features=3, bias=True)
        (3): Linear(in_features=2, out_features=3, bias=True)
      )
    )
    (1): ReLU()
    (2): HeapLinear(
      (linears): ModuleDict(
        (0): Linear(in_features=3, out_features=2, bias=True)
        (1): Linear(in_features=3, out_features=2, bias=True)
        (2): Linear(in_features=3, out_features=2, bias=True)
        (3): Linear(in_features=3, out_features=2, bias=True)
      )
    )
  )
)

Verify that the forward pass recovers the logits generated through the autoregressive sampling.

In [139]:
cache = ar._sample(1)
ar(cache[0]) - cache[-1]

tensor([[[0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.]]], grad_fn=<SubBackward0>)

### One-Hot and Categorical Representations

The sample draw from the autoregressive model are one-hot embeddings. Convert between one-hot and categorical representations.

In [13]:
class OneHotCategoricalTransform(dist.Transform):
    """Convert between one-hot and categorical representations.
    
    Args:
    num_classes: number of classes."""
    def __init__(self, num_classes: int):
        super(OneHotCategoricalTransform, self).__init__()
        self.num_classes = num_classes
        self.bijective = True
    
    def _call(self, x):
        # one-hot to categorical
        return x.max(dim=-1)[1]
    
    def _inverse(self, y):
        # categorical to one-hot
        return F.one_hot(y, self.num_classes).to(dtype=torch.float)
    
    def log_abs_det_jacobian(self, x, y):
        return torch.tensor(0.)

Test:

In [14]:
ar = AutoregressiveModel(4, [2,3,2], bias=False)
x = ar.sample(2)
x

tensor([[[0., 1.],
         [1., 0.],
         [0., 1.],
         [1., 0.]],

        [[1., 0.],
         [1., 0.],
         [0., 1.],
         [1., 0.]]])

In [15]:
oc = OneHotCategoricalTransform(2)
y = oc(x)
y

tensor([[1, 0, 1, 0],
        [0, 0, 1, 0]])

In [16]:
oc.inv(y)

tensor([[[0., 1.],
         [1., 0.],
         [0., 1.],
         [1., 0.]],

        [[1., 0.],
         [1., 0.],
         [0., 1.],
         [1., 0.]]])

They can be pack into a Transform.

### Group Algebra

`Group` represents a group specified by the multiplication table. Group elements will be labeled by integers (ranging from 0 to the order of the group). The element 0 is always treated as the identity element of the group. `Group` provides methods to perform element-wise group multiplication for Torch tensors.

In [252]:
class Group(object):
    """Represent a group, providing multiplication and inverse operation.
    
    Args:
    mul_table: multiplication table as a tensor, e.g. Z2 group: tensor([[0,1],[1,0]])
    """
    def __init__(self, mul_table: torch.Tensor):
        super(Group, self).__init__()
        self.mul_table = mul_table
        self.order = mul_table.size(0) # number of group elements
        gs, ginvs = torch.nonzero(self.mul_table == 0, as_tuple=True)
        self.inv_table = torch.gather(ginvs, 0, gs)
    
    def __iter__(self):
        return iter(range(self.order))
    
    def __repr__(self):
        return 'Group({} elements)'.format(self.order)
    
    def inv(self, input: torch.Tensor):
        return torch.gather(self.inv_table.expand(input.size()[:-1]+(-1,)), -1, input)
    
    def mul(self, input1: torch.Tensor, input2: torch.Tensor):
        output = input1 * self.order + input2
        return torch.gather(self.mul_table.flatten().expand(output.size()[:-1]+(-1,)), -1, output)
    
    def prod(self, input, dim: int, keepdim: bool = False):
        input_size = input.size()
        flat_mul_table = self.mul_table.flatten().expand(input_size[:dim]+input_size[dim+1:-1]+(-1,))
        output = input.select(dim, 0)
        for i in range(1, input.size(dim)):
            output = output * self.order + input.select(dim, i)
            output = torch.gather(flat_mul_table, -1, output)
        if keepdim:
            output = output.unsqueeze(dim)
        return output
    
    def val(self, input, val_table = None):
        if val_table is None:
            val_table = torch.zeros(self.order)
            val_table[0] = 1.
        elif len(val_table) != self.order:
            raise ValueError('Group function value table must be of the same size as the group order, expect {} got {}.'.format(self.order, len(val_table)))
        return torch.gather(val_table.expand(input.size()[:-1]+(-1,)), -1, input)

Create a $S_3$ group.

In [155]:
G = Group(torch.tensor([[0,1,2,3,4,5],[1,0,3,2,5,4],[2,4,0,5,1,3],[3,5,1,4,0,2],[4,2,5,0,3,1],[5,3,4,1,2,0]]))

Multiplying two tensors element-wise following the group multiplication rule.

In [133]:
a = torch.tensor([[0,1,2],[3,4,5]])
b = torch.tensor([[5,4,3],[2,1,0]])
G.mul(a, b)

tensor([[5, 5, 5],
        [1, 2, 5]])

Product of each row of a tensor in the given dimension.

In [144]:
G.prod(a, dim=0)

tensor([3, 5, 3])

Group inversion of all elements.

In [145]:
G.inv(a)

tensor([[0, 1, 2],
        [4, 3, 5]])

Evaluate a group function given by a value table `val_table` (default function: group delta function).

In [149]:
G.val(a)

tensor([[1., 0., 0.],
        [0., 0., 0.]])

In [157]:
G.val(a, val_table=torch.tensor([1.,0.,0.,-0.5,-0.5,0.]))

tensor([[ 1.0000,  0.0000,  0.0000],
        [-0.5000, -0.5000,  0.0000]])

### Lattice System

`LatticeSystem(G, L, d)` is only a container to host lattice information which could be passed together to other classes. It defines the group $G$ on each site, the size $L$ and the dimension $d$ of the lattice. Such that the lattice grid is of the shape $L\times\cdots\times L$ ($d$ times), and the total number units (sites) is $L^d$.

In [261]:
class LatticeSystem(object):
    """ a container to host lattice information
        
        Args:
        group: a group that defines multiplication among elements
        size: length of the lattcie along one dimension (should be a power of 2)
        dimension: dimension of the lattice
    """
    def __init__(self, group: Group, size: int, dimension: int):
        super(LatticeSystem, self).__init__()
        self.group = group
        self.size = size
        self.dimension = dimension
        self.shape = [size]*dimension
        self.units = size**dimension
        
    def __repr__(self):
        return 'LatticeSystem({} on {} grid)'.format(self.group, 'x'.join(str(L) for L in self.shape))

Example:

In [235]:
latt = LatticeSystem(G, 4, 2)
latt.shape, latt.units

([4, 4], 16)

### Haar Wavelet Basis

`HaarWavelet(latt)` returns the Haar wavelet basis on lattice. The lattice size $L$ should be $2^n$ such that the bipartition can be performed to the end.

In [239]:
def HaarWavelet(latt: LatticeSystem):
    wav = torch.zeros(torch.Size([latt.units] + latt.shape), dtype=torch.int)
    def partition(rng: torch.Tensor, dim: int, ind: int):
        if rng[dim].sum()%2 == 0:
            mid = rng[dim].sum()//2
            rng1 = rng.clone()
            rng1[dim, 1] = mid
            rng2 = rng.clone()
            rng2[dim, 0] = mid
            w = wav[ind]
            for k in range(rng1.size(0)):
                w = w.narrow(k, rng1[k,0], rng1[k,1]-rng1[k,0])
            w.fill_(1)
            partition(rng1, (dim + 1)%latt.dimension, 2*ind)
            partition(rng2, (dim + 1)%latt.dimension, 2*ind + 1)
    partition(torch.tensor([[0, latt.size]]*latt.dimension), 0, 1)
    wav[0] = 1
    return wav

1D Haar wavelet of size 8:

In [240]:
HaarWavelet(LatticeSystem(G, 8, 1))

tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0]], dtype=torch.int32)

2D Haar wavelet of size $2\times 2$:

In [242]:
HaarWavelet(LatticeSystem(G, 2, 2))

tensor([[[1, 1],
         [1, 1]],

        [[1, 1],
         [0, 0]],

        [[1, 0],
         [0, 0]],

        [[0, 0],
         [1, 0]]], dtype=torch.int32)

### Tensor Slice by Coordinate

`coordinate_select` is a helper function, which selects element form input by coordinate index. The coordinate can be any iterable.

In [855]:
def coordinate_select(input: torch.Tensor, coordinate, dims = None):
    output = input
    if dims is None:
        dims = range(len(coordinate))
    for dim, i in zip(dims, coordinate):
        output = output.narrow(dim, i, 1)
    return output

Examples:

In [880]:
x = torch.arange(8).reshape(2,2,2)
x

tensor([[[0, 1],
         [2, 3]],

        [[4, 5],
         [6, 7]]])

In [857]:
coordinate_select(x, (0,1,1))

tensor([[[3]]])

In [858]:
coordinate_select(x, (1,0))

tensor([[[4, 5]]])

In [875]:
coordinate_select(x, (1,0), dims=(-2,-1))

tensor([[[2]],

        [[6]]])

### Haar Transformation Bijector

Haar wavelet transformation realized as a bijective transformation.

In [243]:
import itertools
class HaarTransform(dist.Transform):
    """Haar wavelet transformation (bijective)
    transformation takes real space configurations x to wavelet space encoding y
    
    Args:
    lattice: a lattice system containing information of the group and lattice shape
    """
    def __init__(self, lattice: LatticeSystem):
        super(HaarTransform, self).__init__()
        self.lattice = lattice
        self.bijective = True
        self.make_wavelet()
        self.make_plan()
        
    # construct Haar wavelet basis
    def make_wavelet(self):
        self.wavelet = torch.zeros(torch.Size([self.lattice.units]+self.lattice.shape), dtype=torch.int)
        def partition(rng: torch.Tensor, dim: int, ind: int):
            if rng[dim].sum()%2 == 0:
                mid = rng[dim].sum()//2
                rng1 = rng.clone()
                rng1[dim, 1] = mid
                rng2 = rng.clone()
                rng2[dim, 0] = mid
                wave = self.wavelet[ind]
                for k in range(rng1.size(0)):
                    wave = wave.narrow(k, rng1[k,0], rng1[k,1]-rng1[k,0])
                wave.fill_(1)
                partition(rng1, (dim + 1)%self.lattice.dimension, 2*ind)
                partition(rng2, (dim + 1)%self.lattice.dimension, 2*ind + 1)
        partition(torch.tensor([[0, self.lattice.size]]*self.lattice.dimension), 0, 1)
        self.wavelet[0] = 1
    
    # construct solution plan for Haar decomposition
    def make_plan(self):
        levmap = self.wavelet.sum(0)
        self.plan = {i:[] for i in range(self.lattice.units)}
        for spot in zip(*torch.nonzero(self.wavelet, as_tuple = True)):
            self.plan[spot[0].item()].append(tuple(x.item() for x in spot[1:]))
        spot2lev = {spot: coordinate_select(levmap, spot).item() for spot in 
                    itertools.product(*[range(d) for d in self.lattice.shape])}
        self.plan = {i: sorted(spots, key=lambda spot: spot2lev[spot]) for i, spots in self.plan.items()}
        
    def _call(self, x):
        wave = self.wavelet.view(torch.Size([1]*(x.dim()-1))+self.wavelet.size())
        x = x.view(x.size() + torch.Size([1]*self.lattice.dimension))
        return self.lattice.group.prod(x * wave, -(self.lattice.dimension+1))

    def _inverse(self, y):
        y = y.clone() # to avoid modifying the original input
        x = torch.zeros(y.size()[:-self.lattice.dimension]+(self.lattice.units,), dtype=torch.long)
        dims = tuple(range(-self.lattice.dimension,0))
        for i, spots in self.plan.items():
            sol = coordinate_select(y, spots[0], dims)
            x[...,i] = sol.squeeze()
            invsol = self.lattice.group.inv(sol)
            for spot in spots[1:]:
                y_spot = coordinate_select(y, spot, dims)
                y_spot.copy_(self.lattice.group.mul(invsol, y_spot))
        return x
    
    def log_abs_det_jacobian(self, x, y):
        return torch.tensor(0.)

Demonstration. Create a autoregressive model and sample some configuration of Haar encodings.

In [244]:
ar = AutoregressiveModel(4, [2,3,2], bias=False)
oc = OneHotCategoricalTransform(2)
x = oc(ar.sample(3))
x

tensor([[1, 1, 1, 0],
        [0, 1, 0, 1],
        [0, 0, 1, 1]])

Create a Haar transform and decode to spin configurations.

In [246]:
ht = HaarTransform(LatticeSystem(Group(torch.tensor([[0,1],[1,0]])), 2, 2))
y = ht(x)
y

tensor([[[1, 0],
         [1, 1]],

        [[1, 1],
         [1, 0]],

        [[1, 0],
         [1, 0]]])

Encode spin configuration back to Haar encodings.

In [247]:
ht.inv(y)

tensor([[1, 1, 1, 0],
        [0, 1, 0, 1],
        [0, 0, 1, 1]])

Non-trivial examples of non-Abelian groups on 3D lattice.

In [249]:
G = Group(torch.tensor([[0,1,2,3,4,5],[1,0,3,2,5,4],[2,4,0,5,1,3],[3,5,1,4,0,2],[4,2,5,0,3,1],[5,3,4,1,2,0]]))
ht = HaarTransform(LatticeSystem(G, 4, 3))
oc = OneHotCategoricalTransform(G.order)
ar = AutoregressiveModel(ht.lattice.units, [G.order,3,G.order], bias=False)
x = oc(ar.sample(1))
x, x - ht.inv(ht(x))

(tensor([[5, 0, 5, 3, 2, 1, 0, 0, 1, 0, 2, 4, 5, 0, 2, 2, 3, 5, 5, 3, 4, 2, 3, 2,
          0, 4, 3, 4, 5, 3, 3, 1, 5, 0, 5, 4, 3, 1, 1, 2, 4, 1, 4, 1, 3, 4, 2, 3,
          3, 1, 0, 2, 1, 3, 3, 4, 5, 3, 1, 0, 1, 4, 2, 3]]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))

### Energy Model

Finally, we need an energy model to describe the Statistical Mechanics system. `EnergyModel` provides the function to evalutate the energy of a configuration.

In [13]:
class EnergyModel(nn.Module):
    """ Energy mdoel that describes the physical system. Provides function to evaluate energy.
    
        Args:
        lattice: a lattice system containing information of the group and lattice shape
        hamiltonian: lattice Hamiltonian in terms of energy terms
    """
    def __init__(self, lattice: LatticeSystem, energy: EnergyTerms):
        super(EnergyModel, self).__init__()
        self.lattice = lattice
        self.energy = energy.on(lattice)
    
    def extra_repr(self):
        return '(lattice): {}'.format(self.lattice) + super(EnergyModel, self).extra_repr()
        
    def forward(self, input):
        return self.energy(input)

Consider a 2D Ising model on a square lattice
$$H= -J \sum_{i}(\sigma_i\sigma_{i+\hat{x}} + \sigma_i\sigma_{i+\hat{y}}).$$
The Hamiltonian can be typed in as (see the following subsection for explaination of the notation)

In [349]:
J = 0.7
H = - J * (TwoBody(torch.tensor([1.,-1.], dtype=float), (1,0)) 
           + TwoBody(torch.tensor([1.,-1.], dtype=float), (0,1)))

The energy model is defined by the lattice system and the Hamiltonian

In [350]:
latt = LatticeSystem(Group(torch.tensor([[0,1],[1,0]])), 4, 2)
energy = EnergyModel(latt, H)
energy

EnergyModel(
  (lattice): LatticeSystem(Group(2 elements) on 4x4 grid)
  (hamiltonian): EnergyTerms(
    (0): TwoBody(tensor([-0.7000,  0.7000], dtype=torch.float64) across (1, 0))
    (1): TwoBody(tensor([-0.7000,  0.7000], dtype=torch.float64) across (0, 1))
  )
)

Let us generate some spin configurations.

In [353]:
ar = AutoregressiveModel(latt.units, [latt.group.order,3,latt.group.order], bias=False)
oc = OneHotCategoricalTransform(latt.group.order)
ht = HaarTransform(latt)
x = ht(oc(ar.sample(3)))
x

tensor([[[0, 1, 1, 0],
         [1, 0, 1, 1],
         [1, 0, 0, 0],
         [0, 1, 1, 1]],

        [[1, 0, 1, 1],
         [0, 1, 1, 0],
         [1, 0, 0, 1],
         [1, 1, 1, 0]],

        [[0, 0, 1, 0],
         [1, 0, 1, 0],
         [0, 1, 0, 0],
         [0, 0, 0, 1]]])

The energy model can be used to evalutate the energy of these spin configuraitons.

In [354]:
energy(x)

tensor([2.8000, 5.6000, 2.8000], dtype=torch.float64)

#### Hamiltonian Scripting System

(See [phys.py](./phys.py) for code details).

In order to facilitate the intuitive formulation of Hamiltonian, we have introduced a scripting system. Physical Hamiltonians are always sum of local energy terms. In this system, each energy term is a subclass of `nn.Module` and each Hamiltonian is a subclass of `nn.ModuleList` (which contains the collection of energy terms). In this way, the evaluation of the total energy of the Hamiltonian can be passed down to each energy terms.

We have introduced two kinds of energy terms
* `OnSite`: on-site energy term $E_1(g_i)$,
* `TwoBody`: two-body interaction term $E_2(g_i,g_j)$.

More complicated interaction terms can be introduced under this framework if necessary. These energy terms are group functions: $E_1:G\to\mathbb{R}$, $E_2:G\times G\to\mathbb{R}$. These group functions can be specified by a value table, which enumerated the value that each group element maps to. For example, for the $\mathbb{Z}_2=\{0,1\}$ group ($0$-identity, $1$-generator), if we want to specify $E_1(g_\sigma)=\sigma$, i.e.
$$E_1(0)=+1, E_1(1)=-1,$$
the value talbe is $[+1,-1]$. Such a term can be created as follows

In [360]:
OnSite(torch.tensor([1.,-1.],dtype=float))

OnSite(tensor([ 1., -1.], dtype=torch.float64))

For two-body term, we assume that it always take the form of
$$E_2(g_i,g_j)=E_2(g_i^{-1}g_j),$$
such that we will only need to a single-variable group function, unsing the same value table representation. For example,

In [359]:
TwoBody(torch.tensor([1.,-1.],dtype=float), (1,0))

TwoBody(tensor([ 1., -1.], dtype=torch.float64) across (1, 0))

The two-body term also carries a second argument to specify the relative direction from site-$i$ to site-$j$. If the value table is not specified, the default group function will be taken to be the delta function (like Potts model), which maps the identity element to 1 and the others to 0.

In this system, we can add, subtract, scalar multiply and negate the energy terms. Energy terms adding together will be represented as a collection of terms in a list (`nn.ModuleList`), which correspond to a Hamiltonian.

In [4]:
-2.8 * OnSite() + 5.2 * (TwoBody(shifts=(1,0)) + TwoBody(shifts=(0,1)))

EnergyTerms(
  (0): TwoBody(5.2)
  (1): TwoBody(5.2)
  (2): OnSite(-2.8)
)

A Hamiltonian `H` needs to be further put on a lattice by calling `H.on(lattice)`. Only after it putting on a lattice, the Hamiltonian has a concrete meaning, and then the energy of the system can be evaluated by the `H.forward(input)` method on a spin configuration as `input`.

### HolographicPixelFlow

There might be a better name, but the idea is that the holographic pixel flow combines components of renormalization group, autoregressive model (pixel-...), and flow-based generative model together to create a probability model for modeling critical statistical mechanics systems. 

In [4]:
class HolographicPixelFlow(nn.Module, dist.TransformedDistribution):
    """ Combination of hierarchical autoregressive and flow-based model for lattice models.
    
        Args:
        model: a energy model to learn
        hidden_features: a list of feature dimensions of hidden layers
        nonlinearity: activation function to use 
        bias: whether to learn the additive bias in heap linear layers
    """
    def __init__(self, model: EnergyModel, hidden_features, nonlinearity: str = 'ReLU', bias: bool = True):
        super(HolographicPixelFlow, self).__init__()
        self.model = model
        self.haar = HaarTransform(model.lattice)
        n = model.lattice.group.order
        self.onecat = OneHotCategoricalTransform(n)
        features = [n] + hidden_features + [n]
        auto = AutoregressiveModel(model.lattice.units, features, nonlinearity, bias)
        dist.TransformedDistribution.__init__(self, auto, [self.onecat, self.haar])
        self.transform = dist.ComposeTransform(self.transforms)
        
    def energy(self, input): # create a shortcut for energy
        return self.model.energy(input)

Create a holographic pixel flow model

In [6]:
%run "main.py"
G = Group(torch.tensor([[0,1],[1,0]]))
J = 0.7
H = - J * (TwoBody(torch.tensor([1.,-1.]), (1,0)) 
           + TwoBody(torch.tensor([1.,-1.]), (0,1)))
hpf = HolographicPixelFlow(EnergyModel(LatticeSystem(G, 4, 2), H), [3], bias = False)
hpf

HolographicPixelFlow(
  (model): EnergyModel(
    (lattice): LatticeSystem(Group(2 elements) on 4x4 grid)
    (energy): EnergyTerms(
      (0): TwoBody(tensor([-0.7000,  0.7000]) across (1, 0))
      (1): TwoBody(tensor([-0.7000,  0.7000]) across (0, 1))
    )
  )
  (base_dist): AutoregressiveModel(
    (units): 16
    (features): [2, 3, 2]
    (layers): ModuleList(
      (0): HeapLinear(
        (linears): ModuleDict(
          (1): Linear(in_features=2, out_features=3, bias=False)
          (2): Linear(in_features=2, out_features=3, bias=False)
          (3): Linear(in_features=2, out_features=3, bias=False)
          (4): Linear(in_features=2, out_features=3, bias=False)
        )
      )
      (1): ReLU()
      (2): HeapLinear(
        (linears): ModuleDict(
          (0): Linear(in_features=3, out_features=2, bias=False)
          (1): Linear(in_features=3, out_features=2, bias=False)
          (2): Linear(in_features=3, out_features=2, bias=False)
          (3): Linear(in_feature

Draw samples from the model.

In [7]:
x = hpf.sample(2)
x

tensor([[[0, 1, 0, 1],
         [1, 0, 0, 0],
         [1, 0, 1, 0],
         [1, 1, 1, 1]],

        [[0, 1, 0, 1],
         [0, 0, 0, 1],
         [1, 0, 1, 0],
         [0, 1, 0, 1]]])

Evaluate log probabilities of samples.

In [8]:
hpf.log_prob(x)

tensor([-10.9535, -10.2614], grad_fn=<AddBackward0>)

Evaluate energies of samples.

In [9]:
hpf.energy(x)

tensor([2.8000, 8.4000])

Inverse transform samples to the latent space.

In [10]:
hpf.transform.inv(x)

tensor([[[0., 1.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [0., 1.],
         [1., 0.]],

        [[0., 1.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [1., 0.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [0., 1.]]])

## Model Training

### Loss Function

**Reverse KL with log-trick**. The goal is to minimize the difference between the model distribution $q_\theta(x)$ and the target distribution $p(x) \propto e^{-E(x)}$ by minimizing the reverse KL divergence
$$\begin{split}\mathcal{L}&=\mathsf{KL}(q_\theta||p)\\
&=\sum_{x} q_\theta(x) \ln \frac{q_\theta(x)}{p(x)}\\
&=\sum_{x}q_\theta(x)(E(x)+\ln q_\theta(x)). 
\end{split}$$
The parameter dependence is only in $q_\theta$.

The gradient is given by
$$\begin{split}\partial_\theta\mathcal{L}&= \partial_\theta \sum_{x}q_\theta(x)(E(x)+\ln q_\theta(x))\\
&= \sum_{x}[(\partial_\theta q_\theta(x))(E(x)+\ln q_\theta(x))+q_\theta(x)\partial_\theta \ln q_\theta(x)]\\
\end{split}$$
The last term can be dropped because 
$$\sum_x q_\theta(x)\partial_\theta \ln q_\theta(x) = \sum_x \partial_\theta q_\theta(x)=\partial_\theta\sum_x q_\theta(x)=\partial_\theta 1 = 0,$$
the remaining term reads
$$\begin{split}\partial_\theta\mathcal{L}&= \sum_{x}(\partial_\theta q_\theta(x))(E(x)+\ln q_\theta(x))\\
&= \sum_{x}(\partial_\theta q_\theta(x))R(x)\\
&= \mathbb{E}_{x\sim q_\theta}(\partial_\theta \ln q_\theta(x))R(x)\\
\end{split}$$
with a reward signal $R(x)=E(x)+\ln q_\theta(x)$ in the context of reinforcement learning. The gradient signal $\partial_\theta \ln q_\theta(x)$ is weighted by $R(x)$, such that when $R(x)$ is large for a configuration $x$, the gradient descent will decrease the log likelihood $\ln q_\theta(x)$ for that configuration, hence the optimzation will try to reduce the free energy.

However we should not just drop the last term for finite batches, instead we should introduce a Lagrangian multiplier to counter the gradient signal that is towards the direction of violating the normalization condition. This amounts to subtracting $R(x)$ by a baseline value $b=\mathbb{E}_{x\sim q_\theta} R(x)$, which can be estimated within each batch. The baseline subtraction helps to reduce the variance of the gradient.

### Optimization

In [97]:
%run "main.py"
G = Group(torch.tensor([[0,1],[1,0]]))
J = 0.1
H = - J * (TwoBody(torch.tensor([1.,-1.]), (1,0)) 
           + TwoBody(torch.tensor([1.,-1.]), (0,1)))
hpf = HolographicPixelFlow(EnergyModel(LatticeSystem(G, 2, 2), H), [16, 16], bias = False)
optimizer = optim.Adam(hpf.parameters(), lr=0.001)
batch_size = 500

In [93]:
train_loss = 0.
free_energy = 0.
echo = 100
for epoch in range(500):
    x = hpf.sample(batch_size)
    log_prob = hpf.log_prob(x)
    energy = hpf.energy(x)
    free = energy + log_prob.detach()
    meanfree = free.mean()
    loss = torch.sum(log_prob * (free - meanfree))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    free_energy += meanfree.item()
    if (epoch+1)%echo == 0:
        print('loss: {:.4f}, free energy: {:.4f}'.format(train_loss/echo, free_energy/echo))
        train_loss = 0.
        free_energy = 0.

loss: 0.0629, free energy: -2.8348
loss: -0.0035, free energy: -2.8362
loss: 0.0205, free energy: -2.8353
loss: -0.0243, free energy: -2.8353
loss: 0.0591, free energy: -2.8360


In [94]:
x = hpf.sample(5)
x

tensor([[[0, 1],
         [1, 1]],

        [[0, 0],
         [1, 1]],

        [[0, 0],
         [1, 0]],

        [[1, 1],
         [0, 0]],

        [[1, 0],
         [0, 1]]])

In [65]:
hpf.log_prob(x).exp()

tensor([0.0473, 0.1099, 0.1099, 0.1168, 0.1168], grad_fn=<ExpBackward>)

In [95]:
xs = torch.tensor(list(itertools.product([0,1],repeat=4))).view(-1,2,2)

In [96]:
with torch.no_grad():
    ps = hpf.log_prob(xs).exp()
for i, p in enumerate(ps):
    print(i, p.item())

0 0.11705994606018066
1 0.04812687262892723
2 0.07097935676574707
3 0.06810590624809265
4 0.047985441982746124
5 0.04294325038790703
6 0.03375312313437462
7 0.07086819410324097
8 0.07097935676574707
9 0.03400873765349388
10 0.04303836077451706
11 0.04812687262892723
12 0.06821898370981216
13 0.07086819410324097
14 0.047985441982746124
15 0.11695203930139542


In [88]:
hpf.transform.inv(xs[[0,-1]])

tensor([[[1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.]],

        [[0., 1.],
         [1., 0.],
         [1., 0.],
         [1., 0.]]])

In [89]:
hpf.base_dist(hpf.transform.inv(xs[[0,-1]]))

tensor([[[ 0.0000,  0.0000],
         [ 0.1514, -0.2736],
         [ 0.0701, -0.3479],
         [ 0.0701, -0.3479]],

        [[ 0.0000,  0.0000],
         [ 0.1498, -0.2706],
         [-0.0920, -0.5900],
         [-0.0920, -0.5900]]], grad_fn=<CopySlices>)