Load packages.

In [1]:
import math
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
import torch.optim as optim
import torch_scatter

Determine CPU or GPU device.

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device

device(type='cuda')

## Graph

Graph object stores the adjacency matrix in a sparse form.
- **Depth assignment algorithm**. Let $A$ be the adjacency matrix. $u$ is an one-hot vector encoding the active vertices. Initially, $u$ is set to all one. $d$ is the depth vector, initially assigned to all zero.
  - $u'= \text{bool}(A u > 0)$ gives the target vertices under adjacency map,
  - if $\Vert u'\Vert_1=\Vert u\Vert_1$ stop, otherwise $d=d+u'$.

In [3]:
class Graph(object):
    """ Host graph information and enables graph expansion
        
        Args:
        dims: (target_dim, source_dim) number of target/source variables
        indices: index list of adjacency matrix [2, edge_num]
        edge_types: edge type list of adjacency matrix [edge_num]
        source_depths (optional): depth assignments of source variables
    """
    def __init__(self, dims:int or tuple, indices, edge_types, source_depths=None):
        if isinstance(dims, int):
            self.dims = (dims, dims)
        elif isinstance(dims, tuple):
            self.dims = dims
        self.indices = indices
        self.edge_types = edge_types
        self.max_edge_type = edge_types.max().item()
        if source_depths is None:
            self.source_depths = self.get_depth_assignment()
        else:
            self.source_depths = source_depths
        self.edge_depths = self.source_depths[self.indices[1,:]]
        self.max_depth = self.source_depths.max().item()
        
    def __repr__(self):
        return 'Graph({}, {} edges of {} types)'.format('x'.join(str(v) for v in self.dims), self.edge_types.shape[0], self.max_edge_type)
    
    def adjacency_matrix(self):
        return torch.sparse_coo_tensor(self.indices, self.edge_types, self.dims)
    
    def get_depth_assignment(self):
        assert self.dims[0] == self.dims[1], 'get_depths can only be called with square adjacency matrix.'
        dvec = torch.zeros(self.dims[0], dtype=torch.long)
        uvec = torch.ones(self.dims[0], dtype=torch.long)
        adjmat = self.adjacency_matrix()
        while True:
            uvec_new = (adjmat @ uvec > 0).long()
            if uvec_new.sum() == uvec.sum():
                break
            uvec = uvec_new
            dvec += uvec
        if uvec.sum() != 0: # there are nodes trapped in loops
            raise Warning('When assigning depth, discover the following vertices trapped in loops: {}'.format(torch.nonzero(uvec,as_tuple=True)[0].tolist()))
        return dvec
    
    def add_self_loops(self, start = 0):
        # start: the sarting node from which on the self-loop should be added
        assert self.dims[0] == self.dims[1], 'add_self_loops can only be called with square adjacency matrix.'
        loops = torch.arange(start, self.dims[0])
        indices_prepend = torch.stack([loops, loops])
        edge_types_prepend = torch.ones(loops.shape, dtype=torch.long)
        indices = torch.cat([indices_prepend, self.indices], -1)
        edge_types = torch.cat([edge_types_prepend, self.edge_types+1], -1)
        return Graph(self.dims, indices, edge_types, self.source_depths)
    
    def expand(self, target_dim, source_dim):
        # prepare views
        indices = self.indices.view(2,-1,1,1)
        edge_types = self.edge_types.view(-1,1,1)
        target_inds = torch.arange(target_dim).view(-1,1)
        source_inds = torch.arange(source_dim).view(1,-1)
        # calculate indices extension
        target_inds_ext = indices[0,...] * target_dim + target_inds
        source_inds_ext = indices[1,...] * source_dim + source_inds
        # calculate edge type extension
        edge_types_ext = ((edge_types - 1) * target_dim + target_inds) * source_dim + source_inds + 1
        # expand and flatten tensor
        target_inds_ext = target_inds_ext.expand(edge_types_ext.shape).flatten()
        source_inds_ext = source_inds_ext.expand(edge_types_ext.shape).flatten()
        edge_types_ext = edge_types_ext.flatten()
        # expand depths (to the source side)
        source_depths_ext = self.source_depths.repeat_interleave(source_dim)
        dims_ext = (self.dims[0] * target_dim, self.dims[1] * source_dim)
        indices_ext = torch.stack([target_inds_ext, source_inds_ext])
        return Graph(dims_ext, indices_ext, edge_types_ext, source_depths_ext)
    
    def sparse_matrix(self, vector, depth = None):
        if depth is None:
            indices = self.indices
            edge_types = self.edge_types
        else:
            select = self.edge_depths == depth
            indices = self.indices[:, select]
            edge_types = self.edge_types[select]
        return torch.sparse_coo_tensor(indices, vector[edge_types-1], self.dims)

