# Setup

In [None]:
# Imports
import cv2
import copy
import itertools
import keras
import logging
import matplotlib.pyplot as plt
import math
import networkx as nx
import numpy as np
import os
import pandas as pd
import random
import re
import scipy.misc
import sklearn.feature_extraction
import tensorflow as tf
import time
import threading
import gc

from scipy import sparse
from scipy import ndimage
from heapq import heappush, heappop
from networkx_viewer import Viewer
from matplotlib import colors as mcolors
from sys import stdout
from __future__ import division

#%matplotlib inline

# Image to graph

In [None]:
def _make_edges_3d(n_x, n_y, n_z=1):
    """Returns a list of edges for a 3D image.
    Parameters
    ===========
    n_x: integer
        The size of the grid in the x direction.
    n_y: integer
        The size of the grid in the y direction.
    n_z: integer, optional
        The size of the grid in the z direction, defaults to 1
    """
    vertices = np.arange(n_x * n_y * n_z).reshape((n_x, n_y, n_z))
    edges_deep = np.vstack((vertices[:, :, :-1].ravel(),
                            vertices[:, :, 1:].ravel()))
    edges_right = np.vstack((vertices[:, :-1].ravel(),
                             vertices[:, 1:].ravel()))
    edges_down = np.vstack((vertices[:-1].ravel(), vertices[1:].ravel()))
    edges = np.hstack((edges_deep, edges_right, edges_down))
    return edges

In [None]:
def _compute_altitude_3d(edges, img):
    n_x, n_y, n_z = img.shape
    gradient = np.abs(img[edges[0] // (n_y * n_z),
                          (edges[0] % (n_y * n_z)) // n_z,
                          (edges[0] % (n_y * n_z)) % n_z] -
                          img[edges[1] // (n_y * n_z),
                          (edges[1] % (n_y * n_z)) // n_z,
                          (edges[1] % (n_y * n_z)) % n_z])
    return gradient

In [None]:
def img_to_graph(image):
    
    dtype=None
    image = np.atleast_3d(image)
    n_x, n_y, n_z = image.shape
    edges = _make_edges_3d(n_x, n_y, n_z)
    weights = _compute_altitude_3d(edges, image)
    diag = image.ravel()
    n_voxels = diag.size
    diag_idx = np.arange(n_voxels)
    i_idx = np.hstack((edges[0], edges[1]))
    j_idx = np.hstack((edges[1], edges[0]))
    matrix = sparse.coo_matrix((np.hstack((weights, weights, diag)),
                              (np.hstack((i_idx, diag_idx)),
                               np.hstack((j_idx, diag_idx)))),
                              (n_voxels, n_voxels),
                              dtype=dtype)
    graph = nx.from_scipy_sparse_matrix(matrix)
    graph.remove_edges_from(graph.selfloop_edges())
    
    mapping = map_node(image)
    values = get_altitude_map(image)
    
    graph = nx.relabel_nodes(graph, mapping)
    nx.set_node_attributes(graph,'value',values=values)
    
    return graph

In [None]:
def get_altitude_map(img):
    values = dict()
    for row in range(img.shape[0]):
        for col in range(img.shape[1]):
            values[(row,col)] = img[row,col][0]
    
    return values

In [None]:
def map_node(img):
    
    assert isinstance(img, np.ndarray), "Not an image"
    
    mapping = dict()
    
    for row in range(img.shape[0]):
        for col in range(img.shape[1]):
            mapping[row * img.shape[1] + col] = ((row), (col))

    return mapping

In [None]:
def get_image_positions(img):
    positions = dict()
    for row in range(img.shape[0]):
        for col in range(img.shape[1]):
            positions[(row,col)] = (col,row)
    return positions

In [None]:
def view_graph(img, graph, pixel_values=False, 
               figurename='graph.jpg'):
    node_size = 500
    node_color = 'b'
    seed_color = 'r'
    colors = ["pink","orange","brown","red","green",
              "orange","beige","turquoise","blue"]
    
    plt.figure(figsize=(img.shape[0],img.shape[1]))
    
    positions = get_image_positions(img)
          
    pos = nx.spring_layout(graph,pos=positions, fixed = graph.nodes())
    labels = nx.get_edge_attributes(graph, 'weight')
    

    assigned_nodes = list((n for n in graph if graph.node[n]['seed'] != 'none'))
    unassigned_nodes = list((n for n in graph if graph.node[n]['seed'] == 'none'))

    
    seeds = list((n for n in graph if graph.node[n]['type']=='seed'))
    for x in range(len(seeds)):
        nodes = list((n for n in graph if graph.node[n]['seed']==seeds[x]))
        nx.draw_networkx_nodes(graph,pos,
                       nodelist=nodes,
                       node_color=colors[(x % (len(colors) - 1)) + 1],
                       node_size=node_size,
                   alpha=0.8)
        nx.draw_networkx_nodes(graph,pos,
                       nodelist=seeds,
                       node_size=node_size,
                   alpha=0.8, node_shape= 's')
    
    nx.draw_networkx_nodes(graph,pos,
                       nodelist=unassigned_nodes,
                       node_color=colors[0],
                       node_size=node_size,
                   alpha=0.8)

        
    
    if pixel_values:
        values = nx.get_node_attributes(graph,'value')
        nx.draw_networkx_labels(graph,pos,labels=values)
    else:
        nx.draw_networkx_labels(graph,pos)

    nx.draw_networkx_edge_labels(graph,pos,edge_labels=labels)
    nx.draw_networkx_edges(graph,pos , width=5,edge_color='b')
    
    plt.savefig(figurename, dpi=100)
    
    plt.imshow(img, cmap = 'gray')
    
    plt.savefig(figurename, dpi=100)

