# Holographic pixel Graph Convolutional Network (HpGCN)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dist

## Model Design

### Lattice System

In [162]:
import math
class LatticeSystem(object):
    """ host lattice information and construct graph in hyperbolic space
        
        Args:
        size: number of size along one dimension (assuming square/cubical lattice)
        dimension: dimension of the lattice
        causal_radius: radius of the causal cone across one level 
        scale_resolved: whether to distinguish edges from different levels
    """
    def __init__(self, size:int, dimension:int, causal_radius: float = 1., scale_resolved: bool = True):
        self.size = size
        self.dimension = dimension
        self.shape = [size]*dimension
        self.sites = size**dimension
        self.tree_depth = self.sites.bit_length()
        self.node_init()
        self.reset_causal_graph(causal_radius, scale_resolved)
        
    def __repr__(self):
        return 'LatticeSystem({} grid with tree depth {}\n\t(node_index): {}\n\t(edge_index): {}\n\t(edge_type): {})'.format('x'.join(str(L) for L in self.shape), self.tree_depth, self.node_index, self.edge_index, self.edge_type)
    
    def node_init(self):
        self.node_levels = torch.zeros(self.sites, dtype=torch.int)
        self.node_centers = torch.zeros(self.sites, self.dimension, dtype=torch.float)
        self.node_index = torch.zeros(self.sites, dtype=torch.long)
        def partition(rng: torch.Tensor, dim: int, ind: int, lev: int):
            if rng[dim].sum()%2 == 0:
                self.node_levels[ind] = lev
                self.node_centers[ind] = rng.to(dtype=torch.float).mean(-1)
                mid = rng[dim].sum()//2
                rng1 = rng.clone()
                rng1[dim, 1] = mid
                rng2 = rng.clone()
                rng2[dim, 0] = mid
                partition(rng1, (dim + 1)%self.dimension, 2*ind, lev+1)
                partition(rng2, (dim + 1)%self.dimension, 2*ind + 1, lev+1)
            else:
                self.node_index[ind-self.sites] = rng[:,0].dot(self.size**torch.arange(0,self.dimension).flip(0))
        partition(torch.tensor([[0, self.size]]*self.dimension), 0, 1, 1)
        
    def reset_causal_graph(self, causal_radius: float, scale_resolved: bool = True):
        def discover_causal_connection(z: int):
            # Args: z - level of the source
            source_pos = self.node_centers[2**(z-1):2**z]
            target_pos = self.node_centers[2**z:2**(z+1)]
            diff = source_pos.unsqueeze(0) - target_pos.unsqueeze(1)
            diff = (diff + self.size/2)%self.size - self.size/2
            dist = torch.norm(diff, dim=-1)
            smooth_scale = 2**((self.tree_depth-1-z)/self.dimension)
            mask = dist < causal_radius * smooth_scale
            target_ids, source_ids = torch.nonzero(mask, as_tuple=True)
            step_scale = 2**math.floor((self.tree_depth-1-z)/self.dimension)
            edge_signatures = torch.round(2*diff/step_scale)[target_ids, source_ids].to(dtype=torch.int)
            level_signatures = torch.tensor([[z]]*len(source_ids))
            if scale_resolved:
                signatures = torch.cat((level_signatures, edge_signatures), -1)
            else:
                signatures = edge_signatures
            return (2**(z-1) + source_ids, 2**z + target_ids, signatures)
        level_graded_result = [discover_causal_connection(z) for z in range(1, self.tree_depth-1)]
        source_ids, target_ids, signatures = [torch.cat(tens, 0) for tens in zip(*level_graded_result)]
        signatures = [tuple(signature) for signature in signatures.tolist()]
        distinct_signatures = set(signatures)
        self.edge_type_map = {signature: i + 1 for i, signature in enumerate(distinct_signatures)}
        self.edge_type = torch.tensor([self.edge_type_map[signature] for signature in signatures])
        self.edge_index = torch.stack((source_ids, target_ids), 0)
        return self.edge_index, self.edge_type

In [3]:
LatticeSystem(4, 2)