Example:

In [207]:
graph = Lattice(4, 2).causal_graph()
graph.adjacency_matrix().to_dense()

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 2, 5, 1, 6, 0, 0, 4, 7, 0, 0, 0, 0, 0, 0],
        [0, 3, 5, 2, 6, 1, 0, 0, 7, 4, 0, 0, 0, 0, 0, 0],
        [0, 3, 2, 5, 0, 0, 1, 6, 0, 0, 4, 7, 0, 0, 0, 0],
        [0, 3, 5, 2, 0, 0, 6, 1, 0, 0, 7, 4, 0, 0, 0, 0]])

Depth assignment: a vector $d$, s.t. $d_i$ is the depth of the $i$th node. The depth assingment for both the target and the source sides are kept.

In [208]:
graph.source_depths

tensor([0, 0, 1, 2, 2, 3, 2, 3, 3, 4, 3, 4, 5, 5, 5, 5])

In [209]:
graph.indices

tensor([[ 2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7,  8,  8,  8,  9,  9,  9, 10,
         10, 10, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13,
         13, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15],
        [ 1,  2,  1,  2,  1,  1,  3,  2,  1,  1,  3,  2,  4,  1,  5,  1,  3,  2,
          1,  6,  7,  1,  3,  2,  1,  9,  3,  5,  4,  8,  2,  1,  9,  3,  5,  4,
          8,  2,  7,  1, 11,  6,  3, 10,  7,  2,  1, 11,  6,  3, 10]])

In [210]:
graph.edge_depths

tensor([0, 1, 0, 1, 0, 0, 2, 1, 0, 0, 2, 1, 2, 0, 3, 0, 2, 1, 0, 2, 3, 0, 2, 1,
        0, 4, 2, 3, 2, 3, 1, 0, 4, 2, 3, 2, 3, 1, 3, 0, 4, 2, 2, 3, 3, 1, 0, 4,
        2, 2, 3])

Add self loops.

In [211]:
graph_sl = graph.add_self_loops()
graph_sl.adjacency_matrix().to_dense()

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 2, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 3, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 4, 3, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 4, 0, 3, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 4, 3, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 4, 0, 3, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 4, 3, 6, 2, 7, 0, 0, 5, 8, 0, 0, 1, 0, 0, 0],
        [0, 4, 6, 3, 7, 2, 0, 0, 8, 5, 0, 0, 0, 1, 0, 0],
        [0, 4, 3, 6, 0, 0, 2, 7, 0, 0, 5, 8, 0, 0, 1, 0],
        [0, 4, 6, 3, 0, 0, 7, 2, 0, 0, 8, 5, 0, 0, 0, 1]])

Graph extension.

In [212]:
graph = Lattice(2, 2).causal_graph()
graph.source_depths, graph.adjacency_matrix().to_dense()

(tensor([0, 0, 1, 2]),
 tensor([[0, 0, 0, 0],
         [0, 0, 0, 0],
         [0, 1, 0, 0],
         [0, 1, 2, 0]]))

In [213]:
graph_ext = graph.expand(2, 3)
graph_ext.source_depths, graph_ext.adjacency_matrix().to_dense()

(tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2]),
 tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  1,  2,  3,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  4,  5,  6,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  1,  2,  3,  7,  8,  9,  0,  0,  0],
         [ 0,  0,  0,  4,  5,  6, 10, 11, 12,  0,  0,  0]]))

In [214]:
graph_ext = graph.add_self_loops().expand(2, 3)
graph_ext.source_depths, graph_ext.adjacency_matrix().to_dense()

(tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2]),
 tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  1,  2,  3,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  4,  5,  6,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  7,  8,  9,  1,  2,  3,  0,  0,  0],
         [ 0,  0,  0, 10, 11, 12,  4,  5,  6,  0,  0,  0],
         [ 0,  0,  0,  7,  8,  9, 13, 14, 15,  1,  2,  3],
         [ 0,  0,  0, 10, 11, 12, 16, 17, 18,  4,  5,  6]]))

Create depth-specific weight matrix from weight vector. Data moves to GPU automatically.

In [215]:
weight_vector = (torch.randn(graph_ext.max_edge_type)*10).round().to(device)

