Skip to content

Commit

Permalink
If sequence=True, turn the np.arrays() into lists
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 10, 2018
1 parent f64014b commit 2d59da7
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,7 +1717,10 @@ def propagate(self, input, batch_size=32, class_id=None,
outputs = self.model.predict(inputs, batch_size=batch_size)
## Shape the outputs:
if sequence:
pass
if isinstance(outputs, list):
outputs = [bank.tolist() for bank in outputs]
else:
outputs = outputs.tolist()
elif self.num_target_layers == 1:
shape = self[self.output_bank_order[0]].shape
try:
Expand Down Expand Up @@ -1807,6 +1810,10 @@ def propagate_from(self, layer_name, input, output_layer_names=None,
if self.debug: print("propagate_from 2: class_id_name:", class_id_name)
self._comm.send({'class': class_id_name, "xlink:href": data_uri})
if sequence:
if isinstance(outputs, list):
outputs = [bank.tolist() for bank in outputs]
else:
outputs = outputs.tolist()
return outputs
index = 0
for layer_name in output_layer_names:
Expand Down Expand Up @@ -1932,6 +1939,10 @@ def propagate_to(self, layer_name, inputs, batch_size=32, class_id=None,
self._comm.send({'class': class_id_name, "xlink:href": data_uri})
## Shape the outputs:
if sequence:
if isinstance(outputs, list):
outputs = [bank.tolist() for bank in outputs]
else:
outputs = outputs.tolist()
return outputs
shape = self[layer_name].shape
if shape and all([isinstance(v, numbers.Integral) for v in shape]):
Expand Down

0 comments on commit 2d59da7

Please sign in to comment.