Skip to content

Commit

Permalink
Beta for saving/loading entire networks
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 13, 2018
1 parent abb93d5 commit 46d5d95
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 98 deletions.
3 changes: 3 additions & 0 deletions conx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
# Boston, MA 02110-1301 USA

import warnings
## Useless numpy warnings:
warnings.simplefilter(action='ignore', category=FutureWarning)
## When a model has not yet been compiled:
warnings.filterwarnings("ignore", "No training configuration found in save file.*")

import sys
import os
Expand Down
73 changes: 30 additions & 43 deletions conx/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,10 @@
except:
pass # won't turn Keras comments into rft for documentation


#------------------------------------------------------------------------
def make_layer(state):
if state["class"] == "Layer":
return Layer(state["name"], state["shape"], **state["params"])
elif state["class"] == "ImageLayer":
return ImageLayer(state["name"], state["dimension"], **state["params"])
elif state["class"] == "EmbeddingLayer":
return EmbeddingLayer(state["name"], state["in_size"], **state["params"])
else:
return eval("%s(%s, *%s, **%s)" % (state["class"], state["name"],
state["args"], state["params"]))
#------------------------------------------------------------------------
def make_layer(config):
import conx.layers
layer = getattr(conx.layers, config["class"])
return layer(config["name"], *config["args"], **config["params"])

class _BaseLayer():
"""
Expand All @@ -81,7 +72,7 @@ class _BaseLayer():
CLASS = None

def __init__(self, name, *args, **params):
self._state = {
self.config = {
"class": self.__class__.__name__,
"name": name,
"args": args,
Expand Down Expand Up @@ -225,9 +216,6 @@ def _check_layer_name(self, layer_name):
if layer_name.count("%d") not in [0, 1]:
raise Exception("layer name must contain at most one %%d: '%s'" % layer_name)

def __getstate__(self):
return self._state

def on_connect(self, relation, other_layer):
"""
relation is "to"/"from" indicating which layer self is.
Expand Down Expand Up @@ -555,14 +543,13 @@ class Layer(_BaseLayer):
"""
CLASS = keras.layers.Dense
def __init__(self, name: str, shape, **params):
_state = {
"class": "Layer",
super().__init__(name, **params)
self.config.update({
"class": self.__class__.__name__,
"name": name,
"shape": shape,
"args": [shape],
"params": copy.copy(params),
}
super().__init__(name, **params)
self._state = _state
})
if not valid_shape(shape):
raise Exception('bad shape: %s' % (shape,))
# set layer topology (shape) and number of units (size)
Expand Down Expand Up @@ -622,17 +609,15 @@ class ImageLayer(Layer):
A class for images. WIP.
"""
def __init__(self, name, dimensions, depth, **params):
_state = {
"class": self.__class__.__name__,
"name": name,
"dimensions": dimensions,
"depth": depth,
"params": copy.copy(params),
}
## get value before processing
keep_aspect_ratio = params.get("keep_aspect_ratio", True)
super().__init__(name, dimensions, **params)
self._state = _state
self.config.update({
"class": self.__class__.__name__,
"name": name,
"args": [dimensions, depth],
"params": copy.copy(params),
})
if self.vshape is None:
self.vshape = self.shape
## override defaults set in constructor:
Expand Down Expand Up @@ -674,14 +659,13 @@ class AddLayer(_BaseLayer):
CLASS = keras.layers.Add
def __init__(self, name, **params):
self.layers = []
_state = {
super().__init__(name)
self.config.update({
"class": self.__class__.__name__,
"name": name,
"layers": self.layers,
"args": [],
"params": copy.copy(params),
}
super().__init__(name)
self._state = _state
})
self.handle_merge = True

def make_keras_functions(self):
Expand Down Expand Up @@ -768,8 +752,13 @@ def make_keras_function(self):
class LambdaLayer(Layer):
CLASS = keras.layers.Lambda
def __init__(self, name, size, function, **params):
params["function"] = function
super().__init__(name, size, **params)
self.config.update({
"class": self.__class__.__name__,
"name": name,
"args": [size, function],
"params": copy.copy(params),
})

def make_keras_function(self):
"""
Expand All @@ -788,15 +777,13 @@ class EmbeddingLayer(Layer):
A class for embeddings. WIP.
"""
def __init__(self, name, in_size, out_size, **params):
_state = {
super().__init__(name, in_size, **params)
self.config.update({
"class": self.__class__.__name__,
"name": name,
"in_size": in_size,
"out_size": out_size,
"args": [in_size, out_size],
"params": copy.copy(params),
}
super().__init__(name, in_size, **params)
self._state = _state
})
if self.vshape is None:
self.vshape = self.shape
self.in_size = in_size
Expand Down

0 comments on commit 46d5d95

Please sign in to comment.