tensor([-12.,  -6.,  -4.,   3.,  -4.,  -2.,  -5.,   3.,   0.,   2., -11.,  15.,
          3.,   2.,  -7.,   0.,  -1., -17.], device='cuda:0')

In [216]:
graph_ext.sparse_matrix(weight_vector, 0).to_dense().round()

tensor([[  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0., -12.,  -6.,  -4.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   3.,  -4.,  -2.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,  -5.,   3.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   2., -11.,  15.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,  -5.,   3.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [  0.,   0.,   0.,   2., -11.,  15.,   0.,   0.,   0.,   0.,   0.,   0.]],
       device='cuda:0')

## Graph Convolution

### GraphConvLayer

In [4]:
class GraphConvLayer(nn.Module):
    """ Graph Convolution layer 
        
        Args:
        graph: graph object
        in_features: number of input features (per node)
        out_features: number of output features (per node)
        bias: whether to learn an edge-depenent bias
        self_loop: whether to include self loops in message passing
    """
    def __init__(self, graph:Graph, in_features:int, out_features:int,
                 bias:bool = True, self_loop:bool or int = True):
        super(GraphConvLayer, self).__init__()
        self.self_loop = self_loop
        if isinstance(self.self_loop, bool):
            if self.self_loop:
                self.graph = graph.add_self_loops()
            else:
                self.graph = graph
        else:
            self.graph = graph.add_self_loops(start=self.self_loop)
        self.in_features = in_features
        self.out_features = out_features
        self.weight_graph = self.graph.expand(self.out_features, self.in_features)
        self.weight_vector = nn.Parameter(torch.Tensor(self.weight_graph.max_edge_type))
        self.bias = bias
        if self.bias:
            self.bias_graph = self.graph.expand(out_features, 1)
            self.bias_vector = nn.Parameter(torch.Tensor(self.bias_graph.max_edge_type))        
        self.reset_parameters()
        (self.target_dim, self.source_dim) = self.weight_graph.dims
    
    def reset_parameters(self):
        bound = 1 / math.sqrt(self.in_features)
        nn.init.uniform_(self.weight_vector, -bound, bound)
        if self.bias:
            nn.init.uniform_(self.bias_vector, -bound, bound)
    
    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}, self_loop={}\n{}'.format(
            self.in_features, self.out_features, self.bias, self.self_loop, self.graph)
    
    def forward(self, x, depth=None):
        weight_matrix = self.weight_graph.sparse_matrix(self.weight_vector, depth)
        bias_matrix = self.bias_graph.sparse_matrix(self.bias_vector, depth)
        unit = torch.ones((bias_matrix.shape[1], 1), dtype=bias_matrix.dtype, device=bias_matrix.device)
        return weight_matrix @ x + bias_matrix @ unit

Example:

In [49]:
gcl = GraphConvLayer(Lattice(2, 2).causal_graph(), 3, 2).to(device)

GraphConvLayer(
  in_features=3, out_features=2, bias=True, self_loop=True
  Graph(4x4, 7 edges of 3 types)
)

create some input. shape: [num_vertices * in_features, batch_size]

In [50]:
x = torch.randn(gcl.source_dim,5).to(device)

tensor([[-1.2756,  1.2161,  0.0269, -1.0970,  0.4550],
        [ 1.2512,  0.1701, -0.0775,  0.6105,  0.3584],
        [ 0.2962, -1.4163, -0.3514, -1.3215,  1.0920],
        [-2.3260, -1.8017,  1.3213,  0.9404,  0.8980],
        [ 0.0933,  1.0067, -0.5500, -0.0405, -0.3438],
        [-1.2374,  1.1529,  1.2877, -0.7187,  0.4455],
        [ 0.6294,  0.4120,  0.8329,  1.0898, -0.7481],
        [-1.0627, -1.3954,  0.4898,  1.3784, -0.1808],
        [ 0.1091, -0.1274,  0.8931,  1.6168,  0.3910],
        [-1.6585,  0.1301, -0.8432, -2.2592,  1.1184],
        [ 1.1861,  0.6768,  0.1555,  1.3814, -1.3598],
        [-0.2948,  2.9379, -0.8041, -0.4929,  0.9091]], device='cuda:0')

forward all at once.

In [51]:
gcl(x)

