In [1]:
# Compatibility Imports.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

In [2]:
# Python Imports.
import keras
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import os
import sys
import shutil

import ChopinNet
import BachNet
from utils import graph_utils
from utils import display_utils
from utils import preprocessing_utils
from utils import prediction_utils

Using TensorFlow backend.


In [3]:
%matplotlib inline

In [4]:
train_path = "data/train"
test_path = "data/test"

input_path = "input"
output_path = "output"

path_to_original_images = "original"
path_to_gt_images = "gt"

receptive_field_shape = (15, 15)
n_epochs = 8

In [5]:
# Import Images

bach = BachNet.BachNet()
bach.build(receptive_field_shape, 1)
bach.load_model('models/saved_model/Bach/model.h5')

image_data = dict()

files = os.listdir(os.path.join(train_path, input_path))

for i, filename in enumerate(files):
    
    sys.stdout.write("\rProgress: %.2f%% || %d/%d" % (i / len(files),
                                                      i,
                                                      len(files)))
    
    f_name, ext = os.path.splitext(filename)
    
    
    if "gt" in f_name.split("_"):
        continue
        
    gt_filename = f_name + "_gt"
    
    gt_path = os.path.join(train_path, input_path, (gt_filename + ext))
    if not os.path.isfile(gt_path):
        continue    
    
    foldername = "data/train/chopin/" + f_name
        
    image_path = os.path.join(train_path, input_path, filename)
    
    img, gt, seeds = preprocessing_utils.load_image(foldername,
                                    image_path,
                                    gt_path) 
    
    bps = bach.boundary_probabilities(img)
    
    plt.imsave(os.path.join(train_path, "bach", f_name), bps, cmap='gray')
    
    image_data[f_name] = img, bps, gt, seeds
    
sys.stdout.write("\rProgress: Done! || %d/%d" % (len(files),
                                                len(files)))
sys.stdout.flush()

Progress: 0.00% || 0/178data/train/input/slice_39_gt.png
Progress: 0.02% || 3/178data/train/input/slice_55_gt.png
Progress: 0.03% || 5/178data/train/input/slice_86_gt.png
Progress: 0.03% || 6/178data/train/input/slice_35_gt.png
Progress: 0.06% || 10/178data/train/input/slice_36_gt.png
Progress: 0.06% || 11/178data/train/input/slice_38_gt.png
Progress: 0.07% || 13/178data/train/input/slice_40_gt.png
Progress: 0.09% || 16/178data/train/input/slice_1_gt.png
Progress: 0.10% || 18/178data/train/input/slice_10_gt.png
Progress: 0.11% || 19/178data/train/input/slice_23_gt.png
Progress: 0.12% || 21/178data/train/input/slice_29_gt.png
Progress: 0.12% || 22/178data/train/input/slice_4_gt.png
Progress: 0.15% || 27/178data/train/input/slice_16_gt.png
Progress: 0.16% || 28/178data/train/input/slice_24_gt.png
Progress: 0.17% || 30/178data/train/input/slice_46_gt.png
Progress: 0.17% || 31/178data/train/input/slice_20_gt.png
Progress: 0.19% || 33/178data/train/input/slice_32_gt.png
Progress: 0.19% || 3

In [6]:
image_data

