In [44]:
import torch
import torch.nn as nn
import model_with_new_lattice

In [45]:
class GraphConv(nn.Module):
    """ Graph Convolution layer 
        
        Args:
        graph: tensor of shape [3, num_edges] 
               specifying (source, target, type) along each column
        in_features: number of input features (per node)
        out_features: number of output features (per node)
        bias: whether to learn an edge-depenent bias
        self_loop: whether to include self loops in message passing
    """
    def __init__(self, graph, in_features: int, out_features: int,
                 bias: bool = True, self_loop: bool = True):
        super(GraphConv, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        if bias:
            self.bias = bias
        else:
            self.register_parameter('bias', None)
        self.edge_types = None
        self.self_loop = self_loop
        self.update_graph(graph)
        self.forwarding_graphs_init()

    def update_graph(self, graph):
        # update the graph, adding new linear maps if needed
        if not self.self_loop:
            graph = graph.remove_self_loops()#removes any self_loops according to boolean
        self.depth_assignment = graph.get_depth_assignment()
        graph = graph.inverse_connections()#transposes the graph
        self.graph = graph.expand_features(self.in_features, self.out_features)#expands the graph features
        edge_types = graph.edge_types
        if edge_types != self.edge_types:
            self.weight = nn.Parameter(torch.Tensor(edge_types))
            if self.bias is not None:
                self.bias = nn.Parameter(torch.Tensor(edge_types, self.out_features))
            #self.reset_parameters()
        self.edge_types = edge_types
        self.conv_size = self.graph.get_max_nodes()
        return self
    
    def forwarding_graphs_init(self):
        self.forwarding_graphs = []
        for depth in self.depth_assignment:
            self.forwarding_graphs.append(self.graph.select_connections(*depth))

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in = self.in_features
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    
    def extra_repr(self):
        return 'edge_types={}, in_features={}, out_features={}, bias={}, self_loop={}'.format(
            self.edge_types, self.in_features, self.out_features, self.bias is not None, self.self_loop)

    def forward(self, input, depth = None):
        if depth == None:
            signal, edge_type = self.graph.sparse_graph()
            weights = torch.gather(self.weight, 0, edge_type)
            conv = torch.sparse_coo_tensor(signal, weights, size = self.graph.get_max_nodes())
            output = conv.mm(input)
        else:
            signal, edge_type = self.forwarding_graphs[depth].sparse_graph()
            weights = torch.gather(self.weight, 0, edge_type)
            conv = torch.sparse_coo_tensor(signal, weights, size = self.graph.get_max_nodes())
            output = conv.mm(input)
        #if self.bias:
        
        return output

In [46]:
test_lattice = model_with_new_lattice.Lattice(4,2)
test_layer = GraphConv(test_lattice.graph, 2, 2)

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)


TypeError: 'NoneType' object is not iterable