tensor([[-1.5503, -0.4548, -0.5375, -1.4977, -0.3024],
        [-0.6669,  0.3509, -0.1501, -0.5594, -0.0157],
        [-1.6627, -1.4377,  0.5147, -0.3018,  0.0803],
        [-1.0524, -0.9006,  0.3106,  0.2208,  0.1689],
        [-0.3420,  0.3581, -0.8189, -0.9046, -1.1280],
        [ 1.9121,  0.2078, -0.1099,  0.6334, -0.3435],
        [-2.2384,  0.0512, -1.5212, -2.6627,  0.4128],
        [ 1.4476,  0.6067, -0.8818, -1.0748,  1.0055]], device='cuda:0',
       grad_fn=<AddBackward0>)

forward by depth.

In [54]:
gcl(x, 1)

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.2783,  0.3193, -0.2737, -0.4814, -0.6108],
        [ 0.0641, -0.0190,  0.1403,  0.2319, -0.4773],
        [ 0.1614,  0.2088,  0.0901, -0.1014,  0.1561],
        [ 0.4022,  0.5706, -0.1531, -0.4443,  0.6398]], device='cuda:0',
       grad_fn=<AddBackward0>)

### GraphConvNet

In [55]:
class GraphConvNet(nn.Module):
    """ Graph Convolution network 
        
        Args:
        graph: graph object
        features: a list of numbers of features (per node) across layers
        bias: whether to learn an edge-depenent bias
        nonlinearity: nonlinear activation to use
    """
    def __init__(self, graph:Graph, features, bias:bool = True, nonlinearity:str = 'Tanh'):
        super(GraphConvNet, self).__init__()
        self.graph = graph
        self.features = features
        self.layers = nn.ModuleList()
        for l in range(1, len(self.features)):
            if l == 1: # the first layer should not have self loops
                self.layers.append(GraphConvLayer(self.graph, self.features[0], self.features[1], bias, self_loop=False))
            else: # remaining layers are normal
                self.layers.append(getattr(nn, nonlinearity)()) # activatioin layer
                self.layers.append(GraphConvLayer(self.graph, self.features[l - 1], self.features[l], bias, self_loop=1))
                
    def forward(self, input, depth=None, cache=None):
        # input: [..., nodes, features]
        in_shape = input.shape
        batch_dim = torch.tensor(in_shape[:-2]).prod()
        input_dim = torch.tensor(in_shape[-2:]).prod()
        x = input.view((batch_dim, input_dim)).T
        if depth is None:
            for layer in self.layers:
                x = layer(x)
        else: # depth-specific forward
            if cache is None: # if cache not exist, prepare cache
                cache = [x]
                for layer in self.layers:
                    if isinstance(layer, GraphConvLayer):
                        target_dim = layer.target_dim
                    cache.append(torch.zeros((target_dim, batch_dim), device=x.device))
            else: # if cache exist, load x to cache[0]
                cache[0] = x
            # cache is ready, start forwarding
            for l, layer in enumerate(self.layers):
                if isinstance(layer, GraphConvLayer):
                    if l == 0: # first layer should forward from the previous depth
                        cache[l+1] = cache[l+1] + layer(cache[l], depth - 1)
                    else: # remaining layer forward from the current depth
                        cache[l+1] = cache[l+1] + layer(cache[l], depth)
                else:
                    cache[l+1] = layer(cache[l])
            x = cache[-1] # last cache hosts output
        out_shape = in_shape[:-1]+(self.features[-1],)
        output = x.T.view(out_shape)
        if cache is None:
            return output
        else:
            return output, cache

Example:

In [56]:
gcn = GraphConvNet(Lattice(2, 2).causal_graph(), (2, 4, 3))

GraphConvNet(
  (layers): ModuleList(
    (0): GraphConvLayer(
      in_features=2, out_features=4, bias=True, self_loop=False
      Graph(4x4, 3 edges of 2 types)
    )
    (1): Tanh()
    (2): GraphConvLayer(
      in_features=4, out_features=3, bias=True, self_loop=1
      Graph(4x4, 6 edges of 3 types)
    )
  )
)

Foward by depth iteratively.

In [62]:
x = torch.rand(1,4,2)
cache = None
for depth in range(gcn.graph.max_depth+1):
    y, cache = gcn(x, depth, cache)
    print('---- depth: {} ----'.format(depth))
    print(y)

---- depth: 0 ----
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.1460, -0.1877, -0.0758],
         [-0.0525,  0.2649, -0.2509],
         [-0.0525,  0.2649, -0.2509]]], grad_fn=<ViewBackward>)
---- depth: 1 ----
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.1460, -0.1877, -0.0758],
         [ 0.2149, -0.2660, -0.6753],
         [ 0.9043,  0.2574, -0.0924]]], grad_fn=<ViewBackward>)
