Skip to content

Commit

Permalink
Network.propagate can take an image
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Aug 9, 2017
1 parent a34091b commit 5d32ae0
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import base64
import copy
import io
import PIL

import numpy as np
import keras
Expand Down Expand Up @@ -373,9 +374,9 @@ def _cache_dataset_values(self):
assert self.get_inputs_length() == self.get_targets_length(), "inputs/targets lengths do not match"

## FIXME: add these for users' convenience:
#def image_to_channels_last(self, matrix):
#def matrix_to_channels_last(self, matrix): ## vecteor
# x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
#def image_to_channels_first(self, matrix):
#def matrix_to_channels_first(self, matrix):
# x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)

def set_dataset(self, pairs, verbose=True):
Expand All @@ -387,7 +388,7 @@ def set_dataset(self, pairs, verbose=True):
Note:
If you have images in your dataset, they must match K.image_data_format().
See also :any:`image_to_channels_last` and :any:`image_to_channels_first`.
See also :any:`matrix_to_channels_last` and :any:`matrix_to_channels_first`.
"""
## Either the inputs/targets are a list of a list -> np.array(...) (np.array() of vectors)
## or are a list of list of list -> [np.array(), np.array()] (list of np.array cols of vectors)
Expand Down Expand Up @@ -1002,10 +1003,17 @@ def propagate(self, input, batch_size=32):
Propagate an input (in human API) through the network.
If visualizing, the network image will be updated.
"""
import keras.backend as K
if isinstance(input, dict):
input = [input[name] for name in self.get_input_layer_order()]
if self.num_input_layers == 1:
input = input[0]
elif isinstance(input, PIL.Image.Image):
input = np.array(input)
if len(input.shape) == 2:
input = input.reshape(input.shape + (1,))
if K.image_data_format() == 'channels_first':
input = self.matrix_to_channels_first(input)
if self.num_input_layers == 1:
outputs = list(self.model.predict(np.array([input]), batch_size=batch_size)[0])
else:
Expand Down

0 comments on commit 5d32ae0

Please sign in to comment.