In [1]:
import os
from os.path import join
import sys
import cv2 as cv
import numpy as np
from numpy.random import RandomState
import pickle
import matplotlib.pyplot as plt

import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import regularizers
from keras import backend as K
from keras.callbacks import EarlyStopping, ModelCheckpoint

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
def load_labels(specific_video=None):
    """ Loads in image data as numpy arrays """
    sequence = []
    none_count = 0 
    filedir = join(os.getcwd(),"labels")
    for file in os.listdir(filedir):
        ## change current seq when video_id change or marker number changes
        if file.endswith(".jpg"):
            file = file.split(".")[0]
            file = file.split("_")
            if specific_video == None:
                video_id, marker_num, marker_type, frame_num, x_pos, y_pos = file[0], int(file[1]), int(file[2]), int(file[3]), int(file[4]), int(file[5])
                current_seq = [video_id, marker_num, marker_type, frame_num, x_pos, y_pos]
                sequence.append(current_seq)
            else:
                if file[0] == video_id:
                    video_id, marker_num, marker_type, frame_num, x_pos, y_pos = file[0], int(file[1]), int(file[2]), int(file[3]), int(file[4]), int(file[5])
                    current_seq = [video_id, marker_num, marker_type, frame_num, x_pos, y_pos]
                    sequence.append(current_seq)
    return sequence

In [3]:
def sort_labels(sequence):
    sequence.sort(key=lambda x: x[3]) ## Sort by frame number
    sequence.sort(key=lambda x: x[1]) ## Sort by marker_num
    sequence.sort(key=lambda x: x[0]) ## Sort by video_name
    return sequence

In [4]:
def assign_labels(sequence):
    ##  ball change or vid change
    prev_vid_id, prev_marker_num = sequence[0][0][0], sequence[0][0][1]
    for idx, seq in enumerate(sequence):
        vid_id, marker_num = seq[0], seq[1]
        if (vid_id != prev_vid_id) or (marker_num != prev_marker_num):
            prev_vid_id, prev_marker_num = vid_id, marker_num
            prev_coords, prev_frame_num = np.array([-1,-1]), -1
        frame_num, x_pos, y_pos = seq[3], seq[4], seq[5]
        current_coords = np.array([x_pos, y_pos])
        if (prev_coords[0] == -1) & (prev_coords[1] == -1):
            dist = -1
        else:
            dist = np.linalg.norm(current_coords - prev_coords)
        if (prev_frame_num == -1):
            frame_diff = -1
        else:
            frame_diff = frame_num - prev_frame_num
        prev_coords = current_coords
        prev_frame_num = frame_num
        sequence[idx].append(dist)
        sequence[idx].append(frame_diff)
    return sequence

In [5]:
def load_video(video_path, video_name, flip, display = False): ## Convert 3rd element in video name into flip
    width, height = 960, 540
    cap = cv.VideoCapture(join(video_path,video_name))
    ret, frame = cap.read()
    if (flip):
        frame = cv.flip(frame, 0)
    clone = cv.resize(frame, (width,height))
    if (display):
        cv.namedWindow("Video")
    frame_num = 0
    frames = []
    while (ret):
        if (display):
            cv.imshow("Video", clone)
        frames.append( [frame, frame_num] )
        if (display):
            key = cv.waitKey(0)
            if key == 113:
                break
        ret, frame = cap.read()
        if (ret):
            frame_num += 1
            if (flip):
                frame = cv.flip(frame, 0)
            clone = cv.resize(frame, (width, height))
    print("Frames: {}".format(frames[-1][1] + 1))
    cap.release()
    if (display):
        cv.destroyAllWindows()
    return frames