---- depth: 2 ----
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.1460, -0.1877, -0.0758],
         [ 0.2149, -0.2660, -0.6753],
         [ 1.3022, -0.0668, -0.7082]]], grad_fn=<ViewBackward>)


Depth-wise forward and one-shot forward result in the same output (upto roundoff error).

In [59]:
x = torch.rand(3,4,2)
cache = None
for depth in range(gcn.graph.max_depth+1):
    y, cache = gcn(x, depth, cache)
gcn(x) - y

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 1.4901e-08, -1.4901e-08,  0.0000e+00],
         [ 0.0000e+00,  4.4703e-08, -5.9605e-08]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  1.4901e-08, -5.9605e-08]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  2.9802e-08,  0.0000e+00],
         [ 0.0000e+00,  5.9605e-08, -5.9605e-08]]], grad_fn=<SubBackward0>)

## Autoregressive Model

### Test Case

In [68]:
class Autoregressive(nn.Module, dist.Distribution):
    """ Represent a generative model that can generate samples and evaluate log probabilities.
        
        Args:
        latt: Lattice
        order: group order
        hidden_features: a list of integers specifying hidden dimensions
        radius (optional): radius used to construct causal graph 
        bias, nonlinearity (optional): to set graph convolutional network
    """
    
    def __init__(self, latt:Lattice, order:int, hidden_features, radius=1., bias:bool = True, nonlinearity:str = 'Tanh'):
        super(Autoregressive, self).__init__()
        self.latt = latt
        self.order = order
        self.graph = latt.causal_graph(radius)
        features = [order] + hidden_features + [order]
        self.gcn = GraphConvNet(self.graph, features, bias, nonlinearity)
        self.sampler = dist.OneHotCategorical
    
    def log_prob(self, sample):
        logits = self.gcn(sample) # forward pass to get logits
        return torch.sum(sample * F.log_softmax(logits, dim=-1), (-2,-1))
    
    def sample(self, sample_size: int):
        device = next(self.parameters()).device # determine device
        samples = torch.zeros(sample_size, self.latt.sites, self.order, device=device) # prepare sample container
        cache = None
        for depth in range(self.graph.max_depth + 1):
            logits, cache = self.gcn(samples, depth, cache)
            select = self.graph.source_depths == depth # select nodes of the depth
            samples[...,select,:] = self.sampler(logits=logits[...,select,:]).sample() # sample from logits
            print('--- depth: {} ---'.format(depth))
            print(logits.data)
            print(samples.data)
        print('---')
        print(self.gcn(samples).data)

In [70]:
ag = Autoregressive(Lattice(2, 2), 2, [4])
ag.sample(1)

--- depth: 0 ---
tensor([[[ 0.0000,  0.0000],
         [ 0.4735, -0.3400],
         [-0.2883,  0.1043],
         [-0.2883,  0.1043]]])
tensor([[[0., 1.],
         [1., 0.],
         [0., 0.],
         [0., 0.]]])
--- depth: 1 ---
tensor([[[ 0.0000,  0.0000],
         [ 0.4735, -0.3400],
         [ 0.1348, -0.1086],
         [ 0.3147,  0.8451]]])
tensor([[[0., 1.],
         [1., 0.],
         [1., 0.],
         [0., 0.]]])
--- depth: 2 ---
tensor([[[ 0.0000,  0.0000],
         [ 0.4735, -0.3400],
         [ 0.1348, -0.1086],
         [ 0.9257,  0.5878]]])
tensor([[[0., 1.],
         [1., 0.],
         [1., 0.],
         [1., 0.]]])
---
tensor([[[ 0.0000,  0.0000],
         [ 0.4735, -0.3400],
         [ 0.1348, -0.1086],
         [ 0.9257,  0.5878]]])


### Final Version