LatticeSystem(4x4 grid with tree depth 5
	(node_index): tensor([ 0,  1,  4,  5,  2,  3,  6,  7,  8,  9, 12, 13, 10, 11, 14, 15])
	(edge_index): tensor([[ 1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7],
        [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
	(edge_type): tensor([1, 5, 6, 3, 6, 3, 4, 2, 4, 2, 4, 2, 4, 2]))

#### Real Space
Consider the physical lattice of shape $L\times\cdots\times L = L^d$ of size $L$ in $d$ dimensional space. Each physical site is natually labeled by its real space coordinate $i=(i_0,\cdots,i_{d-1})$ with each $i_a= 0,\cdots,L-1$. Instead of labeling sites by coordinate, we can also use the logical index (flattened index)
$$i=L^{d-1}i_0+L^{d-2}i_1+\cdots+L i_{d-2}+i_{d-1}=\sum_{k=0}^{d-1}L^{d-1-k} i_k.$$
Data stored in this order (also known as the C-format) can be viewed as a high dimensional tensor naturally.

#### H-Tree Structure

The real space lattice is reduced under RG by binary coarse graining along each different axis cyclicly. Under RG the information flows along a H-tree fractal. Following is an example in 2D.

![](./image/H-tree.png)

There are two types of units:
* **site**: physical site on the real space lattice, indexed by $i$ (in red)
* **node**: Haar wavelet in the hyperbolic space, indexed by $q$ (in black). Nodes lives on a binary tree (node 0 is not shown in the above figure). The nodes of different **levels** are colored differently. Each node has a **center** position (when it is projected from the holographic bulk to the boundary).

`LetticeSystem` computes these pieces of information upon initialization.

In [4]:
latt = LatticeSystem(4, 2)
latt.node_index, latt.node_levels, latt.node_centers

(tensor([ 0,  1,  4,  5,  2,  3,  6,  7,  8,  9, 12, 13, 10, 11, 14, 15]),
 tensor([0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4], dtype=torch.int32),
 tensor([[0.0000, 0.0000],
         [2.0000, 2.0000],
         [1.0000, 2.0000],
         [3.0000, 2.0000],
         [1.0000, 1.0000],
         [1.0000, 3.0000],
         [3.0000, 1.0000],
         [3.0000, 3.0000],
         [0.5000, 1.0000],
         [1.5000, 1.0000],
         [0.5000, 3.0000],
         [1.5000, 3.0000],
         [2.5000, 1.0000],
         [3.5000, 1.0000],
         [2.5000, 3.0000],
         [3.5000, 3.0000]]))

* `node_index` is a list of index of sites in the node ordering, such that every RG step corresponds to fusing neighboring nodes.
* `node_levels` are levels of nodes.
* `node_centers` are center positions of nodes.

#### Causal Graph Structure

The H-tree is just a backbone of the hyperbolic space. The actural causal connection can be more extended. The causal relation among nodes forms a **graph**. Two nodes are causally related if the are within a certain radius in the hyperbolic space (the radius can be specified as a hyperparameter). The causal influence always flows from IR to UV in the generative process (so the graph is *directed*).

Starting from a node $q$, denote its level as $z_q$ and its center as $x_q$. A 0th-order connection is simply a self-loop. A 1st-order connection is from $(z_q,x_q)$ to $(z_{q'}=z_q+1, x_{q'})$ such that $|x_q-x_{q'}|< 2^{\lfloor (\zeta-1-z_q)/d\rfloor}r$, where $r$ is a given radius and $\zeta=\log_2 L^d+1$ is the depth of the tree.

Nodes of the same level $z$ are labeled by a contineous range of indices: $q=2^{z-1},\cdots, 2^z-1$.

In [5]:
for z in range(1, latt.tree_depth):
    print(list(range(2**(z-1), 2**z)))

[1]
[2, 3]
[4, 5, 6, 7]
[8, 9, 10, 11, 12, 13, 14, 15]


We can iterate through different levels and calculate the rescaled distance, based on which we can pick out the causal connections.

`LatticeSystem` provides the method `reset_causal_graph(radius, scale_resolved=True)` to build/reset the causal graph, given the causality radius. The option `scale_resolved` can be used to switch on or off the scale resolution. The method will return a tuple of edge index and edge type. Example:

In [6]:
latt = LatticeSystem(4, 2)
latt.reset_causal_graph(1., scale_resolved=True)

(tensor([[ 1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7],
         [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]]),
 tensor([1, 5, 6, 3, 6, 3, 4, 2, 4, 2, 4, 2, 4, 2]))

In [7]:
latt = LatticeSystem(4, 2)
latt.reset_causal_graph(1., scale_resolved=False)

(tensor([[ 1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7],
         [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]]),
 tensor([3, 4, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4, 3, 4]))

Less edge types will be assigned if `scale_resolved` is off, because edges at different scales can now be idendentified.

### Group Algebra

`Group` represents a group specified by the multiplication table. Group elements will be labeled by integers (ranging from 0 to the order of the group). The element 0 is always treated as the identity element of the group. `Group` provides methods to perform element-wise group multiplication for Torch tensors.

In [163]:
class Group(object):
    """Represent a group, providing multiplication and inverse operation.
    
    Args:
    mul_table: multiplication table as a tensor, e.g. Z2 group: tensor([[0,1],[1,0]])
    """
    def __init__(self, mul_table: torch.Tensor):
        super(Group, self).__init__()
        self.mul_table = mul_table
        self.order = mul_table.size(0) # number of group elements
        gs, ginvs = torch.nonzero(self.mul_table == 0, as_tuple=True)
        self.inv_table = torch.gather(ginvs, 0, gs)
    
    def __iter__(self):
        return iter(range(self.order))
    
    def __repr__(self):
        return 'Group({} elements)'.format(self.order)
    
    def inv(self, input: torch.Tensor):
        return torch.gather(self.inv_table.expand(input.size()[:-1]+(-1,)), -1, input)
    
    def mul(self, input1: torch.Tensor, input2: torch.Tensor):
        output = input1 * self.order + input2
        return torch.gather(self.mul_table.flatten().expand(output.size()[:-1]+(-1,)), -1, output)
    
    def prod(self, input, dim: int, keepdim: bool = False):
        input_size = input.size()
        flat_mul_table = self.mul_table.flatten().expand(input_size[:dim]+input_size[dim+1:-1]+(-1,))
        output = input.select(dim, 0)
        for i in range(1, input.size(dim)):
            output = output * self.order + input.select(dim, i)
            output = torch.gather(flat_mul_table, -1, output)
        if keepdim:
            output = output.unsqueeze(dim)
        return output
    
    def val(self, input, val_table = None):
        if val_table is None:
            val_table = torch.zeros(self.order)
            val_table[0] = 1.
        elif len(val_table) != self.order:
            raise ValueError('Group function value table must be of the same size as the group order, expect {} got {}.'.format(self.order, len(val_table)))
        return torch.gather(val_table.expand(input.size()[:-1]+(-1,)), -1, input)

Create a $S_3$ group.

In [9]:
G = Group(torch.tensor([[0,1,2,3,4,5],[1,0,3,2,5,4],[2,4,0,5,1,3],[3,5,1,4,0,2],[4,2,5,0,3,1],[5,3,4,1,2,0]]))

Multiplying two tensors element-wise following the group multiplication rule.

In [10]:
a = torch.tensor([[0,1,2],[3,4,5]])
b = torch.tensor([[5,4,3],[2,1,0]])
G.mul(a, b)

tensor([[5, 5, 5],
        [1, 2, 5]])

Product of each row of a tensor in the given dimension.

In [11]:
G.prod(a, dim=0)

tensor([3, 5, 3])

Group inversion of all elements.

In [12]:
G.inv(a)

tensor([[0, 1, 2],
        [4, 3, 5]])

Evaluate a group function given by a value table `val_table` (default function: group delta function).

In [13]:
G.val(a)

tensor([[1., 0., 0.],
        [0., 0., 0.]])

In [14]:
G.val(a, val_table=torch.tensor([1.,0.,0.,-0.5,-0.5,0.]))

tensor([[ 1.0000,  0.0000,  0.0000],
        [-0.5000, -0.5000,  0.0000]])

### Haar Transformation

Lattice system supports Haar tranformation that implements the holographic mapping.

In [164]:
class HaarTransform(dist.Transform):
    """ Haar wavelet transformation (bijective)
        transformation takes real space configurations x to wavelet space encoding y
    
        Args:
        group: a group structure for each unit
        lattice: a lattice system containing information of the group and lattice shape
    """
    def __init__(self, group: Group, lattice: LatticeSystem):
        super(HaarTransform, self).__init__()
        self.group = group
        self.lattice = lattice
        self.bijective = True
        self.make_wavelet()
        
    # construct Haar wavelet basis
    def make_wavelet(self):
        self.wavelet = torch.zeros(torch.Size([self.lattice.sites, self.lattice.sites]), dtype=torch.int)
        self.wavelet[0] = 1
        for z in range(1,self.lattice.tree_depth):
            block_size = 2**(z-1)
            for q in range(block_size):
                node_range = 2**(self.lattice.tree_depth-1-z) * torch.tensor([2*q+1,2*q+2])
                nodes = torch.arange(*node_range)
                sites = self.lattice.node_index[nodes]
                self.wavelet[block_size + q, sites] = 1 
                
    def _call(self, x):
        y = self.group.prod(x.unsqueeze(-1) * self.wavelet, -2)
        return y.view(x.size()[:-1]+torch.Size(self.lattice.shape))
    
    def _inverse(self, y):
        x = y.flatten(-self.lattice.dimension)[...,self.lattice.node_index]
        def renormalize(x):
            if x.size(-1) > 1:
                x0 = x[...,0::2]
                x1 = x[...,1::2]
                return torch.cat((renormalize(x0), self.group.mul(self.group.inv(x0), x1)), -1)
            else:
                return x
        return renormalize(x)
    
    def log_abs_det_jacobian(self, x, y):
        return torch.tensor(0.)

#### Generation

The generation process maps from the holographic bulk (elements on the node $g_q$) to the holographic boundary (elements on the site $f_i$). Starting from $g_q$ on the IR side, in the $z$-th generation step ($z=1,2,\cdots,\zeta-1$ where $\zeta=\log_2 L^d+1$ is the depth of the tree)
$$\text{for }q=0,\cdots,2^{z-1}-1, f_{i(2^{\zeta-1-z}(2q+1:2q+2))} = g_{2^{z-1}+q}f_{i(2^{\zeta-1-z}(2q+1:2q+2))}.$$
The rule can be summerized as a wavelet matrix $w_{iq}=0,1$ such that
$$f_i=\prod_{q}g_q^{w_{iq}},$$
where the wavelet matrix is given as follows.

In [16]:
ht = HaarTransform(G, LatticeSystem(4, 2))
ht.wavelet

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1],
        [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int32)

#### Renormalization

The renormalization process maps from the holographic boundary (elements on the site $f_i$) to the holographic bulk (elements on the node $g_q$). The first step is to reshuffle the $g_i$ into a the node order by

$$g_q = f_{i(q)},$$

where $i(q)$ is the node_index mapping. Now working with $g_q$, in the $z$-th RG step ($z=1,2,\cdots,\zeta-1$ where $\zeta=\log_2 L^d+1$ is the depth of the RG tree)
$$\text{for }q=0,\cdots,2^{\zeta-1-z}-1: g_{q} = g_{2q}, g_{2^{\zeta-1-z}+q} = g_{2q}^{-1}g_{2q+1}.$$
The result of the iteration is the Haar transformed configuration.

In [17]:
x = torch.randint(G.order, (1,16))
y = ht(x)
x, y, ht.inv(y)-x

(tensor([[4, 1, 1, 0, 0, 4, 2, 0, 5, 0, 4, 2, 3, 4, 2, 1]]),
 tensor([[[4, 1, 2, 1],
          [4, 4, 1, 3],
          [2, 5, 2, 0],
          [0, 4, 2, 4]]]),
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]))

### One-Hot and Categorical Representations

The sample draw from the autoregressive model are one-hot embeddings. Convert between one-hot and categorical representations.

In [165]:
class OneHotCategoricalTransform(dist.Transform):
    """Convert between one-hot and categorical representations.
    
    Args:
    num_classes: number of classes."""
    def __init__(self, num_classes: int):
        super(OneHotCategoricalTransform, self).__init__()
        self.num_classes = num_classes
        self.bijective = True
    
    def _call(self, x):
        # one-hot to categorical
        return x.max(dim=-1)[1]
    
    def _inverse(self, y):
        # categorical to one-hot
        return F.one_hot(y, self.num_classes).to(dtype=torch.float)
    
    def log_abs_det_jacobian(self, x, y):
        return torch.tensor(0.)

Test:

In [19]:
x = torch.randint(2, (2,4))
x

tensor([[0, 0, 0, 1],
        [1, 1, 0, 0]])

In [20]:
oc = OneHotCategoricalTransform(2)
y = oc.inv(x)
y

tensor([[[1., 0.],
         [1., 0.],
         [1., 0.],
         [0., 1.]],

        [[0., 1.],
         [0., 1.],
         [1., 0.],
         [1., 0.]]])

In [21]:
oc(y)

tensor([[0, 0, 0, 1],
        [1, 1, 0, 0]])

They can be pack into a Transform.

### Graph Convolution Layer

`GCN` uses `torch_geometric.nn.MessagePassing` to create a graph convolution network of given numbers of input features $f_\text{in}$, output features $f_\text{out}$ and edge features $f_\text{edge}$. 

**Input**: 
* `x`: node feature vectors $x_q$.
* `edge_index`: a tensor of shape $(2,E)$ where $E$ is the number of (active) edges in the graph. `edge_index[0]` gives the source node indices, `edge_index[1]` gives the target node indices correspondingly. Together they specify a directed graph (the causal influence graph).
* `edge_attr`: edge attribution (edge feature vectors) $e_{pq}$.

**output**:
* `y`: a new set of node feature vectors $y_p$, given by
$$y_p = \sum_{q\in\mathcal{N}(p)} e_{pq}^\intercal (W x_q + b),$$
where $W$ is a rank-3 tensor of the shape $(f_\text{edge}, f_\text{out}, f_\text{in})$ and $b$ is a matrix of the shape $(f_\text{edge}, f_\text{out})$. The additive bias $b$ can be turned off by setting `bias=False`.

In [159]:
from torch_geometric.nn import MessagePassing
class GraphConv(MessagePassing):
    """ Graph Convolution layer 
        
        Args:
        in_features: number of input features per node
        out_features: number of output features per node 
        edge_features: number of features per edge
    """
    def __init__(self, in_features: int, out_features: int, edge_features: int, bias: bool = True):
        super(GraphConv, self).__init__(aggr='add')
        self.in_features = in_features
        self.out_features = out_features
        self.edge_features = edge_features
        self.weight = nn.Parameter(torch.Tensor(edge_features, out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(edge_features, out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    
    def extra_repr(self):
        return 'in_features={}, out_features={}, edge_features={}, bias={}'.format(
            self.in_features, self.out_features, self.edge_features, self.bias is not None)

    def reset_parameters(self) -> None:
        bound = 1 / math.sqrt(self.weight.size(1))
        nn.init.uniform_(self.weight, -bound, bound)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -bound, bound)
        
    def forward(self, x, edge_index, edge_attr):
        # x: shape [..., N, in_features]
        # edge_index: shape [2, E]
        # edge_attr: shape [E, edge_features]
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)
    
    def message(self, x_j, edge_attr):
        # x_j: shape [..., E, in_features]
        # edge_attr: [E, edge_features]
        weight = torch.tensordot(edge_attr, self.weight, dims=1)
        x_j = torch.sum(weight * x_j.unsqueeze(-2), -1)
        if self.bias is not None:
            bias = torch.tensordot(edge_attr, self.bias, dims=1)
            x_j += bias
        return x_j
    
    def forward_from(self, x, i, edge_index, edge_attr):
        mask = (edge_index[0] == i)
        return self(x, edge_index[:, mask], edge_attr[mask, :])

Example:

In [23]:
gc = GraphConv(2, 3, 3)
gc

GraphConv(in_features=2, out_features=3, edge_features=3, bias=True)

In [24]:
x = torch.tensor(
        [[[0., 1.],
          [1., 0.],
          [0., 1.],
          [1., 0.]],
 
         [[1., 0.],
          [1., 0.],
          [1., 0.],
          [0., 1.]],
 
         [[1., 0.],
          [1., 0.],
          [0., 1.],
          [1., 0.]]])
edge_index = torch.tensor([[1, 1, 2],
                           [2, 3, 3]])
edge_attr = torch.tensor([[-0.9325, -1.0142,  1.5556],
                          [-0.7909, -0.4774, -1.2065],
                          [-0.1908, -0.6673,  0.1268]])
gc(x, edge_index, edge_attr)

tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 1.0559,  1.3136,  0.3057],
         [-0.6175, -0.1689,  0.4042]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 1.0559,  1.3136,  0.3057],
         [-0.0549,  0.1292,  0.5020]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 1.0559,  1.3136,  0.3057],
         [-0.6175, -0.1689,  0.4042]]], grad_fn=<ScatterAddBackward>)

