Skip to content

Commit

Permalink
Correct names for npy files; can now save/load datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 13, 2018
1 parent 5230145 commit c142194
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
23 changes: 23 additions & 0 deletions conx/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,29 @@ def load_direct(self, inputs=None, targets=None, labels=None):
self._labels = labels # should be a list of np.arrays(dtype=str), one per bank
self._cache_values()

def save_to_disk(self, dir, *args, **kwargs):
"""
Save the dataset into the given directory.
"""
if not os.path.isdir(dir):
os.makedirs(dir)
np.save(os.path.join(dir, "inputs.npy"), self._inputs, *args, **kwargs)
np.save(os.path.join(dir, "targets.npy"), self._targets, *args, **kwargs)
np.save(os.path.join(dir, "labels.npy"), self._labels, *args, **kwargs)

def load_from_disk(self, dir, *args, **kwargs):
"""
Load the dataset from the given directory.
"""
loaded = False
if os.path.exists(os.path.join(dir, "inputs.npy")):
self._inputs = np.load(os.path.join(dir, "inputs.npy"), *args, **kwargs)
self._targets = np.load(os.path.join(dir, "targets.npy"), *args, **kwargs)
self._labels = np.load(os.path.join(dir, "labels.npy"), *args, **kwargs)
self._cache_values()
loaded = True
return loaded

def load(self, pairs=None, inputs=None, targets=None, labels=None):
"""
Dataset.load() will clear and load a new dataset.
Expand Down
28 changes: 24 additions & 4 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3889,7 +3889,7 @@ def saved(self, dir=None):
return (os.path.isdir(dir) and
os.path.isfile("%s/network.pickle" % dir) and
os.path.isfile("%s/model.h5" % dir) and
os.path.isfile("%s/weights.h5" % dir))
os.path.isfile("%s/weights.npy" % dir))

def delete(self, dir=None):
"""
Expand Down Expand Up @@ -3952,6 +3952,26 @@ def save_model(self, dir=None, filename=None):
else:
raise Exception("need to build network before saving")

def load_dataset(self, dir=None, *args, **kwargs):
"""
Load a dataset from a directory. Returns True if data loaded,
else False.
"""
if dir is None:
dir = "%s.conx" % self.name.replace(" ", "_")
return self.dataset.load_from_disk(dir, *args, **kwargs)

def save_dataset(self, dir=None, *args, **kwargs):
"""
Save the dataset to directory.
"""
if len(self.dataset) > 0:
if dir is None:
dir = "%s.conx" % self.name.replace(" ", "_")
self.dataset.save_to_disk(dir, *args, **kwargs)
else:
raise Exception("need to load a dataset before saving")

def load_history(self, dir=None, filename=None):
"""
Load the history from a dir/file.
Expand Down Expand Up @@ -3997,7 +4017,7 @@ def load_weights(self, dir=None, filename=None):
if dir is None:
dir = "%s.conx" % self.name.replace(" ", "_")
if filename is None:
filename = "weights.h5"
filename = "weights.npy"
full_filename = os.path.join(dir, filename)
if os.path.exists(full_filename):
self.model.load_weights(full_filename)
Expand All @@ -4015,7 +4035,7 @@ def save_weights(self, dir=None, filename=None):
if dir is None:
dir = "%s.conx" % self.name.replace(" ", "_")
if filename is None:
filename = "weights.h5"
filename = "weights.npy"
if not os.path.isdir(dir):
os.makedirs(dir)
self.model.save_weights(os.path.join(dir, filename))
Expand Down Expand Up @@ -4306,7 +4326,7 @@ def save_network(datadir, network):
Saves the network name, layers, conecctions, compile args
to network.pickle.
Saves the weights to weights.h5.
Saves the weights to weights.npy.
Saves the training history to history.pickle.
Saves the network config to config.json.
"""
Expand Down

0 comments on commit c142194

Please sign in to comment.