Skip to content

Commit

Permalink
Added net.to_array, .from_array
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Aug 7, 2017
1 parent 2c09b63 commit c788f5f
Showing 1 changed file with 48 additions and 3 deletions.
51 changes: 48 additions & 3 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,9 +1648,54 @@ def pf(self, vector, **opts):
truncated = len(vector) > max_length
return "[" + ", ".join([("%." + str(precision) + "f") % v for v in vector[:max_length]]) + ("..." if truncated else "") + "]"

## FIXME: add these:
#def to_array(self):
#def from_array(self):
def to_array(self):
"""
Get the weights of a network as a flat, one-dimensional list.
Example:
>>> from conx import Network
>>> net = Network("Deep", 3, 4, 5, 2, 3, 4, 5)
>>> net.compile(optimizer="adam", error="mse")
>>> array = net.to_array()
>>> len(array)
103
Returns:
list: All of weights in a single, flat list.
"""
array = []
for layer in self.model.layers:
for weight in layer.get_weights():
array.extend(weight.flatten())
return array

def from_array(self, array):
"""
Load the weights from a list.
Args:
array (list) - a sequence (e.g., list, np.array) of numbers
Example:
>>> from conx import Network
>>> net = Network("Deep", 3, 4, 5, 2, 3, 4, 5)
>>> net.compile(optimizer="adam", error="mse")
>>> net.from_array([0] * 103)
>>> array = net.to_array()
>>> len(array)
103
"""
position = 0
for layer in self.model.layers:
weights = layer.get_weights()
new_weights = []
for i in range(len(weights)):
w = weights[i]
size = reduce(operator.mul, w.shape)
new_w = np.array(array[position:position + size]).reshape(w.shape)
new_weights.append(new_w)
position += size
layer.set_weights(new_weights)

class InterruptHandler():
"""
Expand Down

0 comments on commit c788f5f

Please sign in to comment.