To be used in the autoregressive model, `GraphConv` also provides a `forward_from` method that allows to forward pass from a single given node. This is implemented by masking out other edges that are not going out from the specific node.

In [25]:
edge_index[0] == 2

tensor([False, False,  True])

In [26]:
gc.forward_from(x, 2, edge_index, edge_attr)

tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [-0.2210,  0.2515,  0.1008]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.3416,  0.5496,  0.1987]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000],
         [-0.2210,  0.2515,  0.1008]]], grad_fn=<ScatterAddBackward>)

### Autoregressive Model

`AutoregressiveModel` uses graph convolutional network (GCN) to model the conditional probability distribution following the causal influence (message passing on a directed graph).

![](./image/architecture.png)

Figure (a) shows the causal graph of 8 nodes in a binary tree. The node 0  is somewhat special, that it is always sampled independently from the uniform distribution, and it does not cast causal influence to other nodes. Edges are colored differently to indicate different types of causal influences. Figure (b) is the multi-layer GCN that respects the causal relation. It is important that the first (bottom) layer should not have self-loop connections, but the remaining layers can have. Edges of different colors will correspond to different linear maps of the node feature vectors. The edge type is first maped to a edge feature vector by an embedding layer, the edge feature vectors are then distributed to and shared by the GCN layers.

The hierachical autoregressive model models the probability of a sample $z$ as
$$\begin{split}
p(z)=&p(z_0)p(z_1)p(z_2|z_1)p(z_3|z_1)\\
&p(z_4|z_1,z_2)p(z_5|z_1,z_2)p(z_6|z_1,z_3)p(z_7|z_1,z_3)
\end{split}$$
The conditional distributions are modeled by neural networks.

What is the advange of modeling the Haar encoding in the holographic space?
* **Resolve the criticality**: the holographic mapping brings a scale-free system to a local system (with an emergent scale set by the hyperbolic radius and the critical exponent). This can be seen from the correlation function 
$$C(r)\sim r^{-\alpha}\to \sim e^{-d/\xi},$$
where $d=R\ln r$ and $\xi = R/\alpha$. The complexity of modeling correlation at all scales on the boundary is reduced to modeling correlations locally in the bulk.
* **Shorten the causal chain**: conventional approach like pixel-CNN has unnatural causal structures (why a single pixel must causally depend on its upper-half-plane?). The natural way to think about generating a image is to follow the reverse process of renormalization group. The scale itself becomes the emergent time direction in the hyperbolic space which defines a more natural causal structure: start paining the outline first, then add the details. A remarkable feature is that *time is short in the holographic bulk*, the causal chain is at most of the length $\sim\log L$ (i.e. logarithmic in system size), and the causal cone has limited width (like the past light cone in an expanding universe, which light can not catch up the collapse of universe if we look backwards). This makes the sampling and generation efficient.