In [None]:
def plant_seeds(graph, seeds):
    
    temp_graph = graph.copy()
    
    types_dict = dict()
    assignment_dict = dict()
    assignment_history = dict()

    for x in temp_graph.nodes():
        if x in seeds:
            types_dict[x] = "seed"
            assignment_dict[x] = x
            assignment_history[x] = []
        else:
            types_dict[x] = "node"
            assignment_dict[x] = 'none'
            assignment_history[x] = []
        
            
    nx.set_node_attributes(temp_graph, "type", types_dict)
    nx.set_node_attributes(temp_graph, "seed", assignment_dict)
    nx.set_node_attributes(temp_graph, "path", assignment_history)
    
    return temp_graph

In [None]:
def generate_seeds(graph, n_seeds):
    
    seeds = []
    
    for x in range(n_seeds):
        seeds.append(random.choice(graph.nodes()))
        
    return seeds

In [None]:
def get_graph_seed_assignments(graph):
    assignments = dict()
    for a, seed in enumerate(seeds):
        nodes = list((n for n in graph if graph.node[n]['seed']==seeds[a]))
        assignments.update(dict.fromkeys(nodes, a))
    return assignments

In [None]:
def get_spaced_colors(n):
    max_value = 16581375 #255**3
    interval = int(max_value / n)
    colors = [hex(I)[2:].zfill(6) for I in range(0, max_value, interval)]
    
    return np.array([(int(i[:2], 16), int(i[2:4], 16), int(i[4:], 16)) for i in colors])

In [None]:
def projections(graph, seed, image):
    '''This will output the relative labelings and save the RGB image to the relative labelings folder'''
    
    positions = get_image_positions(image)
    assignments = get_graph_seed_assignments(graph)
    
    seeds = set(assignments.values())
    n_seeds = len(seeds)
    
    seed_colors = get_spaced_colors(n_seeds)
    
    mask = np.zeros((image.shape[0], image.shape[1], 3))
    
    s = pd.Series(assignments)
    s = s.values.reshape(image.shape)
    
    for x in range(n_seeds):
        a = s == x
        if x == seed:
            mask[a] = [0, 255, 0]
        else:
            mask[a] = [0, 0, 255]
    return mask

In [None]:
def prims_msf(G):
    
    MSF = nx.Graph()
    nodes = G.nodes()
    
    s = filter(lambda (n, d): d['type'] == 'seed', G.nodes(data=True))
    seeds = []
    
    edge_num = itertools.count(1)
    num_edges = G.number_of_edges()
    for x in s:
        seeds.append(x[0])
    push = heappush
    pop = heappop

    while nodes:
        frontier = []
        visited = []
        for u in seeds:
            nodes.remove(u)
            visited.append(u)
            
            # Add seed to MSF
            MSF.add_node(u, attr_dict=G.node[u])
            
            # Push all edges
            for u, v in G.edges(u):
                
                #stdout.write("\rCalculating edge {}/{}".format(next(edge_num), num_edges))
                #stdout.flush()
               # percent_done = (next(edge_num) / num_edges) * 100
               # if percent_done.is_integer():
               #     print ("%s percent done." % str(percent_done))
                
                push(frontier, (G[u][v].get('weight', 1), u, v))
        
        

        while frontier:
            W, u, v = pop(frontier)   

            if v in visited:
                continue
            
            # Assign the node
            G.node[v]['seed'] = G.node[u]['seed']
            
            # Add node and edge to MSF
            MSF.add_node(v,attr_dict=G.node[v])
            MSF.add_edge(u,v,attr_dict=G.edge[u][v])
            
            #G.node[G.node[u]['seed']]['assignment history'] = G.node[G.node[u]['seed']]['assignment history'] + [v]       
            
            visited.append(v)
            nodes.remove(v)
            for v, w in G.edges(v):
                if not w in visited:
                    
                    #stdout.write("\rCalculating edge {}/{}".format(next(edge_num), num_edges))
                    #stdout.flush()
                   # percent_done = (next(edge_num) / num_edges) * 100
                    #if percent_done.is_integer():
                    #    print ("%s percent done." % str(percent_done))
                    
                    push(frontier, (G[v][w].get('weight', 1), v, w))
    return MSF