In [74]:
class Autoregressive(nn.Module, dist.Distribution):
    """ Represent a generative model that can generate samples and evaluate log probabilities.
        
        Args:
        latt: Lattice
        order: group order
        hidden_features: a list of integers specifying hidden dimensions
        radius (optional): radius used to construct causal graph 
        bias, nonlinearity (optional): to set graph convolutional network
    """
    
    def __init__(self, latt:Lattice, order:int, hidden_features, radius=1., bias:bool = True, nonlinearity:str = 'Tanh'):
        super(Autoregressive, self).__init__()
        self.latt = latt
        self.order = order
        self.graph = latt.causal_graph(radius)
        features = [order] + hidden_features + [order]
        self.gcn = GraphConvNet(self.graph, features, bias, nonlinearity)
        self.sampler = dist.OneHotCategorical
    
    def log_prob(self, samples):
        logits = self.gcn(samples) # forward pass to get logits
        log_prob = torch.sum(samples * F.log_softmax(logits, dim=-1), (-2,-1))
        return log_prob
    
    def sample(self, sample_size: int, return_log_prob:bool = False):
        device = next(self.parameters()).device # determine device
        samples = torch.zeros(sample_size, self.latt.sites, self.order, device=device) # prepare sample container
        cache = None
        for depth in range(self.graph.max_depth + 1):
            logits, cache = self.gcn(samples, depth, cache)
            select = self.graph.source_depths == depth # select nodes of the depth
            samples[...,select,:] = self.sampler(logits=logits[...,select,:]).sample() # sample from logits
        if return_log_prob:
            log_prob = torch.sum(samples * F.log_softmax(logits, dim=-1), (-2,-1))
            return samples, log_prob
        else:
            return samples

In [77]:
ag = Autoregressive(Lattice(2, 2), 2, [4])
ag.sample(2, return_log_prob=True)

(tensor([[[0., 1.],
          [0., 1.],
          [0., 1.],
          [1., 0.]],
 
         [[1., 0.],
          [1., 0.],
          [1., 0.],
          [1., 0.]]]),
 tensor([-3.1582, -1.7350], grad_fn=<SumBackward1>))

## Node

Node object represents a single node in the lattice. Properties:
- type: `'lat'` - latent node, `'phy'` - physical node.
- ind: node index.
- center: $(x,y)$ coordinate projected to the boundary coordinate system.
- generation: generation of node in the hyperbolic structure.
- parent: point to the parent node (except for node 0).
- children: (*latent only*) [ch1, ch2] a pair of children nodes.
- site: (*physical only*) site index of the physical node.

In [7]:
class Node(object):
    """ Represent a node object, containing coordinate and relationship information.
    """
    def __init__(self, ind:int):
        self.type = None
        self.ind = ind
        self.center = None
        self.generation = None
        self.parent = None
        self.children = [None, None]
        self.site = None
        
    def __repr__(self):
        return 'Node({})'.format(self.ind)
    
    def ancestors(self):
        # ancestor = self + ancestor of parent
        if self.parent is not None:
            return [self] + self.parent.ancestors()
        else:
            return []
    
    def shadow_sites(self):
        # shadow_sites = sum of shadow_sites of children
        if self.type is 'lat':
            shd = []
            for node in self.children:
                shd += node.shadow_sites()
            return shd
        elif self.type is 'phy':
            return [self.site]
        
    def action_sites(self):
        # action_sites = shadow_sites of last child
        if self.type is 'lat':
            return self.children[-1].shadow_sites()
        elif self.type is 'phy':
            return []

Example:

In [255]:
[Node(i) for i in range(4)]

[Node(0), Node(1), Node(2), Node(3)]

## Lattice

Lattice object host a list of nodes. Nodes are classified by types: latent and physical. The physical node is associated with site index. 

