In [1]:
import os
import sys
import numpy as np
import cv2
import re
import networkx as nx
import keras 
from learnedwatershed.utils import prediction_utils, display_utils, preprocessing_utils, graph_utils
from learnedwatershed import ChopinNet

Using TensorFlow backend.


In [2]:
train_path = 'data/train/segmentation_labels/'
input_path = 'data/train/segmentation_labels'
model_path = 'models/'

receptive_field_shape = (23, 23)
num_epochs = 10

## Load Training Data

In [3]:
image_data = {}

files = os.listdir(os.path.join(train_path))

for i, filename in enumerate(files):
    
    sys.stdout.write("\rProgress: %.2f%% || %d/%d" % (i / len(files),
                                                      i,
                                                      len(files)))
    
    f_name, ext = os.path.splitext(filename)
    
    if "gt" in f_name.split("_"):
        continue
    
    gt_path = os.path.join(train_path, (f_name + "_gt" + ext))
    augmentation_path = os.path.join(train_path, (f_name + "_augmentation" + ext))
    seeds_path = os.path.join(train_path, (f_name + "_seeds.txt"))
    
    if not os.path.isfile(gt_path) or not os.path.isfile(augmentation_path) or not os.path.isfile(seeds_path):
        continue    
        
    image_path = os.path.join(train_path, filename)
    
    img = cv2.imread(image_path, 0)
    gt = cv2.imread(gt_path, 0)
    augmentation = cv2.imread(augmentation_path, 0)
    
    seeds = []
    
    with open(seeds_path, 'r') as f:
        for line in f:
            y = int(float(re.split(' ', line)[0]))
            x = int(float(re.split(' ', line)[1]))
            seed = (x, y)
            seeds.append(seed)
    
#     img, gt, seeds = preprocessing_utils.load_image(foldername,
#                                     image_path,
#                                     gt_path) 
    
    image_data[f_name] = img, augmentation, gt, seeds
    
sys.stdout.write("\rProgress: Done! || %d/%d" % (len(files),
                                                len(files)))
sys.stdout.flush()

Progress: 0.00% || 0/242Progress: 0.00% || 1/242Progress: 0.00% || 2/242Progress: 0.00% || 3/242Progress: 0.00% || 4/242Progress: 0.00% || 5/242Progress: 0.00% || 6/242Progress: 0.00% || 7/242Progress: 0.00% || 8/242Progress: 0.00% || 9/242Progress: 0.00% || 10/242Progress: 0.00% || 11/242Progress: 0.00% || 12/242Progress: 0.00% || 13/242Progress: 0.00% || 14/242Progress: 0.00% || 15/242Progress: 0.00% || 16/242Progress: 0.00% || 17/242Progress: 0.00% || 18/242Progress: 0.00% || 19/242Progress: 0.00% || 20/242Progress: 0.00% || 21/242Progress: 0.00% || 22/242Progress: 0.00% || 23/242Progress: 0.00% || 24/242Progress: 0.00% || 25/242Progress: 0.00% || 26/242Progress: 0.00% || 27/242Progress: 0.00% || 28/242Progress: 0.00% || 29/242Progress: 0.00% || 30/242Progress: 0.00% || 31/242Progress: 0.00% || 32/242Progress: 0.00% || 33/242Progress: 0.00% || 34/242Progress: 0.00% || 35/242Progress: 0.00% || 36/242Progress: 0.00% || 37/242Progress: 0.00% || 38

## Training

In [4]:
gt.shape, img.shape

((32, 32), (32, 32))

In [5]:
chopin = ChopinNet.Chopin()
chopin.build(receptive_field_shape)
chopin.initialize_session()