{'slice_0': (array([[150, 137, 123, ..., 195, 201, 208],
         [139, 121, 121, ..., 211, 216, 214],
         [136, 127, 127, ..., 210, 225, 208],
         ..., 
         [199, 190, 183, ..., 192, 210, 225],
         [201, 168, 177, ..., 185, 214, 225],
         [174, 177, 154, ..., 194, 210, 225]], dtype=uint8),
  array([[ 0.03880607,  0.03880607,  0.03880607, ...,  0.03880607,
           0.03880607,  0.03880607],
         [ 0.03880607,  0.03880607,  0.03880607, ...,  0.03880607,
           0.03880607,  0.03880607],
         [ 0.03880607,  0.03880607,  0.03880607, ...,  0.03880607,
           0.03880607,  0.03880607],
         ..., 
         [ 0.03880607,  0.03880607,  0.03880607, ...,  0.03880607,
           0.03880607,  0.03880607],
         [ 0.03880607,  0.03880607,  0.03880607, ...,  0.03880607,
           0.03880607,  0.03880607],
         [ 0.03880607,  0.03880607,  0.03880607, ...,  0.03880607,
           0.03880607,  0.03880607]], dtype=float32),
  array([[255, 255,   0, ..

In [7]:
chopin = ChopinNet.Chopin()
chopin.build(receptive_field_shape)
chopin.initialize_session()

In [8]:
def train_single(chopin,
                 img,
                 I_a,
                 gt,
                 seeds,
                 foldername,
                 global_loss=[],
                 global_accuracy=[],
                 num_epochs=16):
    
    
    I_a = preprocessing_utils.pad_for_window(I_a,
                               chopin.receptive_field_shape[0],
                               chopin.receptive_field_shape[1])
    
    graph = graph_utils.prims_initialize(img)

    ground_truth_cuts, gt_graph, gt_assignments = graph_utils.generate_gt_cuts(gt,
                                                                               seeds,
                                                                               assignments=True)

    local_loss = []
    local_accuracy = []
    
    saved_models_path = os.path.join(foldername, "saved_models")
    
    if os.path.exists(saved_models_path):
        shutil.rmtree(saved_models_path)
    os.mkdir(saved_models_path)

    for epoch in range(num_epochs):
        print("Epoch {}".format(epoch + 1))
        msf = chopin.predicted_msf(I_a, graph, seeds)
        segmentations = display_utils.assignments(np.zeros_like(img), msf, seeds)

        shortest_paths = nx.get_node_attributes(msf, 'path')
        assignments = nx.get_node_attributes(msf, 'seed')
        cuts = graph_utils.get_cut_edges(msf)

        acc = graph_utils.accuracy(assignments, gt_assignments)
        
        print("Accuracy: ", acc)
        local_accuracy.append(acc)

        filename = "epoch_{}.png".format(epoch + 1)
        
        boundaries = display_utils.view_boundaries(np.zeros_like(img),
                                                   cuts)

        plt.imsave(os.path.join(foldername, filename), boundaries)

        constrained_msf = msf.copy()

        constrained_msf.remove_edges_from(ground_truth_cuts)

        constrained_msf = graph_utils.minimum_spanning_forest(img, constrained_msf, seeds)

        ground_truth_paths = nx.get_node_attributes(constrained_msf, 'path')

        children = graph_utils.compute_root_error_edge_children(shortest_paths,
                                                          ground_truth_paths, cuts,
                                                          ground_truth_cuts)

        weights = []
        static_images = []
        dynamic_images = []

        for (u, v), weight in children.iteritems():

            try:
                static_images.append(msf.get_edge_data(u, v)['static_image'])
                dynamic_images.append(msf.get_edge_data(u, v)['dynamic_image'])
                weights.append(weight)
                altitude_val = msf.get_edge_data(u, v)['weight']
            except KeyError:
                pass

        batches = zip(preprocessing_utils.create_batches(np.expand_dims(np.stack(weights), 1)),
                      preprocessing_utils.create_batches(np.stack(static_images)),
                      preprocessing_utils.create_batches(np.stack(dynamic_images)))

        loss = 0
        with chopin.sess.as_default():
            chopin.sess.run(chopin.zero_ops)

            for w, s, d in batches:
                feed_dict = {chopin.gradient_weights: w.transpose(),
                             chopin.static_input: s,
                             chopin.dynamic_input: d,
                             keras.backend.learning_phase(): 0}

                chopin.sess.run(chopin.accum_ops, feed_dict)
                loss = chopin.sess.run(chopin.loss, feed_dict)
                loss += loss[0][0]

            chopin.sess.run(chopin.train_step)



        local_loss.append(loss)
        print("Loss: ", loss)
        
#         info = "Epoch: {}\tloss: {}\taccuracy: {}\n".format(epoch + 1, loss, acc)
#         loss_file.write(info)
#         loss_file.flush()

#         f, axarr = plt.subplots(2, sharex=True)
#         axarr[0].plot(local_loss)
#         axarr[0].set_title("Loss")
#         axarr[1].plot(local_accuracy)
#         axarr[1].set_title("Accuracy")

#         figurename = "Local Loss and Accuracy"

#         plt.savefig(os.path.join(foldername, figurename))
        
#         global_loss.append(loss)
#         global_accuracy.append(acc)
#         f, axarr = plt.subplots(2, sharex=True)
#         axarr[0].plot(global_loss)
#         axarr[0].set_title("Loss")
#         axarr[1].plot(global_accuracy)
#         axarr[1].set_title("Accuracy")

#         figurename = "Global Loss and Accuracy"

#         global_folder = foldername.split("/")[:-1]
#         global_folder = "/".join(global_folder)
        
#         plt.savefig(os.path.join(global_folder, figurename))
        

        chopin.save_model("models/saved_model/Chopin/model.ckpt")
        model_name = "epoch_{}".format(epoch)
        chopin.save_model(os.path.join(foldername, "saved_models", model_name, model_name))

    return segmentations, global_loss, global_accuracy

In [9]:
global_loss = list()
global_accuracy = list()

try:
    loss_file
except NameError:
    loss_file = open('data/train/chopin/global_loss.txt', 'w')


for name, (img, bps, gt, seeds) in image_data.iteritems():
    print("Training on " + name)
    
    img_info = "\nImage: {}\n\n".format(name)
    loss_file.write(img_info)
    loss_file.flush()
    
    I_a = np.stack((img, bps), axis=2)
    foldername = "data/train/chopin/" + name
    segs, glob, acc_timeline = train_single(chopin,
                                            img,
                                            I_a,
                                            gt,
                                            seeds,
                                            foldername,
                                            global_loss,
                                            global_accuracy,
                                            num_epochs=n_epochs)

Training on slice_27
Epoch 1
Starting gradient segmentation...
Segmentation done: 67.961964s
Accuracy:  0.612
Loss:  [[-730.77880859]]
Epoch 2
Starting gradient segmentation...
Segmentation done: 73.243678s
Accuracy:  0.0412
Loss:  [[ 13030.]]
Epoch 3
Starting gradient segmentation...
Segmentation done: 73.693943s
Accuracy:  0.0412
Loss:  [[ 13030.]]


KeyboardInterrupt: 