In [8]:
class Lattice(object):
    """ Hosts lattice information and construct causal graph
        
        Args:
        size: number of size along one dimension (assuming square/cubical lattice)
              must be a power of 2 for binary tree construction
        dimension: dimension of the lattice
    """
    def __init__(self, size:int, dimension:int):
        assert (size & (size-1) == 0) and size != 0, "size must be a power of 2."
        assert dimension > 0, "dimension must be a positive integer."
        self.size = size
        self.dimension = dimension
        self.sites = size**dimension
        self.nodes = [Node(i) for i in range(2*self.sites)]
        self.nodes[0].type = 'lat'
        self.nodes[0].generation = 0
        self.nodes[0].children = [self.nodes[1]]
        self.nodes[1].parent = self.nodes[0]
        def partition(rng: torch.Tensor, dim: int, ind: int, gen: int):
            this_node = self.nodes[ind]
            this_node.center = rng.float().mean(-1)
            this_node.generation = gen
            if rng[dim].sum()%2 == 0:
                this_node.type = 'lat'
                mid = rng[dim].sum()//2
                rng1 = rng.clone()
                rng1[dim, 1] = mid
                rng2 = rng.clone()
                rng2[dim, 0] = mid
                ind1 = (ind-2**gen)+2*2**gen
                ind2 = (ind-2**gen)+3*2**gen
                partition(rng1, (dim + 1)%self.dimension, ind1, gen+1)
                partition(rng2, (dim + 1)%self.dimension, ind2, gen+1)
                this_node.children = [self.nodes[ind1], self.nodes[ind2]]
                self.nodes[ind1].parent = this_node
                self.nodes[ind2].parent = this_node
            else:
                this_node.type = 'phy'
                this_node.site = rng[:,0].dot(self.size**torch.arange(0,self.dimension).flip(0)).item()
        partition(torch.tensor([[0, self.size]]*self.dimension), 0, 1, 0)
        
    def __repr__(self):
        return 'Lattice({} grid)'.format('x'.join(str(self.size) for k in range(self.dimension)))
    
    def wavelet_maps(self):
        decoder_map = torch.zeros((self.sites,self.sites), dtype=torch.long)
        for node in self.nodes:
            if node.type is 'lat':
                source = node.ind
                for target in node.action_sites():
                    decoder_map[target, source] = 1
        encoder_map = torch.inverse(decoder_map.double()).round().long()
        return encoder_map, decoder_map
                    
    def relevant_nodes(self, node, radius = 1.):
        # relevant_nodes = union of ancestors of adjacent nodes within given radius
        scaled_radius = radius * self.size / 2**(node.generation/self.dimension)
        relevant_nodes = set()
        for prior_node in self.nodes[1:node.ind]:
            displacement = prior_node.center - node.center
            displacement = (displacement + self.size/2)%self.size - self.size/2
            if displacement.norm() < scaled_radius:
                relevant_nodes.update(prior_node.ancestors())
        return relevant_nodes
    
    def common_ancestor(self, node1, node2):
        # the closest common ancestor of two nodes
        common_ancestor = None
        while common_ancestor is None:
            if node1.generation == node2.generation:
                if node1 is node2:
                    common_ancestor = node1
                else:
                    node1 = node1.parent
                    node2 = node2.parent
            elif node1.generation < node2.generation:
                node2 = node2.parent
            else: # node1.generation > node2.generation
                node1 = node1.parent
        return common_ancestor
    
    def relationship(self, node1, node2):
        common_ancestor = self.common_ancestor(node1, node2)
        return (node1.generation - common_ancestor.generation, node2.generation - common_ancestor.generation)
    
    def causal_graph(self, radius = 1.):
        relations = set()
        edges = {}
        for target_node in self.nodes[1:]:
            if target_node.type is 'lat':
                for source_node in self.relevant_nodes(target_node, radius):
                    relation = self.relationship(source_node, target_node)
                    relations.add(relation)
                    edges[(target_node.ind, source_node.ind)] = relation
        relations = list(relations)
        relations.sort()
        type_map = {relation: k+1 for k, relation in enumerate(relations)}
        indices = torch.zeros((2, len(edges)), dtype = torch.long)
        edge_types = torch.zeros(len(edges), dtype = torch.long)
        for k, (edge, relation) in enumerate(edges.items()):
            indices[0, k] = edge[0]
            indices[1, k] = edge[1]
            edge_types[k] = type_map[relation]
        graph = Graph(self.sites, indices, edge_types)
        graph.type_dict = {edge_type: relation for relation, edge_type in type_map.items()}
        return graph

Example:

In [162]:
latt = Lattice(4,2)
print('ind type gen  children parent site')
for node in latt.nodes:
    print('{:3d} {:>3s} {:4d} {:>9s} {:>6s} {:>4s}'.format(node.ind, node.type, node.generation, ','.join(str(c.ind) for c in node.children if c is not None), str(node.parent.ind) if node.parent is not None else '', str(node.site) if node.site is not None else ''))

ind type gen  children parent site
  0 lat    0         1            
  1 lat    0       2,3      0     
  2 lat    1       4,6      1     
  3 lat    1       5,7      1     
  4 lat    2      8,12      2     
  5 lat    2      9,13      3     
  6 lat    2     10,14      2     
  7 lat    2     11,15      3     
  8 lat    3     16,24      4     
  9 lat    3     17,25      5     
 10 lat    3     18,26      6     
 11 lat    3     19,27      7     
 12 lat    3     20,28      4     
 13 lat    3     21,29      5     
 14 lat    3     22,30      6     
 15 lat    3     23,31      7     
 16 phy    4                8    0
 17 phy    4                9    8
 18 phy    4               10    2
 19 phy    4               11   10
 20 phy    4               12    4
 21 phy    4               13   12
 22 phy    4               14    6
 23 phy    4               15   14
 24 phy    4                8    1
 25 phy    4                9    9
 26 phy    4               10    3
 27 phy    4        

### Wavelet Transform Data

Sites on which each latent variable acts.