In [12]:
def load_data_and_labels(sequence, vid_format=".avi"):
    print("Loading videos")
    filedir = join(os.getcwd(),"resources")
    video_recorded = []
    for file in os.listdir(filedir):
        video_id = file.split(".")[0]
        video_recorded.append(video_id)
    video_annotated = list(sorted(set([i[0] for i in sequence])))
    video_data = []
    """ Checking the videos annotated is in the video recorded """
    for rec in video_recorded:
        if rec in video_annotated:
            data = []
            print("Found: {}".format(rec))
            labels = [i for i in sequence if i[0] == rec]
            labels.sort(key=lambda x: x[3]) ## Sort by frame number
            labels = np.asarray(labels)[:,1:].astype('float32')
            frames = load_video(filedir, rec + vid_format, False)
            frames = np.asarray(frames)
            print(frames.shape)
            prev_frame = 0
            frame_labels = []
            print(len(labels))
            for idx, label in enumerate(labels):
                print("idx: {}, label: {}, curr_frame = {}".format(idx, label, int(label[2])))
                curr_frame = int(label[2])
                if (curr_frame != prev_frame):
                    data.append([frames[prev_frame][0], frame_labels])                    
                    frame_labels = [label]
                elif idx == len(labels) - 1:
                    frame_labels.append(label)
                    data.append([frames[curr_frame][0],frame_labels])
                else:
                    frame_labels.append(label)
                prev_frame = int(curr_frame)
            video_data.append(data)
    print("Search Complete")
    return video_data    

In [13]:
def data_to_np(data):
    data_np = np.asarray(data)
    x_values = []
    y_values = []
    
    for i in range(len(data_np)):
        print("Frames in video {}: {}".format(i,len(data_np[i])))            
    
    for vid_pos in range(len(data_np)):
        ##print("vid: {}".format(vid_pos))
        for frame_pos in range(len(data_np[vid_pos])):
            ##print("frame num: {}".format(frame_pos))
            x_np = data_np[vid_pos][frame_pos][0]
            x_shape = list(x_np[0].shape)
            x_shape[:0] = [len(x_np)]
            x_np = np.concatenate(x_np).reshape(x_shape)
            x_values.append(x_np)

            y_np = data_np[vid_pos][frame_pos][1]
            y_np = np.asarray(y_np)
            zero_np = np.zeros((16,7))
            if y_np.shape != (0,):
                zero_np[:y_np.shape[0],:y_np.shape[1]] = y_np
            y_values.append(zero_np)
        
    x_img = np.asarray(x_values)
    y_values = np.asarray(y_values)
    x_diff = y_values[:, :, 5:]
    y_cords = y_values[:, :, 3:5]
    print("x_img shape: {}, x_diff shape: {}, y_values shape: {}".format(x_img.shape, x_diff.shape, y_cords.shape))
    return x_img, x_diff, y_cords

In [14]:
def load_data():
    seq = load_labels()
    seq = sort_labels(seq)
    seq = assign_labels(seq)
    data = load_data_and_labels(seq)
    x_img, x_diff, y = data_to_np(data)
    return x_img, x_diff, y

In [15]:
x_img, x_diff, y = load_data()

Loading videos
Found: v0-0
Frames: 150
(150, 2)
450
idx: 0, label: [  0.   1.   0. 645. 531.  -1.  -1.], curr_frame = 0
idx: 1, label: [  1.   2.   0. 735. 719.  -1.  -1.], curr_frame = 0
idx: 2, label: [  2.   1.   0. 778. 883.  -1.  -1.], curr_frame = 0
idx: 3, label: [  0.   1.   1. 645. 531.   0.   1.], curr_frame = 1
idx: 4, label: [  1.   2.   1. 735. 720.   1.   1.], curr_frame = 1
idx: 5, label: [  2.   1.   1. 776. 883.   2.   1.], curr_frame = 1
idx: 6, label: [  0.          1.          2.        646.        532.          1.4142135
   1.       ], curr_frame = 2
idx: 7, label: [  1.   2.   2. 735. 720.   0.   1.], curr_frame = 2
idx: 8, label: [  2.   1.   2. 779. 883.   3.   1.], curr_frame = 2
idx: 9, label: [  0.   1.   3. 646. 532.   0.   1.], curr_frame = 3
idx: 10, label: [  1.   2.   3. 735. 720.   0.   1.], curr_frame = 3
idx: 11, label: [  2.   1.   3. 779. 882.   1.   1.], curr_frame = 3
idx: 12, label: [  0.   1.   4. 646. 532.   0.   1.], curr_frame = 4
idx: 13, la

Frames: 150
(150, 2)
450
idx: 0, label: [  0.   1.   0. 580. 381.  -1.  -1.], curr_frame = 0
idx: 1, label: [  1.   2.   0. 659. 579.  -1.  -1.], curr_frame = 0
idx: 2, label: [  2.   1.   0. 686. 743.  -1.  -1.], curr_frame = 0
idx: 3, label: [  0.   1.   1. 580. 381.   0.   1.], curr_frame = 1
idx: 4, label: [  1.   2.   1. 659. 579.   0.   1.], curr_frame = 1
idx: 5, label: [  2.          1.          1.        687.        744.          1.4142135
   1.       ], curr_frame = 1