In [166]:
class AutoregressiveModel(nn.Module, dist.Distribution):
    """ Represent a generative model that can generate samples and evaluate log probabilities.
        
        Args:
        nodes: number of units in the model
        features: a list of feature dimensions from the input layer to the output layer
        nonlinearity: activation function to use 
        bias: whether to learn the additive bias in heap linear layers
    """
    
    def __init__(self, lattice: LatticeSystem, edge_features: int, node_features, 
                 nonlinearity: str = 'ReLU', bias: bool = True):
        super(AutoregressiveModel, self).__init__()
        self.lattice = lattice
        self.nodes = self.lattice.sites
        self.edge_index = self.lattice.edge_index
        self.edge_type = self.lattice.edge_type
        self.edge_index_ext, self.edge_type_ext = self.edge_extension()
        self.num_edge_type = self.edge_type.max() + 1
        self.edge_features = edge_features
        self.edge_embedding = nn.Embedding(self.num_edge_type, self.edge_features)
        if isinstance(node_features, int):
            self.node_features = [node_features, node_features]
        else:
            if node_features[0] != node_features[-1]:
                raise ValueError('In features {}, the first and last feature dimensions must be equal.'.format(features))
            self.node_features = node_features
        self.layers = nn.ModuleList()
        for l in range(1, len(self.node_features)):
            if l > 1: 
                self.layers.append(getattr(nn, nonlinearity)())
            self.layers.append(GraphConv(self.node_features[l - 1], self.node_features[l], self.edge_features, bias))
        dist.Distribution.__init__(self, event_shape=torch.Size([self.nodes, self.node_features[0]]))
        self.has_rsample = True
    
    def edge_extension(self):
        node_list = torch.arange(1, self.nodes)
        edge_index = torch.stack((node_list, node_list), 0)
        edge_type = torch.zeros(edge_index.size(-1), dtype=self.edge_type.dtype)
        edge_index_ext = torch.cat((edge_index, self.edge_index), -1)
        edge_type_ext = torch.cat((edge_type, self.edge_type), -1)
        return edge_index_ext, edge_type_ext
    
    def extra_repr(self):
        return '(nodes): {}, (edge_features): {}, (node_features): {}'.format(self.nodes, self.edge_features, self.node_features) + super(AutoregressiveModel, self).extra_repr()
          
    def forward(self, input):
        edge_attr = self.edge_embedding(self.edge_type)
        edge_attr_ext = self.edge_embedding(self.edge_type_ext)
        for l, layer in enumerate(self.layers): # apply layers
            if isinstance(layer, GraphConv): # for graph convolution layers
                if l == 0: # first layer
                    output = layer(input, self.edge_index, edge_attr)
                else: # remaining layers
                    output = layer(output, self.edge_index_ext, edge_attr_ext)
            else: # activation layers
                output = layer(output)
        return output # logits
    
    def log_prob(self, value):
        logits = self(value) # forward pass to get logits
        return torch.sum(value * F.log_softmax(logits, dim=-1), (-2,-1))

    def sampler(self, logits, dim=-1): # simplified from F.gumbel_softmax
        gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
        gumbels += logits
        index = gumbels.max(dim, keepdim=True)[1]
        return torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)

    def _sample(self, batch_size: int, sampler = None):
        if sampler is None: # use default sampler
            sampler = self.sampler
        # create a list of tensors to cache layer-wise outputs
        cache = [torch.zeros(batch_size, self.nodes, self.node_features[0])]
        for l, layer in enumerate(self.layers):
            if isinstance(layer, GraphConv): # for graph convolution layers
                node_features = layer.out_features
                cache.append(torch.zeros(batch_size, self.nodes, node_features))
            else: # activation layers
                cache.append(torch.zeros(batch_size, self.nodes, node_features))
        # autoregressive batch sampling
        edge_attr = self.edge_embedding(self.edge_type)
        edge_attr_ext = self.edge_embedding(self.edge_type_ext)
        cache[0][..., 0, :] = sampler(cache[0][..., 0, :]) # always sample node 0 uniformly
        for i in range(1, self.nodes):
            for l, layer in enumerate(self.layers):
                if isinstance(layer, GraphConv): # for graph convolution layers
                    if l==0: # first layer
                        cache[l + 1] += layer.forward_from(cache[l], i - 1, self.edge_index, edge_attr)
                    else: # remaining layers
                        cache[l + 1] += layer.forward_from(cache[l], i, self.edge_index_ext, edge_attr_ext)
                else: # activation layers
                    cache[l + 1][..., i, :] = layer(cache[l][..., i, :])
            # the last cache hosts the logit, sample from it 
            cache[0][..., i, :] = sampler(cache[-1][..., i, :])
        return cache # cache[0] hosts the sample
    
    def sample(self, batch_size=1):
        with torch.no_grad():
            cache = self._sample(batch_size)
        return cache[0]
    
    def rsample(self, batch_size=1, tau=None, hard=False):
        if tau is None: # if temperature not given
            tau = 1/(self.features[-1]-1) # set by the out feature dimension
        cache = self._sample(batch_size, lambda x: F.gumbel_softmax(x, tau, hard))
        return cache[0]    

Example: create a hierachical autoregressive model

In [167]:
G = Group(torch.tensor([[0,1],[1,0]]))
latt = LatticeSystem(4, 2)
ar = AutoregressiveModel(latt, 3, [G.order, 3, G.order])

The model contains the following components

In [168]:
ar

AutoregressiveModel(
  (nodes): 16, (edge_features): 3, (node_features): [2, 3, 2]
  (edge_embedding): Embedding(7, 3)
  (layers): ModuleList(
    (0): GraphConv(in_features=2, out_features=3, edge_features=3, bias=True)
    (1): ReLU()
    (2): GraphConv(in_features=3, out_features=2, edge_features=3, bias=True)
  )
)

Verify that the forward pass recovers the logits generated through the autoregressive sampling.

In [169]:
cache = ar._sample(1)
ar(cache[0]) - cache[-1]

tensor([[[0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.]]], grad_fn=<SubBackward0>)

### Energy Model

Finally, we need an energy model to describe the Statistical Mechanics system. `EnergyModel` provides the function to evalutate the energy of a configuration.

In [170]:
from phys import *
class EnergyModel(nn.Module):
    """ Energy mdoel that describes the physical system. Provides function to evaluate energy.
    
        Args:
        group: a specifying the group on each site
        lattice: a lattice system containing information of the group and lattice shape
        energy: lattice Hamiltonian in terms of energy terms
    """
    def __init__(self, group: Group, lattice: LatticeSystem, energy: EnergyTerms):
        super(EnergyModel, self).__init__()
        self.group = group
        self.lattice = lattice
        self.energy = energy.on(self.group, self.lattice)
    
    def extra_repr(self):
        return '(lattice): {}'.format(self.lattice) + super(EnergyModel, self).extra_repr()
        
    def forward(self, input):
        return self.energy(input)

Consider a 2D Ising model on a square lattice
$$H= -J \sum_{i}(\sigma_i\sigma_{i+\hat{x}} + \sigma_i\sigma_{i+\hat{y}}).$$
The Hamiltonian can be typed in as (see the following subsection for explaination of the notation)

In [33]:
J = 0.7
H = - J * (TwoBody(torch.tensor([1.,-1.], dtype=float), (1,0)) 
           + TwoBody(torch.tensor([1.,-1.], dtype=float), (0,1)))

The energy model is defined by the lattice system and the Hamiltonian

In [34]:
G = Group(torch.tensor([[0,1],[1,0]]))
latt = LatticeSystem(4, 2)
energy = EnergyModel(G, latt, H)
energy