In [6]:
for f_name, (img, augmentation, gt, seeds) in image_data.iteritems():
    
    I_a = np.stack((img, gt), -1)
    print(seeds)
    
    
    I_a = preprocessing_utils.pad_for_window(I_a,
                           chopin.receptive_field_shape[0],
                           chopin.receptive_field_shape[1])

    graph = graph_utils.prims_initialize(img)

    ground_truth_cuts, gt_graph, gt_assignments = graph_utils.generate_gt_cuts(gt,
                                                                               seeds,
                                                                               assignments=True)
    local_loss = []
    local_accuracy = []

    #saved_models_path = os.path.join(foldername, "saved_models")

    # if os.path.exists(saved_models_path):
    #     shutil.rmtree(saved_models_path)
    # os.mkdir(saved_models_path)

    for epoch in range(num_epochs):
        print("Epoch {}".format(epoch + 1))
        msf = chopin.predicted_msf(I_a, graph, seeds)
        segmentations = display_utils.assignments(np.zeros_like(img), msf, seeds)

        shortest_paths = nx.get_node_attributes(msf, 'path')
        assignments = nx.get_node_attributes(msf, 'seed')
        cuts = graph_utils.get_cut_edges(msf)

        acc = graph_utils.accuracy(assignments, gt_assignments)

        print("Accuracy: ", acc)
        local_accuracy.append(acc)

        filename = "{}_epoch_{}.png".format(f_name, epoch + 1)

        boundaries = display_utils.view_boundaries(np.zeros_like(img),
                                                   cuts)

        cv2.imwrite(os.path.join(input_path, filename), boundaries)

        constrained_msf = msf.copy()

        constrained_msf.remove_edges_from(ground_truth_cuts)

        constrained_msf = graph_utils.minimum_spanning_forest(img, constrained_msf, seeds)

        ground_truth_paths = nx.get_node_attributes(constrained_msf, 'path')

        children = graph_utils.compute_root_error_edge_children(shortest_paths,
                                                          ground_truth_paths, cuts,
                                                          ground_truth_cuts)

        weights = []
        static_images = []
        dynamic_images = []

        for (u, v), weight in children.iteritems():

            try:
                static_images.append(msf.get_edge_data(u, v)['static_image'])
                dynamic_images.append(msf.get_edge_data(u, v)['dynamic_image'])
                weights.append(weight)
                altitude_val = msf.get_edge_data(u, v)['weight']
            except KeyError:
                pass

        batches = zip(preprocessing_utils.create_batches(np.expand_dims(np.stack(weights), 1)),
                      preprocessing_utils.create_batches(np.stack(static_images)),
                      preprocessing_utils.create_batches(np.stack(dynamic_images)))

        loss = 0
        with chopin.sess.as_default():
            chopin.sess.run(chopin.zero_ops)

            for w, s, d in batches:
                feed_dict = {chopin.gradient_weights: w.transpose(),
                             chopin.static_input: s,
                             chopin.dynamic_input: d,
                             keras.backend.learning_phase(): 0}

                chopin.sess.run(chopin.accum_ops, feed_dict)
                loss = chopin.sess.run(chopin.loss, feed_dict)
                loss += loss[0][0]

            chopin.sess.run(chopin.train_step)

        local_loss.append(loss)
        print("Loss: ", loss)

    #     chopin.save_model("models/saved_model/Chopin/model.ckpt")
    #     model_name = "epoch_{}".format(epoch)
    #     chopin.save_model(os.path.join(foldername, "saved_models", model_name, model_name))