idx: 6, label: [  0.   1.   2. 580. 381.   0.   1.], curr_frame = 2
idx: 7, label: [  1.   2.   2. 659. 578.   1.   1.], curr_frame = 2
idx: 8, label: [  2.   1.   2. 687. 744.   0.   1.], curr_frame = 2
idx: 9, label: [  0.   1.   3. 580. 381.   0.   1.], curr_frame = 3
idx: 10, label: [  1.   2.   3. 659. 578.   0.   1.], curr_frame = 3
idx: 11, label: [  2.   1.   3. 687. 744.   0.   1.], curr_frame = 3
idx: 12, label: [  0.   1.   4. 580. 381.   0.   1.], curr_frame = 4
idx: 13, label: [  1.   2.   4. 659. 5

idx: 213, label: [  0.   1.  71. 598. 495.   2.   1.], curr_frame = 71
idx: 214, label: [  1.   2.  71. 786. 574.   0.   1.], curr_frame = 71
idx: 215, label: [  2.   1.  71. 769. 741.   0.   1.], curr_frame = 71
idx: 216, label: [  0.         1.        72.       597.       497.         2.236068
   1.      ], curr_frame = 72
idx: 217, label: [  1.   2.  72. 786. 574.   0.   1.], curr_frame = 72
idx: 218, label: [  2.   1.  72. 770. 741.   1.   1.], curr_frame = 72
idx: 219, label: [  0.         1.        73.       595.       501.         4.472136
   1.      ], curr_frame = 73
idx: 220, label: [  1.          2.         73.        785.        575.          1.4142135
   1.       ], curr_frame = 73
idx: 221, label: [  2.   1.  73. 770. 741.   0.   1.], curr_frame = 73
idx: 222, label: [  0.         1.        74.       593.       505.         4.472136
   1.      ], curr_frame = 74
idx: 223, label: [  1.   2.  74. 779. 575.   6.   1.], curr_frame = 74
idx: 224, label: [  2.   1.  74. 770. 74

Frames: 150
(150, 2)
343
idx: 0, label: [  1.   2.  33.   7. 651.  -1.  -1.], curr_frame = 33
idx: 1, label: [  1.         2.        34.        17.       648.        10.440307
   1.      ], curr_frame = 34
idx: 2, label: [  1.        2.       35.       28.      646.       11.18034   1.     ], curr_frame = 35
idx: 3, label: [  0.   1.  36.   3. 320.  -1.  -1.], curr_frame = 36
idx: 4, label: [  1.          2.         36.         36.        643.          8.5440035
   1.       ], curr_frame = 36
idx: 5, label: [  0.          1.         37.          7.        319.          4.1231055
   1.       ], curr_frame = 37
idx: 6, label: [  1.         2.        37.        43.       642.         7.071068
   1.      ], curr_frame = 37
idx: 7, label: [  0.          1.         38.         12.        317.          5.3851647
   1.       ], curr_frame = 38
idx: 8, label: [  1.          2.         38.         49.        640.          6.3245554
   1.       ], curr_frame = 38
idx: 9, label: [  2.   1.  38.  1

Frames: 150
(150, 2)
415
idx: 0, label: [  0.   1.   0.  21. 139.  -1.  -1.], curr_frame = 0
idx: 1, label: [  1.   2.   0.  17. 432.  -1.  -1.], curr_frame = 0
idx: 2, label: [  0.   1.   1.  21. 140.   1.   1.], curr_frame = 1
idx: 3, label: [  1.   2.   1.  17. 433.   1.   1.], curr_frame = 1
idx: 4, label: [  0.   1.   2.  27. 140.   6.   1.], curr_frame = 2
idx: 5, label: [  1.          2.          2.         22.        432.          5.0990195
   1.       ], curr_frame = 2
idx: 6, label: [  0.   1.   3.  30. 140.   3.   1.], curr_frame = 3
idx: 7, label: [  1.   2.   3.  22. 432.   0.   1.], curr_frame = 3
idx: 8, label: [  0.   1.   4.  28. 140.   2.   1.], curr_frame = 4
idx: 9, label: [  1.   2.   4.  23. 432.   1.   1.], curr_frame = 4
idx: 10, label: [  0.   1.   5.  28. 140.   0.   1.], curr_frame = 5
idx: 11, label: [  1.   2.   5.  23. 432.   0.   1.], curr_frame = 5
idx: 12, label: [  0.   1.   6.  30. 140.   2.   1.], curr_frame = 6
idx: 13, label: [  1.   2.   6.  23. 4

   1.      ], curr_frame = 104
idx: 279, label: [  2.   1. 104. 814. 757.   1.   1.], curr_frame = 104
idx: 280, label: [  0.   1. 105. 827. 202.  13.   1.], curr_frame = 105
idx: 281, label: [  1.         2.       105.       852.       539.         8.246211
   1.      ], curr_frame = 105
idx: 282, label: [  2.   1. 105. 818. 760.   5.   1.], curr_frame = 105
idx: 283, label: [  0.         1.       106.       841.       203.        14.035668
   1.      ], curr_frame = 106
idx: 284, label: [  1.         2.       106.       858.       542.         6.708204
   1.      ], curr_frame = 106
idx: 285, label: [  2.   1. 106. 819. 760.   1.   1.], curr_frame = 106
idx: 286, label: [  0.         1.       107.       855.       205.        14.142136
   1.      ], curr_frame = 107
idx: 287, label: [  1.          2.        107.        862.        543.          4.1231055
   1.       ], curr_frame = 107
idx: 288, label: [  2.   1. 107. 820. 760.   1.   1.], curr_frame = 107
idx: 289, label: [  0.     

Frames: 150
(150, 2)
176
idx: 0, label: [ 2.000e+00  2.000e+00  0.000e+00  1.394e+03  3.270e+02 -1.000e+00
 -1.000e+00], curr_frame = 0
idx: 1, label: [   2.        2.       64.        3.      505.     1402.3427   64.    ], curr_frame = 64
idx: 2, label: [  3.   2.  64.   8. 790.  -1.  -1.], curr_frame = 64
idx: 3, label: [  2.         2.        65.        30.       508.        27.166155
   1.      ], curr_frame = 65
idx: 4, label: [  3.        2.       65.       41.      792.       33.06055   1.     ], curr_frame = 65
idx: 5, label: [  2.         2.        66.        56.       512.        26.305893
   1.      ], curr_frame = 66
idx: 6, label: [  3.        2.       66.       77.      790.       36.05551   1.     ], curr_frame = 66
idx: 7, label: [  2.         2.        67.        79.       516.        23.345236
   1.      ], curr_frame = 67
idx: 8, label: [  3.         2.        67.        93.       794.        16.492422
   1.      ], curr_frame = 67
idx: 9, label: [  2.         2.    

Frames: 150
(150, 2)
353
idx: 0, label: [  0.   1.   0. 425. 340.  -1.  -1.], curr_frame = 0
idx: 1, label: [  1.   1.   0. 352. 932.  -1.  -1.], curr_frame = 0
idx: 2, label: [  2.   2.   0. 756. 540.  -1.  -1.], curr_frame = 0
idx: 3, label: [  3.   2.   0. 615. 847.  -1.  -1.], curr_frame = 0
idx: 4, label: [  0.         1.         1.       464.       342.        39.051247
   1.      ], curr_frame = 1
idx: 5, label: [  1.   1.   1. 386. 932.  34.   1.], curr_frame = 1
idx: 6, label: [  2.         2.         1.       772.       548.        17.888544
   1.      ], curr_frame = 1
idx: 7, label: [  3.         2.         1.       637.       852.        22.561028
   1.      ], curr_frame = 1
idx: 8, label: [1.0000000e+00 1.0000000e+00 2.0000000e+00 1.1330000e+03 9.3500000e+02
 7.4700604e+02 1.0000000e+00], curr_frame = 2
idx: 9, label: [2.0000000e+00 2.0000000e+00 2.0000000e+00 1.0810000e+03 5.9100000e+02
 3.1197757e+02 1.0000000e+00], curr_frame = 2
idx: 10, label: [  3.        2.       

idx: 260, label: [  2.        2.       83.      941.      613.       75.50497   1.     ], curr_frame = 83
idx: 261, label: [  0.       1.      84.     825.     329.     134.3019   5.    ], curr_frame = 84
idx: 262, label: [  1.          1.         84.        873.        943.          3.1622777
   1.       ], curr_frame = 84
idx: 263, label: [2.000000e+00 2.000000e+00 8.400000e+01 1.014000e+03 5.940000e+02
 7.543209e+01 1.000000e+00], curr_frame = 84
idx: 264, label: [  3.       2.      84.     999.     828.     330.6055   6.    ], curr_frame = 84
idx: 265, label: [  0.   1.  85. 855. 329.  30.   1.], curr_frame = 85
idx: 266, label: [  1.   1.  85. 874. 943.   1.   1.], curr_frame = 85
idx: 267, label: [2.000000e+00 2.000000e+00 8.500000e+01 1.090000e+03 5.770000e+02
 7.787811e+01 1.000000e+00], curr_frame = 85
idx: 268, label: [3.0000000e+00 2.0000000e+00 8.5000000e+01 1.0270000e+03 8.2300000e+02
 2.8442924e+01 1.0000000e+00], curr_frame = 85
idx: 269, label: [  0.         1.        8

Frames: 150
(150, 2)
248
idx: 0, label: [  1.   1.  56.  91. 966.  -1.  -1.], curr_frame = 56
idx: 1, label: [  1.        1.       57.      139.      963.       48.09366   1.     ], curr_frame = 57
idx: 2, label: [ 4.  1. 67. 19. 66. -1. -1.], curr_frame = 67
idx: 3, label: [  3.   1.  68.   7. 573.  -1.  -1.], curr_frame = 68
idx: 4, label: [ 4.        1.       68.       62.       69.       43.104523  1.      ], curr_frame = 68
idx: 5, label: [  3.         1.        69.        21.       571.        14.142136
   1.      ], curr_frame = 69
idx: 6, label: [  4.         1.        69.       101.        78.        40.024994
   1.      ], curr_frame = 69
idx: 7, label: [  3.         1.        70.        38.       570.        17.029387
   1.      ], curr_frame = 70
idx: 8, label: [  4.         1.        70.       143.        84.        42.426407
   1.      ], curr_frame = 70
idx: 9, label: [  3.   1.  71.  54. 570.  16.   1.], curr_frame = 71
idx: 10, label: [  4.         1.        71.       

Frames: 150
(150, 2)
346
idx: 0, label: [  0.   2.   0. 209.  30.  -1.  -1.], curr_frame = 0
idx: 1, label: [  1.   2.   0. 211.  32.  -1.  -1.], curr_frame = 0
idx: 2, label: [  0.   2.   1. 209.  30.   0.   1.], curr_frame = 1
idx: 3, label: [  0.   2.   2. 209.  30.   0.   1.], curr_frame = 2
idx: 4, label: [  0.   2.   3. 209.  30.   0.   1.], curr_frame = 3
idx: 5, label: [  0.   2.   4. 209.  30.   0.   1.], curr_frame = 4
idx: 6, label: [  0.   2.   5. 209.  30.   0.   1.], curr_frame = 5
idx: 7, label: [  0.   2.   6. 209.  30.   0.   1.], curr_frame = 6
idx: 8, label: [  0.   2.   7. 209.  30.   0.   1.], curr_frame = 7
idx: 9, label: [  0.   2.   8. 209.  30.   0.   1.], curr_frame = 8
idx: 10, label: [  0.   2.   9. 209.  30.   0.   1.], curr_frame = 9
idx: 11, label: [  0.   2.  10. 209.  30.   0.   1.], curr_frame = 10
idx: 12, label: [  0.   2.  11. 209.  30.   0.   1.], curr_frame = 11
idx: 13, label: [  0.   2.  12. 209.  30.   0.   1.], curr_frame = 12
idx: 14, label: 

 2.3769728e+01 1.0000000e+00], curr_frame = 103
idx: 275, label: [4.000000e+00 1.000000e+00 1.030000e+02 1.362000e+03 9.450000e+02
 2.236068e+00 1.000000e+00], curr_frame = 103
idx: 276, label: [5.000e+00 1.000e+00 1.030e+02 1.524e+03 6.160e+02 1.100e+01 1.000e+00], curr_frame = 103
idx: 277, label: [  0.   2. 104. 211.  32.   0.   1.], curr_frame = 104
idx: 278, label: [1.000000e+00 2.000000e+00 1.040000e+02 1.552000e+03 2.060000e+02
 3.238827e+01 1.000000e+00], curr_frame = 104
idx: 279, label: [4.0000000e+00 1.0000000e+00 1.0400000e+02 1.3660000e+03 9.4400000e+02
 4.1231055e+00 1.0000000e+00], curr_frame = 104
idx: 280, label: [5.0000000e+00 1.0000000e+00 1.0400000e+02 1.5380000e+03 6.1800000e+02
 1.4142136e+01 1.0000000e+00], curr_frame = 104
idx: 281, label: [  0.   2. 105. 211.  32.   0.   1.], curr_frame = 105
idx: 282, label: [1.0000000e+00 2.0000000e+00 1.0500000e+02 1.5800000e+03 2.1100000e+02
 2.8442924e+01 1.0000000e+00], curr_frame = 105
idx: 283, label: [4.000000e+00 1.00

MemoryError: 

In [None]:
# need to normalise y values
# need to build a model that includes distance and frame diff
# model predicting x and y values

# going to build model that predicts coords from image

In [None]:
def normalise_img(x_values):
    return x_values / 255

In [None]:
def cnn_prepare(x_value, y_value):
    #x_val = np.reshape(x_value,(x_value.shape[0] * x_value.shape[1], 
    #                            x_value.shape[2], 
    #                            x_value.shape[3], 
    #                            x_value.shape[4]))
    x_val = normalise_img(x_value)
    y_val = np.reshape(y_value,(y_value.shape[0],
                                y_value.shape[1] * y_value.shape[2]))
    return x_val, y_val

In [None]:
def random_np(x_np, y_np):
    prng = RandomState(0)
    randomise = prng.permutation(x_np.shape[0])
    x_np = x_np[randomise]
    y_np = y_np[randomise]
    return x_np, y_np

In [None]:
def split_np(x_data, y_data, percent):
    """ splits a numpy array into testing and training """
    position = int(len(x_data) * (1-percent))
    x_train, x_test = x_data[:position], x_data[position:]
    y_train, y_test = y_data[:position], y_data[position:]
    print('x_train shape: {}, x_test shape: {}'.format(x_train.shape,x_test.shape))
    print('y_train shape: {}, y_test shape: {}'.format(y_train.shape,y_test.shape))
    return x_train, y_train, x_test, y_test

In [None]:

x_img, y = random_np(x_img, y)

x_cnn, y_cnn = cnn_prepare(x_img, y)
#x_cnn, y_cnn = x_img, y
x_train, y_train, x_test, y_test = split_np(x_cnn, y_cnn, 0.2)

## remove useless variables
x_img, x_diff, y, x_cnn, y_cnn = None, None, None, None, None

In [None]:
## prepare data to pickle
import pickle

data = [x_train, y_train, x_test, y_test]
pickle.dump(data, open("./aws/data_rand.p","wb"))
print("Finished")

In [None]:
input_shape = x_train.shape[1:]
output_shape = y_train.shape[1]

model = Sequential()
model.add(Conv2D(32, kernel_size=(3,3),
                activation='relu',
                input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(32, (3, 3),activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(16, (3, 3),activation='relu'))
model.add(MaxPooling2D((2,2), strides=(2,2)))
model.add(Flatten())
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(32, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(output_shape))
model.compile(loss='mean_squared_error', 
              optimizer='adam',
              metrics=['accuracy'])

In [None]:
monitor = EarlyStopping(monitor='val_loss', min_delta=1e-3, patience=5, verbose=1, mode='auto')
checkpointer = ModelCheckpoint(filepath="dnn/tmp_best_weights.hdf5", verbose=0, save_best_only=True) # save best model

batch_size = 4
epochs = 1000
import time
start_time = time.time()

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test),
          callbacks=[monitor,checkpointer])
model.load_weights('dnn/tmp_best_weights.hdf5') # load weights from best model


save_dir = join(os.getcwd(),"dnn")
save_path = join(save_dir,str(int(start_time)) + "_cnn.h5")
model.save(save_path)

score = model.evaluate(x_test, y_test, verbose=2)
print('Test loss: {}'.format(score[0]))
print('Test accuracy: {}'.format(score[1]))

elapsed_time = time.time() - start_time
print("Elapsed time: {}".format(hms_string(elapsed_time)))