EnergyModel(
  (lattice): LatticeSystem(4x4 grid with tree depth 5
  	(node_index): tensor([ 0,  1,  4,  5,  2,  3,  6,  7,  8,  9, 12, 13, 10, 11, 14, 15])
  	(edge_index): tensor([[ 1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7],
          [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
  	(edge_type): tensor([1, 5, 6, 3, 6, 3, 4, 2, 4, 2, 4, 2, 4, 2]))
  (energy): EnergyTerms(
    (0): TwoBody(tensor([-0.7000,  0.7000], dtype=torch.float64) across (1, 0))
    (1): TwoBody(tensor([-0.7000,  0.7000], dtype=torch.float64) across (0, 1))
  )
)

Let us generate some spin configurations.

In [35]:
ar = AutoregressiveModel(latt, 3, [G.order, 3, G.order], bias=False)
oc = OneHotCategoricalTransform(G.order)
ht = HaarTransform(G, latt)
x = ht(oc(ar.sample(3)))
x

tensor([[[1, 0, 1, 0],
         [0, 1, 1, 0],
         [1, 1, 0, 1],
         [0, 1, 0, 1]],

        [[1, 0, 1, 1],
         [0, 1, 0, 0],
         [1, 0, 1, 1],
         [0, 1, 1, 1]],

        [[0, 1, 1, 0],
         [1, 0, 1, 1],
         [0, 0, 0, 1],
         [1, 0, 1, 1]]])

The energy model can be used to evalutate the energy of these spin configuraitons.

In [36]:
energy(x)

tensor([8.4000, 5.6000, 2.8000], dtype=torch.float64)

#### Hamiltonian Scripting System

(See [phys.py](./phys.py) for code details).

In order to facilitate the intuitive formulation of Hamiltonian, we have introduced a scripting system. Physical Hamiltonians are always sum of local energy terms. In this system, each energy term is a subclass of `nn.Module` and each Hamiltonian is a subclass of `nn.ModuleList` (which contains the collection of energy terms). In this way, the evaluation of the total energy of the Hamiltonian can be passed down to each energy terms.

We have introduced two kinds of energy terms
* `OnSite`: on-site energy term $E_1(g_i)$,
* `TwoBody`: two-body interaction term $E_2(g_i,g_j)$.

More complicated interaction terms can be introduced under this framework if necessary. These energy terms are group functions: $E_1:G\to\mathbb{R}$, $E_2:G\times G\to\mathbb{R}$. These group functions can be specified by a value table, which enumerated the value that each group element maps to. For example, for the $\mathbb{Z}_2=\{0,1\}$ group ($0$-identity, $1$-generator), if we want to specify $E_1(g_\sigma)=\sigma$, i.e.
$$E_1(0)=+1, E_1(1)=-1,$$
the value talbe is $[+1,-1]$. Such a term can be created as follows

In [37]:
OnSite(torch.tensor([1.,-1.],dtype=float))

OnSite(tensor([ 1., -1.], dtype=torch.float64))

For two-body term, we assume that it always take the form of
$$E_2(g_i,g_j)=E_2(g_i^{-1}g_j),$$
such that we will only need to a single-variable group function, unsing the same value table representation. For example,

In [38]:
TwoBody(torch.tensor([1.,-1.],dtype=float), (1,0))

TwoBody(tensor([ 1., -1.], dtype=torch.float64) across (1, 0))

The two-body term also carries a second argument to specify the relative direction from site-$i$ to site-$j$. If the value table is not specified, the default group function will be taken to be the delta function (like Potts model), which maps the identity element to 1 and the others to 0.

In this system, we can add, subtract, scalar multiply and negate the energy terms. Energy terms adding together will be represented as a collection of terms in a list (`nn.ModuleList`), which correspond to a Hamiltonian.

In [39]:
-2.8 * OnSite() + 5.2 * (TwoBody(shifts=(1,0)) + TwoBody(shifts=(0,1)))

EnergyTerms(
  (0): TwoBody(5.2)
  (1): TwoBody(5.2)
  (2): OnSite(-2.8)
)

A Hamiltonian `H` needs to be further put on a lattice by calling `H.on(lattice)`. Only after it putting on a lattice, the Hamiltonian has a concrete meaning, and then the energy of the system can be evaluated by the `H.forward(input)` method on a spin configuration as `input`.

### Holographic pixel Graph Convolutional Network (HpGCN)

Putting all components together

In [193]:
class HpGCN(nn.Module, dist.TransformedDistribution):
    """ Combination of hierarchical autoregressive and flow-based model for lattice models.
    
        Args:
        energy: a energy model to learn
        hidden_features: a list of feature dimensions of hidden layers
        nonlinearity: activation function to use 
        bias: whether to learn the additive bias in heap linear layers
    """
    def __init__(self, energy: EnergyModel, edge_features: int, hidden_node_features):
        super(HpGCN, self).__init__()
        self.energy = energy
        self.group = energy.group
        self.lattice = energy.lattice
        self.haar = HaarTransform(self.group, self.lattice)
        self.onecat = OneHotCategoricalTransform(self.group.order)
        node_features = [self.group.order] + hidden_node_features + [self.group.order]
        auto = AutoregressiveModel(self.lattice, edge_features, node_features)
        dist.TransformedDistribution.__init__(self, auto, [self.onecat, self.haar])
        self.transform = dist.ComposeTransform(self.transforms)

Create a holographic pixel flow model

In [173]:
%run "main.py"
G = Group(torch.tensor([[0,1],[1,0]]))
latt = LatticeSystem(4, 2)
J = 0.7
H = - J * (TwoBody(torch.tensor([1.,-1.]), (1,0)) 
           + TwoBody(torch.tensor([1.,-1.]), (0,1)))
model = HpGCN(EnergyModel(G, latt, H), 3, [3])
model

HpGCN(
  (energy): EnergyModel(
    (lattice): LatticeSystem(4x4 grid with tree depth 5
    	(node_index): tensor([ 0,  1,  4,  5,  2,  3,  6,  7,  8,  9, 12, 13, 10, 11, 14, 15])
    	(edge_index): tensor([[ 1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7],
            [ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
    	(edge_type): tensor([1, 5, 6, 3, 6, 3, 4, 2, 4, 2, 4, 2, 4, 2]))
    (energy): EnergyTerms(
      (0): TwoBody(tensor([-0.7000,  0.7000]) across (1, 0))
      (1): TwoBody(tensor([-0.7000,  0.7000]) across (0, 1))
    )
  )
  (base_dist): AutoregressiveModel(
    (nodes): 16, (edge_features): 3, (node_features): [2, 3, 2]
    (edge_embedding): Sequential(
      (0): Embedding(7, 3)
      (1): Softmax(dim=-1)
    )
    (layers): ModuleList(
      (0): GraphConv(in_features=2, out_features=3, edge_features=3, bias=True)
      (1): ReLU()
      (2): GraphConv(in_features=3, out_features=2, edge_features=3, bias=True)
    )
  )
)

Draw samples from the model.

In [174]:
x = model.sample(2)
x

tensor([[[1, 1, 1, 0],
         [1, 0, 1, 1],
         [0, 1, 0, 1],
         [1, 1, 0, 0]],

        [[1, 1, 0, 0],
         [1, 1, 1, 0],
         [0, 0, 0, 1],
         [1, 0, 1, 0]]])

Evaluate log probabilities of samples.

In [175]:
model.log_prob(x)

tensor([-15.0476, -11.3792], grad_fn=<AddBackward0>)

Evaluate energies of samples.

In [176]:
model.energy(x)

tensor([2.8000, 5.6000])

Inverse transform samples to the latent space.

In [177]:
model.transform.inv(x)

tensor([[[0., 1.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [0., 1.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [0., 1.],
         [1., 0.]],

        [[0., 1.],
         [0., 1.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [0., 1.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [1., 0.],
         [0., 1.],
         [1., 0.],
         [0., 1.],
         [0., 1.],
         [0., 1.]]])

## Model Training

### Loss Function

**Reverse KL with log-trick**. The goal is to minimize the difference between the model distribution $q_\theta(x)$ and the target distribution $p(x) \propto e^{-E(x)}$ by minimizing the reverse KL divergence
$$\begin{split}\mathcal{L}&=\mathsf{KL}(q_\theta||p)\\
&=\sum_{x} q_\theta(x) \ln \frac{q_\theta(x)}{p(x)}\\
&=\sum_{x}q_\theta(x)(E(x)+\ln q_\theta(x)). 
\end{split}$$
The parameter dependence is only in $q_\theta$.

The gradient is given by
$$\begin{split}\partial_\theta\mathcal{L}&= \partial_\theta \sum_{x}q_\theta(x)(E(x)+\ln q_\theta(x))\\
&= \sum_{x}[(\partial_\theta q_\theta(x))(E(x)+\ln q_\theta(x))+q_\theta(x)\partial_\theta \ln q_\theta(x)]\\
\end{split}$$
The last term can be dropped because 
$$\sum_x q_\theta(x)\partial_\theta \ln q_\theta(x) = \sum_x \partial_\theta q_\theta(x)=\partial_\theta\sum_x q_\theta(x)=\partial_\theta 1 = 0,$$
the remaining term reads
$$\begin{split}\partial_\theta\mathcal{L}&= \sum_{x}(\partial_\theta q_\theta(x))(E(x)+\ln q_\theta(x))\\
&= \sum_{x}(\partial_\theta q_\theta(x))R(x)\\
&= \mathbb{E}_{x\sim q_\theta}(\partial_\theta \ln q_\theta(x))R(x)\\
\end{split}$$
with a reward signal $R(x)=E(x)+\ln q_\theta(x)$ in the context of reinforcement learning. The gradient signal $\partial_\theta \ln q_\theta(x)$ is weighted by $R(x)$, such that when $R(x)$ is large for a configuration $x$, the gradient descent will decrease the log likelihood $\ln q_\theta(x)$ for that configuration, hence the optimzation will try to reduce the free energy.

However we should not just drop the last term for finite batches, instead we should introduce a Lagrangian multiplier to counter the gradient signal that is towards the direction of violating the normalization condition. This amounts to subtracting $R(x)$ by a baseline value $b=\mathbb{E}_{x\sim q_\theta} R(x)$, which can be estimated within each batch. The baseline subtraction helps to reduce the variance of the gradient.

### Optimization

In [178]:
%run "main.py"
G = Group(torch.tensor([[0,1],[1,0]]))
latt = LatticeSystem(2, 2, 1.)
J = 0.1
H = - J * (TwoBody(torch.tensor([1.,-1.]), (1,0)) 
           + TwoBody(torch.tensor([1.,-1.]), (0,1)))
model = HpGCN(EnergyModel(G, latt, H), 5, [3], nonlinearity='ReLU' , bias=True)
#model.base_dist.lattice.edge_type.fill_(1)
optimizer = optim.Adam(model.parameters(), lr=0.01)
batch_size = 100

In [184]:
train_loss = 0.
free_energy = 0.
echo = 100
for epoch in range(500):
    x = model.sample(batch_size)
    log_prob = model.log_prob(x)
    energy = model.energy(x)
    free = energy + log_prob.detach()
    meanfree = free.mean()
    loss = torch.sum(log_prob * (free - meanfree))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    free_energy += meanfree.item()
    if (epoch+1)%echo == 0:
        print('loss: {:.4f}, free energy: {:.4f}'.format(train_loss/echo, free_energy/echo))
        train_loss = 0.
        free_energy = 0.

loss: 0.4743, free energy: -5.0326
loss: -1.2254, free energy: -5.0419
loss: -0.5264, free energy: -5.0880
loss: -0.1949, free energy: -5.0958
loss: 1.9242, free energy: -5.0534


In [181]:
J = 0.7
H = - J * (TwoBody(torch.tensor([1.,-1.]), (1,0)) 
           + TwoBody(torch.tensor([1.,-1.]), (0,1)))
model.energy.update(H)

EnergyModel(
  (lattice): LatticeSystem(2x2 grid with tree depth 3
  	(node_index): tensor([0, 1, 2, 3])
  	(edge_index): tensor([[1, 1],
          [2, 3]])
  	(edge_type): tensor([1, 2]))
  (energy): EnergyTerms(
    (0): TwoBody(tensor([-0.7000,  0.7000]) across (1, 0))
    (1): TwoBody(tensor([-0.7000,  0.7000]) across (0, 1))
  )
)

In [185]:
import itertools
xs = torch.tensor(list(itertools.product([0,1],repeat=4))).view(-1,2,2)
with torch.no_grad():
    ps = model.log_prob(xs).exp()
for i, p in enumerate(ps):
    print(i, p.item())

0 0.411565363407135
1 0.028368087485432625
2 0.0018196612363681197
3 0.026399722322821617
4 0.02799769677221775
5 0.0019298052648082376
6 0.00012378675455693156
7 0.0017959026154130697
8 0.0017959026154130697
9 0.00012378675455693156
10 0.0019298052648082376
11 0.02799769677221775
12 0.026399722322821617
13 0.0018196612363681197
14 0.028368087485432625
15 0.411565363407135


Performance is not stable. Possible reason:
* Compared with the previous version, the UI/IR mixing is reduced. Increasing the pressure on the hidden layer.
* The link variety increases, this makes the model harder to converge.
* Easily trapped at local minimum when temperature is low (try annealing?)

## Model Improvement

### Gradient Analysis

In [186]:
opt2 = optim.SGD(model.parameters(), lr=0.5)
x = xs[3]
print(x)
opt2.zero_grad()
obj = model.log_prob(xs[3])
print(obj)
obj.backward()
opt2.step()

tensor([[0, 0],
        [1, 1]])
tensor(-3.6344, grad_fn=<AddBackward0>)


In [187]:
for p in model.parameters():
    print(p.grad)

tensor([[ 0.0522,  0.0451,  0.0179,  0.0295, -0.1447],
        [ 0.0026, -0.0071,  0.0008,  0.0020,  0.0016],
        [ 0.0018, -0.0065,  0.0010,  0.0012,  0.0025]])
tensor([[[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.],
         [0., 0.]]])
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
tensor([[[0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.]],

        [[0., 0., 0.],
         [0., 0., 0.]]])
tensor([[-0.0197,  0.0197],
        [ 0.0994, -0.0994],
        [-0.0022,  0.0022],
        [-0.2875,  0.2875],
        [-0.4733,  0.4733]])


The gradient is very sparse. There are only very limited ways to tune the paramters. Model is too rigid.

### Brief Summary

**Key Problems**
* There is no node specific bias!
* There is no three body interaction to capture the correlation of UV nodes conditioned on IR
* There is no regularization of parameters, parameters could do random walk. (But this is simple to solve)

**Future Plan**
* Develop a geometric system, which allows us to systematically capture all the simplicial objects within the local causal range. Instead trying to assign features to edges faces and higher objects, we just need to introduce an embedding for every node, and derive the higher-object feature vector by aggregation (presumably realized by an RNN?). The simplest position encoding for a node is the plan-wave encoding in the hypobolic space (we want to preserve the translation symmetry). Such a feature can be either trainable or not. We hope that translationally equivalent edges and faces will have the same feature, such that parameters can be shared effectively. Then We can use the edge and face features to build a conditional distribution. It always takes the form of $p(x_i|x_{j<i})$. We will sum up the contributions of all higher objects together on the log likelihood level to made a final probability model that is normalized and not over counting.
* Maybe higher object is not really necessary, we just need more edges? Given that attention is all you need!

### Extended Causal Graph + Node Embedding

In [2]:
%run "main.py"

#### Lattice System Upgrade

Upgrade lattice system to provide infomation about extended causal graph and node position encoding.

In [4]:
import math
class LatticeSystem(object):
    """ host lattice information and construct graph in hyperbolic space
        
        Args:
        size: number of size along one dimension (assuming square/cubical lattice)
        dimension: dimension of the lattice
        causal_radius: radius of the causal cone across one level 
        scale_resolved: whether to distinguish edges from different levels
    """
    def __init__(self, size:int, dimension:int):
        self.size = size
        self.dimension = dimension
        self.shape = [size]*dimension
        self.sites = size**dimension
        self.tree_depth = self.sites.bit_length()
        self.node_init()
        
    def __repr__(self):
        return 'LatticeSystem({} grid with tree depth {})'.format('x'.join(str(L) for L in self.shape), self.tree_depth)
    
    def node_init(self):
        self.node_levels = torch.zeros(self.sites, dtype=torch.int)
        self.node_centers = torch.zeros(self.sites, self.dimension, dtype=torch.float)
        self.node_index = torch.zeros(self.sites, dtype=torch.long)
        def partition(rng: torch.Tensor, dim: int, ind: int, lev: int):
            if rng[dim].sum()%2 == 0:
                self.node_levels[ind] = lev
                self.node_centers[ind] = rng.to(dtype=torch.float).mean(-1)
                mid = rng[dim].sum()//2
                rng1 = rng.clone()
                rng1[dim, 1] = mid
                rng2 = rng.clone()
                rng2[dim, 0] = mid
                partition(rng1, (dim + 1)%self.dimension, 2*ind, lev+1)
                partition(rng2, (dim + 1)%self.dimension, 2*ind + 1, lev+1)
            else:
                self.node_index[ind-self.sites] = rng[:,0].dot(self.size**torch.arange(0,self.dimension).flip(0))
        partition(torch.tensor([[0, self.size]]*self.dimension), 0, 1, 1)
        
    def causal_graph(self, causal_radius: float = 1.):
        def discover_causal_connection(z: int):
            # Args: z - level of the source
            source_pos = self.node_centers[2**(z-1):2**z]
            target_pos = self.node_centers[2**z:2**(z+1)]
            diff = source_pos.unsqueeze(0) - target_pos.unsqueeze(1) # difference 
            diff = (diff + self.size/2)%self.size - self.size/2 # assuming periodic boundary
            dist = torch.norm(diff, dim=-1) # distance
            smooth_scale = 2**((self.tree_depth-1-z)/self.dimension)
            mask = dist < causal_radius * smooth_scale
            target_ids, source_ids = torch.nonzero(mask, as_tuple=True)
            return (2**(z-1) + source_ids, 2**z + target_ids)
        level_graded_result = [discover_causal_connection(z) for z in range(1, self.tree_depth-1)]
        # nearest layer causal connections within radius
        adj1 = self.to_adj(torch.stack([torch.cat(tens, 0) for tens in zip(*level_graded_result)]))
        adj2 = adj1 @ adj1
        #adj3 = adj2 @ adj1
        adj11 = torch.tril(adj1 @ adj1.t(), -1)
        adj22 = torch.tril(adj2 @ adj2.t(), -1) # overlap adj11
        adj21 = torch.tril(adj2 @ adj1.t(), -1) # overlap adj1
        edge_index = torch.cat([self.to_edge_index(adj) for adj in 
                                [adj1 + adj21, adj11 + adj22, adj2]], -1)
        return edge_index
    
    def to_adj(self, edge_index):
        ones = torch.ones(edge_index.size(-1), dtype=torch.long)
        return torch.sparse.LongTensor(edge_index.flip(0), ones, torch.Size([self.sites]*2)).to_dense()
    
    def to_edge_index(self, adj):
        target, source = torch.nonzero(adj, as_tuple=True)
        return torch.stack([source, target])
    
    def node_position_encoding(self, node_features: int = None):
        phase = 2*math.pi*self.node_centers/self.size
        levels = self.node_levels.to(dtype=torch.float).unsqueeze(-1)
        encoding = torch.cat([phase.sin(), phase.cos(), levels], -1)
        if node_features is not None:
            pad = node_features - encoding.size(-1)
            if pad < 0:
                raise RuntimeError('Number of node features must be no less than 5, got {}.'.format(node_features))
            zeros = torch.zeros(self.sites, pad)
            encoding = torch.cat([encoding, zeros], -1)
        return encoding

Extend causal relations to include more local relations in the hyperbolic space. 
* 亲子关系
* 婶侄关系
* 姐弟关系（亲+表）
* 祖孙关系

In [5]:
latt = LatticeSystem(4, 2)
latt.causal_graph()

tensor([[ 1,  1,  2,  3,  2,  3,  2,  3,  2,  3,  4,  5,  4,  5,  4,  5,  4,  5,
          6,  7,  6,  7,  6,  7,  6,  7,  2,  4,  4,  5,  4,  5,  6,  8,  8,  9,
          8,  9, 10, 12, 12, 13, 12, 13, 14,  1,  1,  1,  1,  2,  2,  2,  2,  3,
          3,  3,  3],
        [ 2,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  9,  9, 10, 10, 11, 11,
         12, 12, 13, 13, 14, 14, 15, 15,  3,  5,  6,  6,  7,  7,  7,  9, 10, 10,
         11, 11, 11, 13, 14, 14, 15, 15, 15,  4,  5,  6,  7,  8,  9, 10, 11, 12,
         13, 14, 15]])

Use node position encoding to initialize node embedding.

In [6]:
latt = LatticeSystem(4, 2)
latt.node_position_encoding()

tensor([[ 0.0000e+00,  0.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00],
        [-8.7423e-08, -8.7423e-08, -1.0000e+00, -1.0000e+00,  1.0000e+00],
        [ 1.0000e+00, -8.7423e-08, -4.3711e-08, -1.0000e+00,  2.0000e+00],
        [-1.0000e+00, -8.7423e-08,  1.1925e-08, -1.0000e+00,  2.0000e+00],
        [ 1.0000e+00,  1.0000e+00, -4.3711e-08, -4.3711e-08,  3.0000e+00],
        [ 1.0000e+00, -1.0000e+00, -4.3711e-08,  1.1925e-08,  3.0000e+00],
        [-1.0000e+00,  1.0000e+00,  1.1925e-08, -4.3711e-08,  3.0000e+00],
        [-1.0000e+00, -1.0000e+00,  1.1925e-08,  1.1925e-08,  3.0000e+00],
        [ 7.0711e-01,  1.0000e+00,  7.0711e-01, -4.3711e-08,  4.0000e+00],
        [ 7.0711e-01,  1.0000e+00, -7.0711e-01, -4.3711e-08,  4.0000e+00],
        [ 7.0711e-01, -1.0000e+00,  7.0711e-01,  1.1925e-08,  4.0000e+00],
        [ 7.0711e-01, -1.0000e+00, -7.0711e-01,  1.1925e-08,  4.0000e+00],
        [-7.0711e-01,  1.0000e+00, -7.0711e-01, -4.3711e-08,  4.0000e+00],
        [-7.0711e-01,  1.

#### Upgrade GraphConv and AutoregressiveModel

Upgrade from edge-embedding-based to node-embedding-based approach.

In [306]:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
class GraphConv(MessagePassing):
    """ Graph Convolution layer 
        
        Args:
        in_channels: number of input data channels
        out_channels: number of output data channels 
        edge_features: number of features on the edge
        bias: whether to learn an edge depenent bias
    """
    def __init__(self, in_channels: int, out_channels: int, edge_features: int, bias: bool = True):
        super(GraphConv, self).__init__(aggr='add')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_features = edge_features
        self.weight = nn.Parameter(torch.Tensor(edge_features, out_channels, in_channels))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(edge_features, out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    
    def extra_repr(self):
        return 'in_channels={}, out_channels={}, edge_features={}, bias={}'.format(
            self.in_channels, self.out_channels, self.edge_features, self.bias is not None)

    def reset_parameters(self) -> None:
        bound = 1 / math.sqrt(self.weight.size(1))
        nn.init.uniform_(self.weight, -bound, bound)
        if self.bias is not None:
            nn.init.uniform_(self.bias, -bound, bound)
        
    def forward(self, input, edge_index, edge_attr):
        # input: shape [..., N, in_channels]
        # edge_index: shape [2, E]
        # edge_attr: shape [E, edge_features]
#        print('forward receiving:\n    input {}\n    edge_index {}\n    edge_attr {}'.format(input.size(), edge_index.size(), edge_attr.size()))
        return self.propagate(edge_index, input=input, edge_attr=edge_attr)
    
    def forward_from(self, input, i, edge_index, edge_attr):
#        print('--------\nforward_from receiving:\n    input {} at {}\n    edge_index {}\n    edge_attr {}'.format(input.size(), i, edge_index.size(), edge_attr.size()))
        mask = (edge_index[0] == i)
#        print('    -> mask found {} active edges to:'.format(mask.sum().item()))
#        print('    ', edge_index[1, mask].tolist())
        return self(input, edge_index[:, mask], edge_attr[mask, :])
    
    def message(self, input_j, edge_attr):
        # input_j: shape [..., E, in_channels]
        # edge_attr_i: [E, edge_features]
#        print('message receiving:\n    input_j {}\n    edge_attr {}'.format(input_j.size(), edge_attr.size())) 
        weight = torch.tensordot(edge_attr, self.weight, dims=1)
        output = torch.sum(weight * input_j.unsqueeze(-2), -1)
        if self.bias is not None:
            bias = torch.tensordot(edge_attr, self.bias, dims=1)
            output += bias
        return output
    


    
class AutoregressiveModel(nn.Module, dist.Distribution):
    """ Represent a generative model that can generate samples and evaluate log probabilities.
        
        Args:
        lattice: lattice system
        channels: a list of channel dimensions from the input layer to the output layer
        node_features: node embedding dimension
        edge_features: edge embedding dimension
        nonlinearity: activation function to use 
        bias: whether to learn the additive bias in heap linear layers
    """
    
    def __init__(self, lattice: LatticeSystem, channels,
                 node_features: int = 5, edge_features: int = 4, 
                 nonlinearity: str = 'ReLU', bias: bool = True):
        super(AutoregressiveModel, self).__init__()
        self.lattice = lattice
        self.nodes = self.lattice.sites
        self.node_index = torch.arange(self.nodes)
        self.edge_index = self.lattice.causal_graph()
        self.self_loop_index = torch.stack([self.node_index[1:], self.node_index[1:]])
        self.node_features = node_features
        self.edge_features = edge_features
        self.node_attr = self.lattice.node_position_encoding()
#        self.node2vec = nn.Embedding(self.nodes, self.node_features,
#                                    _weight = self.lattice.node_position_encoding(self.node_features))
        self.merger = nn.Bilinear(self.node_features, self.node_features, self.edge_features)
        if isinstance(channels, int):
            self.channels = [channels, channels]
        else:
            if channels[0] != channels[-1]:
                raise ValueError('In channels {}, the first and last channel dimensions must be equal.'.format(channels))
            self.channels = channels
        self.layers = nn.ModuleList()
        for l in range(1, len(self.channels)):
            if l > 1: 
                self.layers.append(getattr(nn, nonlinearity)())
            self.layers.append(GraphConv(self.channels[l - 1], self.channels[l], self.edge_features, bias))
        dist.Distribution.__init__(self, event_shape=torch.Size([self.nodes, self.channels[0]]))
        self.has_rsample = True
    
    def extra_repr(self):
        return '(nodes): {}, (channels): {}\n(node_features): {}, (edge_features): {}'.format(self.nodes, self.channels, self.node_features, self.edge_features) + super(AutoregressiveModel, self).extra_repr()
                
    def forward(self, input):
        # prepare edge information
#        node_attr = self.node2vec(self.node_index) # get node embedding
        node_attr = self.node_attr
        edge_attr = self.merger(node_attr[self.edge_index[0], :], node_attr[self.edge_index[1], :])
        self_loop_attr = self.merger(node_attr[1:], node_attr[1:])
        output = input
        for l, layer in enumerate(self.layers): # apply layers
            if isinstance(layer, GraphConv): # for graph convolution layers
                if l == 0:
                    output = layer(output,
                                   self.edge_index,
                                   edge_attr)
                else:
                    output = layer(output,
                                   torch.cat([self.edge_index, self.self_loop_index], -1),
                                   torch.cat([edge_attr, self_loop_attr], 0))
            else: # activation layers
                output = layer(output)
#            print('forward: l {}'.format(l))
#            print(output)
        return output # logits
    
    def log_prob(self, value):
        logits = self(value) # forward pass to get logits
        return torch.sum(value * F.log_softmax(logits, dim=-1), (-2,-1))

    def sampler(self, logits, dim=-1): # simplified from F.gumbel_softmax
        gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
        gumbels += logits
        index = gumbels.max(dim, keepdim=True)[1]
        return torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)

    def _sample(self, batch_size: int, sampler = None):
        if sampler is None: # use default sampler
            sampler = self.sampler
        # create a list of tensors to cache layer-wise outputs
        cache = [torch.zeros(batch_size, self.nodes, self.channels[0])]
        for l, layer in enumerate(self.layers):
            if isinstance(layer, GraphConv): # for graph convolution layers
                channels = layer.out_channels
                cache.append(torch.zeros(batch_size, self.nodes, channels))
            else: # activation layers
                cache.append(torch.zeros(batch_size, self.nodes, channels))
        # prepare edge information
#        node_attr = self.node2vec(self.node_index) # get node embedding
        node_attr = self.node_attr
        edge_attr = self.merger(node_attr[self.edge_index[0], :], node_attr[self.edge_index[1], :])
        self_loop_attr = self.merger(node_attr[1:], node_attr[1:])
        cache[0][..., 0, :] = sampler(cache[0][..., 0, :]) # always sample node 0 uniformly
        for i in range(1, self.nodes):
            for l, layer in enumerate(self.layers):
                if isinstance(layer, GraphConv): # for graph convolution layers
                    if l==0: # first layer
                        cache[l + 1] += layer.forward_from(cache[l], i - 1,
                                                           self.edge_index,
                                                           edge_attr)
                    else: # remaining layers
                        cache[l + 1] += layer.forward_from(cache[l], i,
                                                           torch.cat([self.edge_index, self.self_loop_index], -1),
                                                           torch.cat([edge_attr, self_loop_attr], 0))
                else: # activation layers
                    cache[l + 1][..., i, :] = layer(cache[l][..., i, :])
            # the last cache hosts the logit, sample from it 
            cache[0][..., i, :] = sampler(cache[-1][..., i, :])
#            print('forward from {}'.format(i))
#            print(cache)
        return cache # cache[0] hosts the sample
    
    def sample(self, batch_size=1):
        with torch.no_grad():
            cache = self._sample(batch_size)
        return cache[0]
    
    def rsample(self, batch_size=1, tau=None, hard=False):
        if tau is None: # if temperature not given
            tau = 1/(self.features[-1]-1) # set by the out feature dimension
        cache = self._sample(batch_size, lambda x: F.gumbel_softmax(x, tau, hard))
        return cache[0]
    
class HpGCN(nn.Module, dist.TransformedDistribution):
    """ Combination of hierarchical autoregressive and flow-based model for lattice models.
    
        Args:
        energy: a energy model to learn
        hidden_features: a list of feature dimensions of hidden layers
        nonlinearity: activation function to use 
        bias: whether to learn the additive bias in heap linear layers
    """
    def __init__(self, energy: EnergyModel, hidden_channels,
                 node_features: int = 5, edge_features: int = 4, 
                 nonlinearity: str = 'ReLU', bias: bool = True):
        super(HpGCN, self).__init__()
        self.energy = energy
        self.group = energy.group
        self.lattice = energy.lattice
        self.haar = HaarTransform(self.group, self.lattice)
        self.onecat = OneHotCategoricalTransform(self.group.order)
        channels = [self.group.order] + hidden_channels + [self.group.order]
        auto = AutoregressiveModel(self.lattice, channels, node_features, edge_features, nonlinearity, bias)
        dist.TransformedDistribution.__init__(self, auto, [self.onecat, self.haar])
        self.transform = dist.ComposeTransform(self.transforms) 
    
    
    

In [128]:
G = Group(torch.tensor([[0,1],[1,0]]))
latt = LatticeSystem(2, 2)
ar = AutoregressiveModel(latt, [G.order, 3, G.order])
#print(ar.edge_index)
with torch.no_grad():
    cache = ar._sample(1)
dif = ar(cache[0]) - cache[-1]
obj = torch.norm(dif)
#obj = torch.norm(ar(cache[0]))
print('-----\n',dif,'\n-----')
#with torch.autograd.set_detect_anomaly(True):
#    obj.backward()
#for p in ar.parameters():
#    print(p.size())
#    print(p.grad)

-----
 tensor([[[0., 0.],
         [0., 0.],
         [0., 0.],
         [0., 0.]]], grad_fn=<SubBackward0>) 
-----


In [327]:
G = Group(torch.tensor([[0,1],[1,0]]))
latt = LatticeSystem(4, 2)
J = 0.2
H = - J * (TwoBody(torch.tensor([1.,-1.]), (1,0)) 
           + TwoBody(torch.tensor([1.,-1.]), (0,1)))
model = HpGCN(EnergyModel(G, latt, H), [8], edge_features=1, nonlinearity='GELU' , bias=True)
#model.base_dist.lattice.edge_type.fill_(1)
optimizer = optim.Adam(model.parameters(), lr=0.01)
batch_size = 200

In [329]:
train_loss = 0.
free_energy = 0.
echo = 100
for epoch in range(500):
    x = model.sample(batch_size)
    log_prob = model.log_prob(x)
    energy = model.energy(x)
    free = energy + log_prob.detach()
    meanfree = free.mean()
    loss = torch.sum(log_prob * (free - meanfree))
    #l1reg = sum(para.norm(1) for para in model.base_dist.node2vec.parameters())
    #tot_loss = loss + 0.001*l1reg
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    free_energy += meanfree.item()
    if (epoch+1)%echo == 0:
        print('loss: {:.4f}, free energy: {:.4f}'.format(train_loss/echo, free_energy/echo))
        train_loss = 0.
        free_energy = 0.

loss: -1.3872, free energy: -11.3800
loss: -1.2772, free energy: -11.3751
loss: -1.8401, free energy: -11.3883
loss: -0.6389, free energy: -11.3976
loss: -1.4647, free energy: -11.3956


In [320]:
J = 0.2
H = - J * (TwoBody(torch.tensor([1.,-1.]), (1,0)) 
           + TwoBody(torch.tensor([1.,-1.]), (0,1)))
model.energy.update(H)

EnergyModel(
  (lattice): LatticeSystem(4x4 grid with tree depth 5)
  (energy): EnergyTerms(
    (0): TwoBody(tensor([-0.2000,  0.2000]) across (1, 0))
    (1): TwoBody(tensor([-0.2000,  0.2000]) across (0, 1))
  )
)

In [204]:
import itertools
xs = torch.tensor(list(itertools.product([0,1],repeat=4))).view(-1,2,2)
with torch.no_grad():
    ps = model.log_prob(xs).exp()
for i, p in enumerate(ps):
    print(i, p.item())

0 0.4245363175868988
1 0.012525171972811222
2 0.01250420231372118
3 0.012461643666028976
4 0.012529365718364716
5 0.012554225511848927
6 0.0003762754495255649
7 0.012512803077697754
8 0.012512803077697754
9 0.0003762754495255649
10 0.012554225511848927
11 0.012529365718364716
12 0.012461643666028976
13 0.01250420231372118
14 0.012525171972811222
15 0.4245363175868988


In [231]:
for para in model.parameters():
    print(para)

Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0048,  0.0048,  0.0000],
        [-0.3605, -0.2768, -0.6904, -0.9490,  1.2043],
        [ 1.0732,  0.4588,  0.2340, -0.7635,  0.6221],
        [-0.7294,  0.2653,  0.0078, -0.8071,  0.0022],
        [-0.0344,  0.4974,  0.0659,  0.3853,  1.5621],
        [ 0.7111,  0.6773,  0.1137,  0.2339, -0.2674],
        [-0.5738,  0.2787,  0.5201,  0.2133,  0.6567],
        [-0.3591,  0.2673, -0.4366,  0.3973, -0.7102],
        [ 0.1938, -0.6889, -0.1219,  0.2702,  0.7793],
        [-0.3176, -0.3621, -0.4251, -0.2156,  0.3835],
        [ 0.1063, -0.0118,  0.3413,  0.5737, -0.0210],
        [ 0.0875,  0.8893,  0.7122,  0.4104,  0.3317],
        [-0.1232, -0.4791, -0.0235,  0.3320,  0.8548],
        [ 0.8407,  0.0113, -0.2151, -0.1885,  0.9369],
        [-0.2285,  0.2005, -0.1119,  0.8275,  0.3268],
        [ 0.0843,  0.7942,  0.3353,  0.6156,  0.5964]], requires_grad=True)
Parameter containing:
tensor([[[ 1.2061,  0.2099,  0.0906,  0.4469, -0.5653],

In [305]:
import jsonpickle
def export(filename, obj):
    with open('./data/' + filename + '.json', 'w') as outfile:
        outfile.write(jsonpickle.encode(obj))
nodevec = list(para.tolist() for para in model.base_dist.node2vec.parameters())
export('nodevec', nodevec)