In [1]:
# Crafted by Collin Miller

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [3]:
import numpy as np
import networkx as nx
import time

In [5]:
def img_to_graph(image):
    """Converts an image to a graph.
    
    This function takes in an image and returns a 
    4-connected grid graph.  The nodes of this graph
    are labeled as such: every pixel is a node, 
    the label of each node is the corresponding 
    (x, y) coordinates.
    
    Args:
        image (numpy_array): The input image.
        
    Returns:
        A network X graph.
    
    """

    image = image.astype(np.int16)

    coo_matrix = feature_extraction.img_to_graph(image)
    graph = nx.from_scipy_sparse_matrix(coo_matrix)

    node_labels = graph.nodes()
    node_labels = np.array(node_labels)
    node_labels = node_labels.reshape(image.shape)

    mapping = {}
    for index, x in np.ndenumerate(node_labels):
        mapping[x] = index

    graph = nx.relabel_nodes(graph, mapping)
    graph.remove_edges_from(graph.selfloop_edges())

    return graph


# In[ ]:


def prims_initialize(img):
    """Initializes an image for prims algorithm.
    
    This function takes in an image and returns
    a graph.  Each node in this graph will have a
    label, an assigned seed variable to be used
    in the minimum spanning forest, and the path
    from the assigned seed to the respective node.

    Args:
        img (numpy_array):  The image to be initialized.

    Returns:
        An initialized 4-connected grid graph.
    """

    graph = img_to_graph(img)

    assignment_dict = dict()
    assignment_history = dict()

    for x in graph.nodes():
        assignment_dict[x] = 'none'
        assignment_history[x] = []

    nx.set_node_attributes(graph, "seed", assignment_dict)
    nx.set_node_attributes(graph, "path", assignment_history)

    return graph

In [6]:
def minimum_spanning_forest(graph, seeds, timed=False):
    """Computes the minimum spanning forest for an image.
    
    This function computes the minimum spanning forest 
    for an image.  The weights for the graph are the 
    pixel gradients of the image.  Starting from the
    given seeds, each region is grown until the entire
    image is segmented.
    
    Args:
        graph (nx_graph): A networkx graph that has 
        been initialized.
        seeds (list): A list of (x, y) tuples to start 
        region growing.  
        timed (boolean): A flag that if True, will display
        how long it took to run the minimum spanning forest.
        
    Returns:
        A networkx graph with every node assigned to a 
        seed and the path from each seed to their respective 
        node.
    
    """
    
    
    num_nodes = graph.number_of_nodes()
    visited = []
    frontier = []

    push = heappush
    pop = heappop

    if timed:         
        print("Starting gradient segmentation...")
        start = time.time()

    while len(visited) < num_nodes:

        for u in seeds:

            # Assign seed to self.
            graph.node[u]['seed'] = u

            visited.append(u)

            # Store path.
            graph.node[u]['path'] = [u]

            # Push all edges
            for u, v in graph.edges(u):
                try:
                    graph.edge[u][v]['image'] = graph.node[v]['image']
                except KeyError:
                    pass

                push(frontier, (graph[u][v].get('weight', 1), u, v))

        while frontier:
            W, u, v = pop(frontier)

            if v in visited:
                continue

            # Assign the node
            graph.node[v]['seed'] = graph.node[u]['seed']

            # Store path.
            graph.node[v]['path'] = graph.node[u]['path'] + [v]

            visited.append(v)

            for v, w in graph.edges(v):
                if not w in visited:
                    try:
                        graph.edge[v][w]['image'] = graph.node[w]['image']
                    except KeyError:
                        pass
                    push(frontier, (graph[v][w].get('weight', 1), v, w))

    if timed:
        end = time.time()
        print("Segmentation done: %fs" % (end - start))

    return graph