In [1]:
import torch
import networkx as nx
import matplotlib.pyplot as plt
import depthwise_model as model

In [2]:
class Graph(object):
    
    def __init__(self, connection_graph, encode = False):
        """Constructs a causal graph from a connection_graph structure
        
        args:
        connection_graph: a graph with connections and their connection types
        encode: a boolean value indicating wether to encode the connection types"""
        
        assert len(connection_graph[0][0]) == len(connection_graph[1]), "Size Mismatch"

        if len(connection_graph[0][0]) != 0:
            self.edge_types = len(connection_graph[1].unique())
            self.max_node = connection_graph[0].max().item()
            if encode:
                encoded_connection_types = torch.tensor(self.encode_connection_types(connection_graph[1]))
                connection_graph = (connection_graph[0], encoded_connection_types)
        self.connection_graph = connection_graph
        self.num_edges = len(connection_graph[0][0])
        
    @classmethod
    def from_python_lists(cls, in_connections, out_connections, connection_types, encode = False):
        """Constructs a causal graph from specified arguments
        
        args:
        in_connections: a list of where connections start
        out_connections: a list of where connections end
        connection_types: a list of types of connection
        encode: a boolean value indicating wether to encode the connection types"""
        
            
        obj = cls((torch.tensor([in_connections, out_connections]), torch.tensor(connection_types)), encode)
        return obj
        
    def encode_connection_types(self, connection_types):
        """encodes connection_types as consecutive numbers from start
        
        args:
        connection_types: list of the different connection types
        start: integer indicating where to start the encoding"""
        types = connection_types.unique()
        if min(types) < 0 or max(types) >= len(types):
            for encoding, connection_type in enumerate(types):
                connection_types[connection_types == connection_type] = encoding
        return connection_types
    
    def expand_features(self, in_expansion=1, out_expansion=1):
        #expansion of the in features
        out_features = self.connection_graph[0][1].repeat(in_expansion)
        in_features = self.connection_graph[0][0] * in_expansion
        in_features = in_features + torch.arange(0, in_expansion).unsqueeze(1)
        in_features = in_features.flatten()
        weights = self.connection_graph[1] * in_expansion
        weights = weights + torch.arange(0, in_expansion).unsqueeze(1)
        weights = weights.flatten()
    
        #expansion of the out features
        in_features = in_features.repeat(out_expansion)
        out_features = out_features * out_expansion
        out_features = out_features + torch.arange(0, out_expansion).unsqueeze(1)
        out_features = out_features.flatten()
        weights = weights * in_expansion
        weights = weights + torch.arange(0, out_expansion).unsqueeze(1)
        weights = weights.flatten()
        connection_graph = (torch.stack((in_features, out_features)), weights)
        return Graph(connection_graph, encode = True)
    
    def get_adj_matrix(self):
        sparse_graph = self.inverse_connections().sparse_graph()
        adj_size = self.max_node + 1
        return torch.sparse.FloatTensor(sparse_graph[0], torch.ones(self.num_edges),
                                        torch.Size([adj_size, adj_size])).to_dense().int()
    
    def get_depths(self):                                                                            
        """creates a list of the causal depths if there are no loops in the graph otherwise returns None"""
        depths = torch.zeros(self.max_node + 1, 1, dtype=torch.int32)
        depth_determiner = torch.ones(self.max_node + 1, 1, dtype=torch.int32)
        adj_matrix = self.remove_self_loops().get_adj_matrix()
        prev_parents = self.max_node + 1
        num_parents = self.max_node + 1
        while num_parents:
            depth_determiner = (adj_matrix.mm(depth_determiner) > 0).int()
            num_parents = sum(depth_determiner)
            if num_parents == prev_parents:
                return None
            prev_parents = num_parents
            depths += depth_determiner
        return depths.flatten()
    
    def get_depth_assignment(self):
        """creates a list grouping the different depths of each site"""
        depths = self.get_depths()
        if depths is None:
            return None
        depth_assignment = [[] for _ in range(max(depths) + 1)]
        for i, depth in enumerate(depths):
            depth_assignment[depth].append(i)
        return depth_assignment
    
    def get_edges(self):
        """creates a set of connections with out connection type"""
        return self.connection_graph[0]
    
    def get_max_nodes(self):
        max_row = self.connection_graph[0][0].max().item()
        max_column = self.connection_graph[0][1].max().item()
        return max_row, max_column
            
    def inverse_connections(self):
        """inverts directionality of edges in the graph"""
        inverse_connection_graph = (self.connection_graph[0].roll(1, 0) , self.connection_graph[1])
        return Graph(inverse_connection_graph)
        
    def is_self_loop(self):
        """checks if graph has any self loops"""
        return torch.all(self.connection_graph[0][0] != self.connection_graph[0][1])
    
    def remove_self_loops(self):
        """removes all self loops from the graph"""
        not_self_loops = self.connection_graph[0][0] != self.connection_graph[0][1]
        return Graph(self.take_connections(not_self_loops))
    
    def select_connections(self, *args, out = False):
        """creates a graph with specified in or out features
        
        args: connections to select
        out: boolean indicating to select by in feature or out feature
        """
        graph_selection = torch.tensor([node.item() in args for node in self.connection_graph[0][1 if out else 0]])
        
        return Graph(self.take_connections(graph_selection))
        
    def set_connection_graph(self, connection_graph):
        """updates connection_graph, num_edges, and weights based on input connection graph
        
        args:
        connection_graph: connection graph to set self.connection_graph to"""
        self.connection_graph = connection_graph
        self.num_edges = len(self.connection_graph[1])
        self.edge_types = len(torch.unique(self.connection_graph[1]))
        self.max_node = max(max(self.connection_graph[0][0]), max(self.connection_graph[0][1])).item()
        
    def sparse_graph(self):
        """creates a sparse graph for pytorch sparse tensors"""
            
        return self.connection_graph[0], self.connection_graph[1]
    
    def take_connections(self, bool_tensor):
        selected_graph = (self.connection_graph[0][:, bool_tensor], self.connection_graph[1][bool_tensor])
        return selected_graph
    
    def __repr__(self):
        """representation of connection_graph sorted by connection_type"""
        return "{0}\n{1}".format(self.connection_graph[0], self.connection_graph[1])