[(12, 15), (0, 31)]
Epoch 1
Starting gradient segmentation...
Segmentation done: 10.125757s
('Accuracy: ', 0.283203125)
('Loss: ', array([[ 281.06362915]], dtype=float32))
Epoch 2
Starting gradient segmentation...
Segmentation done: 9.618827s
('Accuracy: ', 0.37109375)
('Loss: ', array([[-36664.33984375]], dtype=float32))
Epoch 3
Starting gradient segmentation...
Segmentation done: 9.679310s
('Accuracy: ', 0.369140625)
('Loss: ', array([[ 15278.8359375]], dtype=float32))
Epoch 4
Starting gradient segmentation...
Segmentation done: 10.026381s
('Accuracy: ', 0.3701171875)
('Loss: ', array([[-310765.125]], dtype=float32))
Epoch 5
Starting gradient segmentation...
Segmentation done: 9.765884s
('Accuracy: ', 0.7197265625)
('Loss: ', array([[ 18990.5]], dtype=float32))
Epoch 6
Starting gradient segmentation...
Segmentation done: 9.696550s
('Accuracy: ', 0.6396484375)
('Loss: ', array([[ 6328.73632812]], dtype=float32))
Epoch 7
Starting gradient segmentation...
Segmentation done: 9.739666s
('

Segmentation done: 10.007804s
('Accuracy: ', 0.912109375)
('Loss: ', array([[ -1.63674522e+09]], dtype=float32))
Epoch 5
Starting gradient segmentation...
Segmentation done: 9.879662s
('Accuracy: ', 0.919921875)
('Loss: ', array([[  2.22547149e+09]], dtype=float32))
Epoch 6
Starting gradient segmentation...
Segmentation done: 9.956976s
('Accuracy: ', 0.919921875)
('Loss: ', array([[  1.32783309e+09]], dtype=float32))
Epoch 7
Starting gradient segmentation...
Segmentation done: 10.002641s
('Accuracy: ', 0.873046875)
('Loss: ', array([[ -5.25128909e+09]], dtype=float32))
Epoch 8
Starting gradient segmentation...
Segmentation done: 9.894445s
('Accuracy: ', 0.873046875)
('Loss: ', array([[ -5.23842970e+09]], dtype=float32))
Epoch 9
Starting gradient segmentation...
Segmentation done: 9.996564s
('Accuracy: ', 0.873046875)
('Loss: ', array([[ -5.69550234e+09]], dtype=float32))
Epoch 10
Starting gradient segmentation...
Segmentation done: 9.878604s
('Accuracy: ', 0.8740234375)
('Loss: ', arra

('Loss: ', array([[  5.62051416e+11]], dtype=float32))
Epoch 7
Starting gradient segmentation...
Segmentation done: 10.109127s
('Accuracy: ', 0.1884765625)
('Loss: ', array([[  3.62956194e+11]], dtype=float32))
Epoch 8
Starting gradient segmentation...
Segmentation done: 9.811080s
('Accuracy: ', 0.1884765625)
('Loss: ', array([[  1.65150720e+11]], dtype=float32))
Epoch 9
Starting gradient segmentation...
Segmentation done: 9.953214s
('Accuracy: ', 0.1884765625)
('Loss: ', array([[ -2.90602353e+10]], dtype=float32))
Epoch 10
Starting gradient segmentation...
Segmentation done: 9.943660s
('Accuracy: ', 0.1884765625)
('Loss: ', array([[ -2.11332104e+11]], dtype=float32))
[(7, 0), (22, 15)]
Epoch 1
Starting gradient segmentation...
Segmentation done: 9.788570s
('Accuracy: ', 0.603515625)
('Loss: ', array([[ -1.25417701e+13]], dtype=float32))
Epoch 2
Starting gradient segmentation...
Segmentation done: 9.826012s
('Accuracy: ', 0.6015625)
('Loss: ', array([[ -1.30428510e+13]], dtype=float32)

Segmentation done: 9.877623s
('Accuracy: ', 0.8759765625)
('Loss: ', array([[ -3.38247025e+12]], dtype=float32))
Epoch 10
Starting gradient segmentation...
Segmentation done: 9.876077s
('Accuracy: ', 0.8759765625)
('Loss: ', array([[ -3.37307553e+12]], dtype=float32))
[(17, 10), (2, 27)]
Epoch 1
Starting gradient segmentation...
Segmentation done: 10.038380s
('Accuracy: ', 0.892578125)
('Loss: ', array([[ -3.45073202e+12]], dtype=float32))
Epoch 2
Starting gradient segmentation...
Segmentation done: 9.920528s
('Accuracy: ', 0.892578125)
('Loss: ', array([[ -3.04757578e+12]], dtype=float32))
Epoch 3
Starting gradient segmentation...
Segmentation done: 10.160117s
('Accuracy: ', 0.892578125)
('Loss: ', array([[ -2.70019749e+12]], dtype=float32))
Epoch 4
Starting gradient segmentation...
Segmentation done: 10.206356s
('Accuracy: ', 0.892578125)
('Loss: ', array([[ -2.54341598e+12]], dtype=float32))
Epoch 5
Starting gradient segmentation...
Segmentation done: 10.007513s
('Accuracy: ', 0.892

KeyboardInterrupt: 