# Test EZNet with Its Trainer Object

In [1]:
import sys
import numpy as np 
import os

import keras
import tensorflow as tf

sys.path.append('../../../')
import dnn

# import basic plotting
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot

from dnn.keras_models.nets.eznet import EZNet
from dnn.io.readerauxdataset import ReaderEZNetDataset
from dnn.keras_models.trainers.eznet import EZNetTrainer

# Import magic commands for jupyter notebook 
# - autoreloading a module
# - profiling functions for memory usage and scripts
%load_ext autoreload
%autoreload 2
%load_ext line_profiler
%load_ext memory_profiler

Using TensorFlow backend.


In [2]:
reader = ReaderEZNetDataset()

In [3]:
traindir = '/Users/adam2392/Downloads/eznet_traindata/trainpats/'
testdir = '/Users/adam2392/Downloads/eznet_traindata/testpats/'

reader.loadbydir(traindir, testdir)
# print(reader.trainfilepaths)

2018-07-02 17:34:19,827 - INFO - ReaderEZNetDataset - Reading testing data directory /Users/adam2392/Downloads/eznet_traindata/testpats/ 
2018-07-02 17:34:19,831 - INFO - ReaderEZNetDataset - Reading training data directory /Users/adam2392/Downloads/eznet_traindata/trainpats/ 
2018-07-02 17:34:19,846 - INFO - ReaderEZNetDataset - Found 49 training files and 38 testing files.
2018-07-02 17:34:19,847 - INFO - ReaderEZNetDataset - Finished reading in data by directories!


In [4]:
reader.loadfiles(mode='TRAIN')
reader.loadfiles(mode='TEST')

2018-07-02 17:34:19,965 - INFO - ReaderEZNetDataset - Loading files from directory!
(49000, 30, 480)
(49000,)
(49000, 480)
2018-07-02 17:34:27,501 - INFO - ReaderEZNetDataset - Image tensor shape: (3549, 30, 480, 1)
2018-07-02 17:34:27,509 - INFO - ReaderEZNetDataset - Loading files from directory!
(38000, 30, 480)
(38000,)
(38000, 480)
2018-07-02 17:34:36,175 - INFO - ReaderEZNetDataset - Image tensor shape: (3297, 30, 480, 1)


In [5]:
# print(reader.train_dataset)
print(reader.train_dataset)
print(reader.test_dataset)

<dnn.io.dataloaders.baseaux.TrainDataset object at 0x11e12aac8>
<dnn.io.dataloaders.baseaux.TestDataset object at 0x11e12ab00>


# Create Model

In [6]:
# define model
model_params = {
    'length_imsize': 30,
    'width_imsize': 480,
    'num_classes': 2,
    'n_colors': 1,
}

# initialize object
eznet = EZNet(**model_params) 
model = eznet.buildmodel(output=True)

# Create Trainer

In [7]:
num_epochs=1
batch_size=16
outputdir=os.path.join('/Users/adam2392/Downloads/testdnn')
learning_rate=5e-2
shuffle=True
augment=True

In [8]:
trainer = EZNetTrainer(model,
                       num_epochs,
                 batch_size,
                 outputdir,
                 learning_rate,
                 shuffle,
                 augment)

2018-07-02 17:34:39,432 - INFO - EZNetTrainer - Logging output data to: /Users/adam2392/Downloads/testdnn/output
2018-07-02 17:34:39,434 - INFO - EZNetTrainer - Logging experimental data at: /Users/adam2392/Downloads/testdnn/traininglogs
2018-07-02 17:34:39,436 - INFO - EZNetTrainer - Logging tensorboard data at: /Users/adam2392/Downloads/testdnn/tensorboard


In [9]:
trainer.composedatasets(reader.train_dataset, reader.test_dataset)

2018-07-02 17:34:39,528 - INFO - EZNetTrainer - Each training epoch is 3549 steps and each validation is 3297 steps.
2018-07-02 17:34:39,530 - INFO - EZNetTrainer - Setting the datasets for training/testing in trainer object!
2018-07-02 17:34:39,534 - INFO - EZNetTrainer - Image size is (30, 480) with 1 colors


In [10]:
print(reader.test_dataset.X_aux.shape)
print(reader.test_dataset.X_chan.shape)
print(reader.test_dataset.ylabels.shape)

(3297, 30, 480, 1)
(3297, 480, 1)
(3297, 2)


In [11]:
trainer.configure()

In [12]:
trainer.train()

Training data:  (3549, 30, 480, 1) (3549, 2)
Testing data:  (3297, 30, 480, 1) (3297, 2)
Class weights are:  [0.58085106 3.59210526]
class imbalance:  494 3549
Using real-time data augmentation.
Epoch 1/1


StopIteration: local variable 'idx' referenced before assignment

In [None]:
print(trainer.batch_size)