In [3]:
test = Graph.from_python_lists([],[],[])
cords, weights = test.sparse_graph()
torch.sparse_coo_tensor(cords, weights, size = (10,10)).to_dense()

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [4]:
lat = model.Lattice(2,2)
#graph = graph.remove_self_loops()
cords, weights = lat.graph.sparse_graph()
print(torch.sparse_coo_tensor(cords, weights).to_dense())
graph = lat.graph.expand_features(4,2)
cords, weights = graph.sparse_graph()
print(torch.sparse_coo_tensor(cords, weights).to_dense())

tensor([[1, 2, 2],
        [0, 1, 3],
        [0, 0, 1]])
tensor([[ 0,  1,  8,  9,  8,  9],
        [ 2,  3, 10, 11, 10, 11],
        [ 4,  5, 12, 13, 12, 13],
        [ 6,  7, 14, 15, 14, 15],
        [ 0,  0,  0,  1, 16, 17],
        [ 0,  0,  2,  3, 18, 19],
        [ 0,  0,  4,  5, 20, 21],
        [ 0,  0,  6,  7, 22, 23],
        [ 0,  0,  0,  0,  0,  1],
        [ 0,  0,  0,  0,  2,  3],
        [ 0,  0,  0,  0,  4,  5],
        [ 0,  0,  0,  0,  6,  7]])


In [8]:
lat = model.Lattice(4,2)
connection_graph = lat.graph.connection_graph
expand_graph = model.ExpandGraph(connection_graph, False)
expand_graph = expand_graph.expand_features(4, 2)
expand_graph = expand_graph.remove_self_loops()
print(expand_graph.get_adj_matrix())
cords, weights = expand_graph.sparse_graph()
print(torch.sparse_coo_tensor(cords, weights).to_dense())
print(expand_graph.get_depth_assignment())

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.int32)
tensor([[ 0,  0,  8,  ..., 25, 24, 25],
        [ 0,  0, 10,  ..., 27, 26, 27],
        [ 0,  0, 12,  ..., 29, 28, 29],
        ...,
        [ 0,  0,  0,  ...,  0, 74, 75],
        [ 0,  0,  0,  ...,  0, 76, 77],
        [ 0,  0,  0,  ...,  0, 78, 79]])
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31], [32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47], [48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59]]


In [6]:
#creates the graph with ([[0,1,2,3,0], [1,2,3,4,0]], [2,3,4,5,1])
x = [0,1,2,3,0]
y = [1,2,3,4,0]
z = [2,3,4,5,1]
test_graph = Graph.from_python_lists(x, y, z)

print("Connection Graph:\n\n", test_graph)
print("Adjacency Matrix:\n\n", test_graph.get_adj_matrix())
print("Selecting Inputs from 2:\n\n", test_graph.select_connections(2))
test_graph = test_graph.remove_self_loops()
print("Adjacency Matrix with self loops removed:\n\n", test_graph.get_adj_matrix())
test_graph = test_graph.expand_features(2, 2)
print("Adjacency Matrix with expanded features:\n\n", test_graph.get_adj_matrix())
print("Depths of each feature:\n\n", test_graph.get_depths())
depth_assignment = test_graph.get_depth_assignment()
print("Depth Assignments:\n\n", depth_assignment)
print("Selecting Inputs from depth 0:\n\n", test_graph.select_connections(*depth_assignment[0]))

Connection Graph:

 tensor([[0, 1, 2, 3, 0],
        [1, 2, 3, 4, 0]])
tensor([2, 3, 4, 5, 1])
Adjacency Matrix:

 tensor([[1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0]], dtype=torch.int32)
Selecting Inputs from 2:

 tensor([[2],
        [3]])
tensor([4])
Adjacency Matrix with self loops removed:

 tensor([[0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0]], dtype=torch.int32)
Adjacency Matrix with expanded features:

 tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 1, 0, 0]], dtype=torch.int32)
Depths of each feature

  encoded_connection_types = torch.tensor(self.encode_connection_types(connection_graph[1]))
