Learing Strategy:

Train on every image at once n number of epochs.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import cv2
import numpy as np
import os
import re
import shutil

import keras

from models import BachNet
from models import ChopinNet

import networkx as nx
import sys
import time


from heapq import heappop as pop
from heapq import heappush as push
from utils import graph_utils
from utils import display_utils
from utils import prediction_utils
from utils import preprocessing_utils

In [None]:
train_path = "data/train"
test_path = "data/test"

input_path = "input"
output_path = "output"

gt_tag = "gt"

receptive_field_shape = (12, 12)
n_epochs = 32

In [None]:
bach = BachNet.BachNet()
bach.build(receptive_field_shape, 1)
bach.load_model('models/saved_models/Bach/model.h5')

In [None]:
batch = dict()
input_gen = prediction_utils.input_generator(bach, train_path, input_path, gt_tag)

while True:
    try:
        f_name, img, bps, I_a, gt, gt_cuts, seeds = next(input_gen)
        graph = graph_utils.prims_initialize(img)
        batch[f_name] = img, bps, I_a, gt, gt_cuts, seeds, graph
    except StopIteration:
        break

In [None]:
chopin = ChopinNet.Chopin()
chopin.build(receptive_field_shape, learning_rate=1e-6)
#chopin.load_model("models/saved_model/Chopin/checkpoint")
chopin.initialize_session()

In [None]:
global_loss_timeline = []
loss_timelines = dict()
loss_file = open("data/train/chopin/global_loss.txt", 'w')
loss_file.write("f_name\tepoch\tloss\n")

for epoch in range(n_epochs):
    for f_name, (img, bps, I_a, gt, gt_cuts, seeds, graph) in batch.iteritems():
        
        foldername = os.path.join(train_path, "chopin", f_name)
        start = time.time()
        
        loss, segmentations, cuts = chopin.train_on_image(img, bps, I_a, gt, gt_cuts, seeds, graph)
        
        print(time.time() - start)
        print("Loss: ", loss)
        
        plt.imsave(os.path.join(foldername, "epoch_{}_bw".format(epoch)), display_utils.view_boundaries(np.zeros_like(img), cuts))

        mask = display_utils.transparent_mask(img, segmentations)
        plt.imsave(os.path.join(foldername, "epoch_{}_overlay".format(epoch), mask))
        
        loss_file.write(f_name + "\t" + str(epoch) + "\t" + str(loss) + "\n")
        loss_file.flush()
        
        try:
            loss_timelines[f_name].append(loss)
        except KeyError:
            loss_timelines[f_name] = [loss]
            
        plt.plot(loss_timelines[f_name])
        plt.savefig(os.path.join(foldername, "local_loss"))
            
        plt.plot(global_loss_timeline)
        plt.savefig("data/train/chopin/global_loss")
            
loss_file.close()