In [None]:
def field_of_view(master_image, center_pixel, crop_size):
    """Returns a crop of the master image centered on a specified pixel.
    
    Args:
        master_image (numpy.ndarray): The master image to be cropped.
        center_pixel (tuple): The pixel to be at the center of the crop.
        window_size (tuple): The dimensions of the newly cropped image.
        
    Returns:
        A cropped image of the master image with the specified pixel at the center.
        
    Examples:
        fov = field_of_view(I, (0,0), (50,50))"""
    
    center = (center_pixel[0] + crop_size[0], center_pixel[1] + crop_size[1])

    t_l = (center[0] - crop_size[0] // 2, center[1] - crop_size[1] // 2)
    
    return master_image[t_l[0]:t_l[0] + crop_size[0],t_l[1]:t_l[1] + crop_size[1],:]

In [None]:
def assignment_mask(image, graph):
    positions = get_image_positions(image)
    assignments = get_graph_seed_assignments(graph)
    
    seeds = set(assignments.values())
    n_seeds = len(seeds)
    
    seed_colors = get_spaced_colors(n_seeds)
    
    mask = np.zeros((image.shape[0], image.shape[1], 3))
    
    s = pd.Series(assignments)
    s = s.values.reshape(image.shape)
    
    for x in range(n_seeds):
        a = s == x
        mask[a] = seed_colors[x]
    return mask

In [None]:
def max_arc_topographical_distance(path, graph):
    """Returns the max-arc topographical distance of a path in a graph.
    
    Args:
        path (list): A list of tuples defining the path traversed in graph.
        graph (networkx.classes.graph.Graph): The graph containing the path.
        
    Returns:
        float: the max-arc topographical distance of path."""

    
    G = graph.subgraph(path)
    
    distance = -np.infty
    
    for x, y in G.edges_iter():
        altitude = G.edge[x][y]['weight']
        if altitude > distance:
            distance = altitude
            edge = (x, y)
            
    return distance, edge

In [None]:
def get_boundary_probabilities(img):
    """Returns boundary probabilites for an image"""
    
    #Make a placeholder for boundary probabilities coming in next iteration of pipeline
    return np.random.rand(img.shape[0],img.shape[1])

In [None]:
def _AsList(x):
    return x if isinstance(x, (list, tuple)) else [x]

In [None]:
def calculate_altitudes(edges):
    """
    Calculates a batch of edges.
    
    Args:
        edges: A `Dictionary` where the keys are the edges that need their altitudes calculated,
        and the values are the images that will be used to calculate the edges.
        
    Returns:
        Returns a copy of the dictionary where the original values of the dictionary are replaced
        with the altitudes of the edges.    
    """
    
    x = np.stack(edges.values())

    with sess.as_default():
        altitudes = sess.run(f_static, feed_dict={image_placeholder: x,
                                                 keras.backend.learning_phase(): 0})
        
    for i, (u, v) in enumerate(edges):
        altitude = altitudes[i][0]
        edges[(u, v)] = altitude
        
    return edges

In [None]:
def prims_on_demand(G, image_dict):
    """
    Creates a minimum spanning forest from give graph.
    
    Args:
        G (Graph): Graph to use for minimum spanning forest.
        image_dict (Dictionary): The dictionary of nodes used to calculate an altitude.
        
    Returns:
        MSF (Graph): Minimum spanning forest.
    """
    
    MSF = nx.Graph()
    nodes = G.nodes()
    
    s = filter(lambda (n, d): d['type'] == 'seed', G.nodes(data=True))
    seeds = []
    
    for x in s:
        seeds.append(x[0])
        
    push = heappush
    pop = heappop
    
    while nodes:
        frontier = []
        visited = []
        
        for u in seeds:
            nodes.remove(u)
            visited.append(u)
            
            # Add seed to MSF
            MSF.add_node(u, attr_dict=G.node[u])
            
            # Store path.
            G.node[u]['path'] = [u]
            
            uncalculated_edges = {}
            
            # Push all edges
            for u, v in G.edges(u):
                
                # Fetch cropped image from dictionary and append it to the dictionary of edges to be
                # calculated.
                fov = image_dict[(v[0], v[1])]
                uncalculated_edges[(u,v)] = fov
                G.edge[u][v]['f_static_image'] = fov
                G.edge[u][v]['f_dynamic_image'] = fov
                
            
            uncalculated_edges = calculate_altitudes(uncalculated_edges)
            
            # Set the weight of the edges.
            for i, (u, v) in enumerate(uncalculated_edges):
                altitude = uncalculated_edges[(u, v)]
                G.edge[u][v]['weight'] = altitude
                push(frontier, (altitude, u, v))

        while frontier:
            W, u, v = pop(frontier)   
            
            if v in visited:
                continue
            
            # Assign the node
            G.node[v]['seed'] = G.node[u]['seed']
            
            # Add node and edge to MSF
            MSF.add_node(v,attr_dict=G.node[v])
            MSF.add_edge(u,v,attr_dict=G.edge[u][v])
            
            # Store path.
            G.node[v]['path'] = G.node[u]['path'] + [v]
            
            #G.node[G.node[u]['seed']]['assignment history'] = \
            #G.node[G.node[u]['seed']]['assignment history'] + [v]
            
            visited.append(v)
            nodes.remove(v)
            
            uncalculated_edges = {}
            
            for v, w in G.edges(v):
                if not w in visited:
                    
                    # Fetch cropped image from dictionary and append it to the dictionary of edges to be
                    # calculated.
                    fov = image_dict[(w[0], w[1])]
                    uncalculated_edges[(v, w)] = fov
                    G.edge[v][w]['f_static_image'] = fov
                    G.edge[v][w]['f_dynamic_image'] = fov
            
            if len(uncalculated_edges) != 0:
                uncalculated_edges = calculate_altitudes(uncalculated_edges)

                for i, (v, w) in enumerate(uncalculated_edges):
                    altitude = uncalculated_edges[(v, w)]
                    G.edge[v][w]['weight'] = altitude
                    push(frontier, (altitude, v, w))
    
    return MSF

In [None]:
def epoch_data(img_graph):
    """
    Returns arrays containing edge calculation images and edge children count.
    
    Each edge weight corresponds to the same image used to calculate the edge.  For example,
    if images[0] was used to compute edge ((0,0),(0,1)), then weights[0] would be the error
    weight for that same edge.
    
    Args:
        img_graph: (graph)
    
    Returns:
        images (list): A list in which each element is the `ndarray` used to compute the
        edge in the graph.
        weights: (list): A list in which each element is the error weight for the edge.
    """
    
    
    images = list()
    weights = list()

    for edge in img_graph.edges_iter():
        try:
            weights.append(img_graph.edge[edge[0]][edge[1]]['error_weight'])
            images.append(img_graph.edge[edge[0]][edge[1]]['f_static_image'])
            
        except KeyError:
            continue

    images = np.array(images)
    weights = np.array(weights)
    
    return images, weights

In [None]:
def manually_check_loss(images, weights):
    """
    Checks the loss of the given data.
    """
    
    loss = 0

    with sess.as_default():
        altitudes = sess.run(f_static, feed_dict={image_placeholder: images,
                                                  keras.backend.learning_phase(): 0})
        
        loss = np.dot(weights, altitudes)
    
    return loss

In [None]:
def create_batches(x, y, max_batch_size=32):
    """

    Args:
        x: A numpy array of the input data
        y: A numpy array of the output
        max_batch_size: The maximum elements in each batch.

    Returns: A list of batches.

    """

    batches = math.ceil(x.shape[0] / max_batch_size)
    x = np.array_split(x, batches)
    y = np.array_split(y, batches)

    return zip(x, y)

In [None]:
def find_deviation(ground_truth_path, shortest_path):
    """
    Computes finds the edge where the ground truth path deviates from the shortest path.

    Args:
        ground_truth_path (list): The list of edges in the ground truth path.
        shortest_path (list): The list of edges in the shortest path.

    Returns:
        tuple: The first edge in which the two paths differ.
    """

    for i, (ground_truth_node, shortest_path_node) in enumerate(zip(ground_truth_path, shortest_path)):
        
        if shortest_path_node != ground_truth_node:
            return (ground_truth_path[i - 1], ground_truth_path[i])
    else:
        raise ValueError('No deviation.')

In [None]:
def find_first_false_cut(ground_truth_path, ground_truth_cuts, cut_edges):
    """
    Finds the first false cut edge of a ground truth path.
    
    
    Args:
        ground_truth_path (list): A list of nodes representing the path from the seed to the node.
        ground_truth_cuts (list): A list of ground truth cut edges. 
        cut_edges (list): A list of cut edges from the minimum spanning forest.
        
    Returns:
        tuple: The first edge in the ground truth path that is in the list of cut edges, but not in
        in the list of ground truth edges.
    """
    
    for i, node in enumerate(ground_truth_path):
        try:
            edge = (ground_truth_path[i], ground_truth_path[i + 1])
            if edge in cut_edges or tuple(reversed(edge)) in cut_edges:
                if edge not in ground_truth_cuts or tuple(reversed(edge)) not in ground_truth_cuts:
                    return edge

        except IndexError:
            print "Something went wrong."
            continue


In [None]:
def find_root_edge(node, msf, constrained_msf, cut_edges, ground_truth_cuts, edge_error_weights):
    """
    Finds the root error edges for a node and inserts them into the dictionary.
    
    Args:
        node (tuple): The node to find the root error edges for.
        msf (Graph): The MSF used to find the shortest path.
        constrained_msf (Graph): The constrained MSF used to find the ground truth path.
        cut_edges (list): A list of tuples representing the cuts for the graph.
        edge_error_weights (dictionary): The dictionary that holds all of the weights for the root
        error edges.        
    """
    
    # Get assigned seeds for MSF and constrained MSF.
    assigned_seed = msf.node[node]['seed']
    ground_truth_seed = constrained_msf.node[node]['seed']

    # Get the path from seed to node
    shortest_path = next(nx.all_simple_paths(msf, assigned_seed, node))
    ground_truth_path = next(nx.all_simple_paths(constrained_msf, ground_truth_seed, node))

    # Get the distances from the seed to the node.
    shortest_path_distance, shortest_distance_edge = max_arc_topographical_distance(shortest_path,
                                                                                    msf)
    ground_truth_path_distance, ground_truth_distance_edge = \
                                max_arc_topographical_distance(ground_truth_path,
                                                               constrained_msf)


    # Check is the ground truth path and the shortest path are equivalent.  If so, 
    #  then the node is correct, if not, then the node is incorrect.  
    if shortest_path != ground_truth_path:
        # The node is incorrect. 

        # Compute the root edge to increase (p(w)).
        root_missing_cut_edge = find_missing_cut(shortest_path, ground_truth_cuts, 
                                                 cut_edges)


        # Increment the number of children for the root edge.
        try:
            edge_error_weights[root_missing_cut_edge]

        except KeyError:

            edge_error_weights[root_missing_cut_edge] = 0
        finally:

            edge_error_weights[root_missing_cut_edge] = \
            edge_error_weights[root_missing_cut_edge] - 1

        # Compute the root edge to decrease.
        if assigned_seed != ground_truth_seed:
            root_false_cut_edge = find_first_false_cut(ground_truth_path,
                                                       ground_truth_cuts,
                                                       cut_edges)
        else:
            root_false_cut_edge = find_deviation(ground_truth_path, shortest_path)   

        try:
            edge_error_weights[root_false_cut_edge]
        except KeyError:
            edge_error_weights[root_false_cut_edge] = 0
        finally:
            edge_error_weights[root_false_cut_edge] = \
            edge_error_weights[root_false_cut_edge] + 1

In [None]:
def compute_root_error_edge_children(msf, constrained_msf, cut_edges, ground_truth_cuts):
    """
    Computes the root error edges used for a single training epoch of the system.

    This function will prepare the weight function and the altitude prediction used for the loss.
    The approach taken here is for every node in the graph, check if the node satisfies a failure 
    condition. If so, then add or subtract to the root error edge children.           

    By construction of the MSF, the shortest path and the ground truth path are equal
    for all nodes.  Conversely, they differ for incorrect nodes, causing the gound truth
    path distance to exceed the shortest path distance.

    TODO:
        Write function to fetch root error false cuts.
    """
    
    # Initialize edge error weights dictionary.
    edge_error_weights = dict()
    
    # Create a list of nodes and iterate through them.
    nodes = list((n for n in msf if msf.node[n]['type']=='node'))
    num_nodes = len(nodes)

    # Here multithreading is used to speed up root error edge computation.  Each thread 
    # computes the root error edges for a node.
    threads = []
    for i, node in enumerate(nodes):
        
        #stdout.write("\rChecking Node {}/{}".format(i, len(nodes)))
        #stdout.flush()
        
        thread = threading.Thread(target=find_root_edge, args=[node, msf, constrained_msf,
                                                              cut_edges, ground_truth_cuts,
                                                              edge_error_weights])
        threads.append(thread)
        thread.start()
    
    # Join threads
    [thread.join() for thread in threads]
    
    # Compute correct nodes and accuracy.
    correct_nodes = num_nodes - sum(map(abs, edge_error_weights.values()))
                
    accuracy = correct_nodes / num_nodes
    print "\nCorrect nodes: {}/{}".format(correct_nodes, num_nodes) 
    print "Accuracy: ", accuracy

    return edge_error_weights

In [None]:
def view_boundary_line(image, graph, figurename="graph.jpg"):
    """
    Draws the segmentation boundary line on the image for the given graph.
    
    Args:
        graph: The graph to use to compute the boundary lines.
        figurename: The name of the image to be saved to the current working directory.
    """
    
    boundary_nodes = list()
    
    seeds = list((n for n in graph if graph.node[n]['type']=='seed'))
    boundary_nodes.extend(seeds)
    
    for edge in graph.edges_iter():            
        if graph.node[edge[0]]['seed'] is not graph.node[edge[1]]['seed']:
            boundary_nodes.extend([edge[0], edge[1]])
            
    sub_graph = graph.subgraph(boundary_nodes)
    view_graph(image, sub_graph, figurename=figurename)

In [None]:
def find_missing_cut(shortest_path, ground_truth_cuts, cut_edges):
    """
    Computes the root error missing cut of a shortest path.

    Every incorrect shortest path has at least one erroneous cut edge.  The first such
    edge shall be called the path's root error edge p(w) and is always a missing cut.

    Args:
        shortest_path (list): The list of edges in the shortest path.
        ground_truth_cuts (list): The list ground truth cuts for the ground truth segmentation.
        cut_edges (list): The list of cut edges from the current segmentation.

    Returns:
        tuple: The first erroneous cut edge in the shortest path.
    """

    for i, node in enumerate(shortest_path):
        try:
            edge = (shortest_path[i], shortest_path[i + 1])
            if edge in ground_truth_cuts or tuple(reversed(edge)) in ground_truth_cuts:
                if edge not in cut_edges and tuple(reversed(edge)) not in cut_edges:
                    return edge

        except IndexError:
            print "Something went wrong."
            continue

In [None]:
class Watershed:
    
    def __init__(self, window_size=(32,32)):
        self.window_size = window_size
    
    def fit(self, images, ground_truth_images, epochs=16, batch_size=32, verbose=False):
        """
        Fits the model given an image or set of images.
        
        Args:
            images: A `numpy.ndarray` or list of arrays to be used to fit the model.
            ground_truth_images: A `numpy.ndarray` or list of arrays to be used to fit the model.
        """
        
        self.verbose=verbose
        self.epochs = epochs
        self.images = _AsList(images)
        self.ground_truth_images = _AsList(ground_truth_images)
        
        for image, ground_truth_image in zip(self.images, self.ground_truth_images):
            self.image = image
            self.ground_truth_image = ground_truth_image
            self._train_single()
        
    def _train_single(self):
        """
        Trains the model on a single image. Training is composed of two steps. First the image is
        segmented. Second, the model is updated. Before the model is trained, the ground truth
        cut edges are calculated.
        """
        
        print "Computing ground truth cuts."
        
        start = time.time()
        self.ground_truth_cuts = self._compute_ground_truth_cuts(self.ground_truth_image,
                                                                 self.seeds)
        end = time.time()
        
        print ("\nTime: %f" % (end - start))
        
        loss_timeline = list()
        
        for i in xrange(self.epochs):
            self.current_epoch = i + 1
            print "\n==========="
            print "Epoch ", self.current_epoch
            print "==========="
            
            print "Computing the MSF"
            start = time.time()
            segmentation = self.segment(self.image)
            end = time.time()
            print "Done in {} seconds".format(end - start)
            
            if self.verbose:
                filename = "training_images/epoch_{}.jpg".format(self.current_epoch)
                cv2.imwrite(filename, segmentation)
            
            loss_val = self._training_epoch()
            loss_timeline.append(loss_val)
            plt.plot(loss_timeline)
            plt.savefig("training_images/Loss.jpg")
            
            print "Loss: ", loss_val
        

    def segment(self, image):
        """
        Segments image.
        
        TODO:
            Create and return segmented image.
        """
        
        self._prepare_input_images()
        
        # Translate image to 4 connected grid graph.
        self.image_graph = img_to_graph(image)
            
        # Plant seeds.
        self.image_graph = plant_seeds(self.image_graph, self.seeds)
        
        # Compute image MSF.
        self.image_msf =  prims_on_demand(self.image_graph, self.input_images)

        # Compute cut edges
        self.cut_edges = []
        
        [self.cut_edges.append(e) if self.image_msf.node[e[0]]['seed']\
        is not self.image_msf.node[e[1]]['seed'] else '' for e in\
        self.image_graph.edges_iter()]
            
        return assignment_mask(image, self.image_graph)
        
    def _training_epoch(self):
        """
        This is the training epoch for each image.  The approach taken here is to first 
        create the constrained msf with the ground truth cut edges, then compute
        the root error edges, afterwards compute and apply the updates for the model parameters.
        """
        
        # Compute the constrained MSF.
        print "Compute constrained MSF"
        start = time.time()
        self.constrained_msf = self._compute_constrained_msf()
        end = time.time()
        print "Done in {} seconds".format(end - start)
        
        print "2. Identifying root edges and loss."
        start = time.time()
        edge_error_weights = compute_root_error_edge_children(self.image_msf, self.constrained_msf,
                                                                 self.cut_edges, self.ground_truth_cuts)
        
        end = time.time()
        print "Done in {} seconds".format(end - start)
        
        nx.set_edge_attributes(self.image_graph, 'error_weight', edge_error_weights)
        
        # Fetch the training data from the image graph.
        images, weights = epoch_data(self.image_graph)
        
        # Split data into batches.
        batches = create_batches(images, weights)
        
        # Compute loss.
        print "Computing Loss"
        start = time.time()
        with sess.as_default():
            loss_val = sess.run(loss, feed_dict={image_placeholder: images,
                                                 gradient_weights: [weights],
                                                 keras.backend.learning_phase(): 0})
        end = time.time()
        print "Done in {} seconds".format(end - start)
        loss_val = loss_val[0][0]
            
        # Update parameters
        print "Updating Parameters."
        start = time.time()
        with sess.as_default():
            
            # Zero out gradient accumulator.
            sess.run(zero_ops)
            
            # Accumulate gradients.
            for batch in batches:
                sess.run(accum_ops, feed_dict={image_placeholder: batch[0],
                                               gradient_weights: [batch[1]],
                                               keras.backend.learning_phase(): 0})
            
            sess.run(train_step)
        end = time.time()
        print "Done in {} seconds".format(end - start)
            
        gc.collect()
            
        return loss_val
    
    
    def _compute_ground_truth_cuts(self, ground_truth_image, seeds):
        """
        Computes the ground truth cuts of the given image.
        
        Args:
            ground_truth_image (ndarray): The image used to create the ground truth cuts.
            seeds (list): A list of seeds to start watershed.
        
        Returns:
            list: A list of ground truth cut edges.
        """
        
        ground_truth_graph = img_to_graph(ground_truth_image)
        ground_truth_graph = plant_seeds(ground_truth_graph, self.seeds)
        ground_truth_msf = prims_msf(ground_truth_graph)
        
        # Compute the ground truth cut edges
        ground_truth_cuts = []
        [ground_truth_cuts.append(e) if ground_truth_msf.node[e[0]]['seed']\
                 is not ground_truth_msf.node[e[1]]['seed'] else '' for e in\
                 ground_truth_graph.edges_iter()]
        
        ground_truth_segmentation = assignment_mask(self.image, ground_truth_graph)
        filename = "training_images/ground_truth_segmentation.jpg"
        cv2.imwrite(filename, ground_truth_segmentation)
        
        return ground_truth_cuts
            

    def _prepare_input_images(self):
        """
        Preprocess images to be used in the prediction of the edges.
        """
        
        self.padded_image = np.pad(self.image,(self.window_size [0],self.window_size [1]),'reflect')
        
        #boundary_probabilities = get_boundary_probabilities(self.padded_image)
        boundary_probabilities = np.pad(self.ground_truth_image,(self.window_size [0],self.window_size [1]),'reflect')

        #Augment the image with boundary probabilities
        self.augmented_image = np.dstack((self.padded_image, boundary_probabilities))

        # Get input images
        self.input_images = dict()

        for x in range(self.image.shape[0]):
            for y in range(self.image.shape[1]):
                node = (x, y)
                self.input_images[node] = field_of_view(self.augmented_image, node, self.window_size)

        
    def _compute_constrained_msf(self):
        """
        Returns the constained msf.
        
        """
        
        self.img_cuts = self.image_graph.copy()
        self.img_cuts.remove_edges_from(self.ground_truth_cuts)
        return prims_msf(self.img_cuts)
    
    
    @property
    def ground_truth_cuts(self):
        return self.ground_truth_cuts
    
    
    @property
    def cut_edges(self):
        return self.cut_edges
    
    
    @property
    def image_graph(self):
        return self.image_graph
    
    
    @property
    def seeds(self):
        return self.seeds
    
    
    @seeds.setter
    def seeds(self, seeds):
        self.seeds = seeds
        
    
    @property
    def window_size(self):
        return self.window_size
    
    @property
    def image_dict(self):
        return self.image_dict

# THIS IS THE START OF THE COMPUTATION

In [None]:
# Set image size
image_tl = (0, 0)
image_size = (50, 50)
window_size = (15, 15)

#import Images
img = cv2.imread('1O.jpg', 0)
gt = cv2.imread('1G.jpg', 0)

#resize
img = img[image_tl[0]:image_tl[0] + image_size[0],
          image_tl[1]:image_tl[1] + image_size[1]]
gt = gt[image_tl[0]:image_tl[0] + image_size[0],
        image_tl[1]:image_tl[1] + image_size[1]]

#set type
img = img.astype(np.int16)
gt = gt.astype(np.int16)

# Save image
plt.imsave("training_images/image", cv2.resize(img, (1000,1000)), cmap='gray')
plt.clf()

# Import seeds
seeds = []
f = open("seeds1G.txt", 'r')
for line in f:
    x = int(float(re.split(' ', line)[0]))
    y = int(float(re.split(' ', line)[1]))
    seed = (x, y)
    
    

    if x >= image_tl[0] and x <= image_tl[0] + image_size[0]:
        if y >= image_tl[1] and y <= image_tl[1] + image_size[1]:
            x = x - image_tl[0]
            y = y - image_tl[1]
            seed = (x, y)
            seeds.append(seed)

In [None]:
# This placeholder will contain the input images
image_placeholder = tf.placeholder(tf.float32, shape=(None, window_size[0], window_size[0], 2))

# Create model
m = keras.layers.Conv2D(16, 5, padding = 'same',
                 activation='elu', dilation_rate=1) (image_placeholder)
m = keras.layers.Conv2D(16, 3, padding = 'same',
                 activation='elu', dilation_rate=1) (m)
m = keras.layers.BatchNormalization() (m)
m = keras.layers.Conv2D(32, 3, padding = 'same',
                 activation='elu', dilation_rate=2) (m)
m = keras.layers.BatchNormalization() (m)
m = keras.layers.Conv2D(32, 3, padding = 'same',
                 activation='elu', dilation_rate=4) (m)
m = keras.layers.BatchNormalization() (m)
m = keras.layers.Conv2D(64, 3, padding = 'same',
                 activation='elu', dilation_rate=8) (m)
m = keras.layers.BatchNormalization() (m)
m = keras.layers.Conv2D(64, 3, padding = 'same',
                 activation='elu', dilation_rate=16) (m)
m = keras.layers.BatchNormalization() (m)
m = keras.layers.Conv2D(128, 3, padding = 'same',
                 activation='elu', dilation_rate=1) (m)
m = keras.layers.BatchNormalization() (m)
m = keras.layers.Flatten()(m)
m = keras.layers.Dense(1024, activation='relu')(m)
m = keras.layers.BatchNormalization() (m)
m = keras.layers.Dense(1, activation='elu')(m)
f_static = keras.layers.BatchNormalization() (m)

In [None]:
# This placeholder will hold the root error edge values.
gradient_weights = tf.placeholder(tf.float32, shape=(1, None))

# Define optimizer
opt = tf.train.GradientDescentOptimizer(learning_rate=0.000001)

tvs = tf.trainable_variables()

loss = tf.matmul(gradient_weights, f_static)

# Accumulate gradients of predictions with respect to the parameters.
accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in tvs]                                        
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars]