In [142]:
print('ind    action_sites')
for node in latt.nodes:
    if node.type is 'lat':
        print('{:3d} -> {}'.format(node.ind, node.action_sites()))

ind    action_sites
  0 -> [0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15]
  1 -> [8, 9, 12, 13, 10, 11, 14, 15]
  2 -> [2, 3, 6, 7]
  3 -> [10, 11, 14, 15]
  4 -> [4, 5]
  5 -> [12, 13]
  6 -> [6, 7]
  7 -> [14, 15]
  8 -> [1]
  9 -> [9]
 10 -> [3]
 11 -> [11]
 12 -> [5]
 13 -> [13]
 14 -> [7]
 15 -> [15]


Encoder map $e$ and decoder map $d$ of wavelet transform, s.t.
$$\text{encode: }z_a=\prod_i x_i^{e_{a i}},\quad\text{decode: } x_i=\prod_a z_a^{d_{i a}}.$$

In [348]:
latt.wavelet_maps()

(tensor([[ 1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [-1,  0,  0,  0,  0,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0],
         [-1,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0, -1,  0,  1,  0,  0,  0,  0,  0],
         [-1,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0, -1,  0,  0,  0,  1,  0,  0,  0],
         [ 0,  0, -1,  0,  0,  0,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1,  0,  0,  0,  1,  0],
         [-1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0, -1,  1,  0,  0,  0,  0,  0,  0],
         [ 0,  0, -1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0, -1,  1,  0,  0,  0,  0],
         [ 0,  0,  0,  0, -1,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0,

### Causal Graph Data

- **Relevant nodes** of a given node is defined to be the nodes in the *past light-cone* of the *vicinity* of the given node within a given relative radius. 
- Given a target node, the **causal relation** will be established from its relavant nodes to itself.
- The **edge type** can be determined by the relationship between the nodes. The **relationship** between a pair of node is given by their *relative generation* to their *closest common ancestor*.

In [146]:
print('ind relevant_nodes')
for node in latt.nodes[1:]:
    if node.type is 'lat':
        print('{:3d} {}'.format(node.ind, latt.relevant_nodes(node, 1.5)))

ind relevant_nodes
  1 set()
  2 {Node(1)}
  3 {Node(2), Node(1)}
  4 {Node(2), Node(1), Node(3)}
  5 {Node(2), Node(1), Node(4), Node(3)}
  6 {Node(2), Node(1), Node(5), Node(4), Node(3)}
  7 {Node(2), Node(1), Node(5), Node(6), Node(4), Node(3)}
  8 {Node(2), Node(1), Node(5), Node(6), Node(4), Node(3)}
  9 {Node(2), Node(1), Node(5), Node(8), Node(4), Node(3), Node(7)}
 10 {Node(2), Node(1), Node(8), Node(6), Node(4), Node(3), Node(7)}
 11 {Node(2), Node(1), Node(5), Node(10), Node(6), Node(9), Node(3), Node(7)}
 12 {Node(2), Node(1), Node(5), Node(8), Node(6), Node(4), Node(9), Node(3)}
 13 {Node(2), Node(1), Node(5), Node(8), Node(12), Node(4), Node(9), Node(3), Node(7)}
 14 {Node(2), Node(1), Node(10), Node(11), Node(12), Node(6), Node(4), Node(3), Node(7)}
 15 {Node(2), Node(1), Node(5), Node(10), Node(11), Node(14), Node(6), Node(3), Node(7), Node(13)}


Causal graph.

In [164]:
graph = latt.causal_graph()
graph.type_dict, graph.adjacency_matrix().to_dense()

({1: (0, 1), 2: (0, 2), 3: (0, 3), 4: (1, 1), 5: (1, 3), 6: (2, 3), 7: (3, 3)},
 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 1, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 3, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 3, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 3, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 3, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 3, 2, 5, 1, 6, 0, 0, 4, 7, 0, 0, 0, 0, 0, 0],
         [0, 3, 5, 2, 6, 1, 0, 0, 7, 4, 0, 0, 0, 0, 0, 0],
         [0, 3, 2, 5, 0, 0, 1, 6, 0, 0, 4, 7, 0, 0, 0, 0],
         [0, 3, 5, 2, 0, 0, 6, 1, 0

Increasing the radius will include more edges to the causal graph.

In [163]:
latt.causal_graph(1.), latt.causal_graph(1.5), latt.causal_graph(2.)

(Graph(16x16, 51 edges of 7 types),
 Graph(16x16, 85 edges of 9 types),
 Graph(16x16, 101 edges of 9 types))