In [1]:
import torch
import torch.nn as nn
import math
import model

In [2]:
class GraphConv(nn.Module):
    """ Graph Convolution layer 
        
        Args:
        graph: tensor of shape [3, num_edges] 
               specifying (source, target, type) along each column
        in_features: number of input features (per node)
        out_features: number of output features (per node)
        bias: whether to learn an edge-depenent bias
        self_loop: whether to include self loops in message passing
    """
    def __init__(self, lattice, in_features: int, out_features: int,
                 bias: bool = True, self_loop: bool = True):
        super(GraphConv, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        if bias:
            self.bias = bias
        else:
            self.register_parameter('bias', None)
        self.edge_types = None
        self.self_loop = self_loop
        self.lattice_sites = lattice.sites
        self.update_graph(lattice.graph)
        self.conv_size = (self.lattice_sites * out_features, self.lattice_sites * in_features)

    def update_graph(self, graph):
        # update the graph, adding new linear maps if needed
        if not self.self_loop:
            graph = graph.remove_self_loops()#removes any self_loops according to boolean
        self.graph = graph.expand_features(self.in_features, self.out_features)#expands the graph features
        self.graph = self.graph.inverse_connections()
        self.depth_assignment = graph.get_depth_assignment()
        self.forwarding_graphs_init()
        edge_types = self.graph.edge_types
        if edge_types != self.edge_types:
            self.weight = nn.Parameter(torch.Tensor(edge_types))
            if self.bias is not None:
                self.bias = nn.Parameter(torch.Tensor(edge_types, self.out_features))
            self.reset_parameters()
        self.edge_types = edge_types
        return self
    
    def forwarding_graphs_init(self):
        self.forwarding_graphs = []
        for depth in self.depth_assignment:
            self.forwarding_graphs.append(self.graph.select_connections(*depth))

    def reset_parameters(self):
        bound = 1 / math.sqrt(self.lattice_sites * self.in_features)
        nn.init.uniform_(self.weight, -bound, bound)
        if self.bias is not None:
            fan_in = self.in_features
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    
    def extra_repr(self):
        return 'edge_types={}, in_features={}, out_features={}, bias={}, self_loop={}'.format(
            self.edge_types, self.in_features, self.out_features, self.bias is not None, self.self_loop)

    def forward(self, input, depth = None):
        input = input.flatten(1).t()
        if depth == None:
            signal, edge_type = self.graph.sparse_graph()
            weights = torch.gather(self.weight, 0, edge_type)
            conv = torch.sparse_coo_tensor(signal, weights, size = self.conv_size)
            output = torch.sparse.mm(conv, input)
        else:
            signal, edge_type = self.forwarding_graphs[depth].sparse_graph()
            weights = torch.gather(self.weight, 0, edge_type)
            conv = torch.sparse_coo_tensor(signal, weights, size = self.conv_size)
            output = torch.sparse.mm(conv, input)
        output = output.t().unflatten(1, (self.lattice_sites, self.out_features))
        #if self.bias:
        
        return output

In [29]:
test_lattice = model.Lattice(4,2)
test_layer = GraphConv(test_lattice, 2, 4, self_loop = False)
test_input = torch.arange(160, dtype = torch.float).view(5, 16, 2)
test_layer(test_input)

raw input tensor([[[  0.,   1.],
         [  2.,   3.],
         [  4.,   5.],
         [  6.,   7.],
         [  8.,   9.],
         [ 10.,  11.],
         [ 12.,  13.],
         [ 14.,  15.],
         [ 16.,  17.],
         [ 18.,  19.],
         [ 20.,  21.],
         [ 22.,  23.],
         [ 24.,  25.],
         [ 26.,  27.],
         [ 28.,  29.],
         [ 30.,  31.]],

        [[ 32.,  33.],
         [ 34.,  35.],
         [ 36.,  37.],
         [ 38.,  39.],
         [ 40.,  41.],
         [ 42.,  43.],
         [ 44.,  45.],
         [ 46.,  47.],
         [ 48.,  49.],
         [ 50.,  51.],
         [ 52.,  53.],
         [ 54.,  55.],
         [ 56.,  57.],
         [ 58.,  59.],
         [ 60.,  61.],
         [ 62.,  63.]],

        [[ 64.,  65.],
         [ 66.,  67.],
         [ 68.,  69.],
         [ 70.,  71.],
         [ 72.,  73.],
         [ 74.,  75.],
         [ 76.,  77.],
         [ 78.,  79.],
         [ 80.,  81.],
         [ 82.,  83.],
         [ 84.,  85.

tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],
         [6.0000e+00, 6.0000e+00, 6.0000e+00, 6.0000e+00],
         [1.5000e+01, 1.5000e+01, 1.5000e+01, 1.5000e+01],
         [2.8000e+01, 2.8000e+01, 2.8000e+01, 2.8000e+01],
         [4.5000e+01, 4.5000e+01, 4.5000e+01, 4.5000e+01],
         [6.6000e+01, 6.6000e+01, 6.6000e+01, 6.6000e+01],
         [9.1000e+01, 9.1000e+01, 9.1000e+01, 9.1000e+01],
         [1.2000e+02, 1.2000e+02, 1.2000e+02, 1.2000e+02],
         [1.5300e+02, 1.5300e+02, 1.5300e+02, 1.5300e+02],
         [1.9000e+02, 1.9000e+02, 1.9000e+02, 1.9000e+02],
         [2.3100e+02, 2.3100e+02, 2.3100e+02, 2.3100e+02],
         [2.7600e+02, 2.7600e+02, 2.7600e+02, 2.7600e+02],
         [3.2500e+02, 3.2500e+02, 3.2500e+02, 3.2500e+02],
         [3.7800e+02, 3.7800e+02, 3.7800e+02, 3.7800e+02],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+