gvs = opt.compute_gradients(loss, tvs)
accum_ops = [accum_vars[i].assign_add(gv[0]) for i, gv in enumerate(gvs)]

# Apply gradients
train_step = opt.apply_gradients([(accum_vars[i], gv[1]) for i, gv in enumerate(gvs)])

sess = tf.Session() 
sess.run(tf.global_variables_initializer())

In [None]:
w = Watershed(window_size=window_size)

w.seeds = seeds

w.fit(img, gt, epochs=1, verbose=True)

In [None]:
w.image_graph.node[(0,0)]

In [None]:
# Set image size
image_tl = (300,300)

#import Images
img = cv2.imread('1O.jpg', 0)
gt = cv2.imread('1G.jpg', 0)

#resize
img = img[image_tl[0]:image_tl[0] + image_size[0],
          image_tl[1]:image_tl[1] + image_size[1]]
gt = gt[image_tl[0]:image_tl[0] + image_size[0],
        image_tl[1]:image_tl[1] + image_size[1]]

#set type
img = img.astype(np.int16)
gt = gt.astype(np.int16)

# Save image
plt.imsave("image", cv2.resize(img, (1000,1000)), cmap='gray')
plt.clf()

# Import seeds
seeds = []
f = open("seeds1G.txt", 'r')
for line in f:
    x = int(float(re.split(' ', line)[0]))
    y = int(float(re.split(' ', line)[1]))
    seed = (x, y)
    
    

    if x >= image_tl[0] and x <= image_tl[0] + image_size[0]:
        if y >= image_tl[1] and y <= image_tl[1] + image_size[1]:
            x = x - image_tl[0]
            y = y - image_tl[1]
            seed = (x, y)
            seeds.append(seed)

In [None]:
w.seeds = seeds

segmentation = w.segment(img)


filename = "test.jpg"
cv2.imwrite(filename, segmentation)