Skip to content

Commit

Permalink
Load and save datasets to disk
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 14, 2018
1 parent 90fca34 commit 2645283
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
29 changes: 22 additions & 7 deletions conx/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import types
import keras
import operator
import os

from .utils import *
import conx.datasets
Expand Down Expand Up @@ -834,21 +835,35 @@ def save_to_disk(self, dir, *args, **kwargs):
"""
Save the dataset into the given directory.
"""
import h5py
if not os.path.isdir(dir):
os.makedirs(dir)
np.save(os.path.join(dir, "inputs.npy"), self._inputs, *args, **kwargs)
np.save(os.path.join(dir, "targets.npy"), self._targets, *args, **kwargs)
np.save(os.path.join(dir, "labels.npy"), self._labels, *args, **kwargs)
with h5py.File(os.path.join(dir, "dataset.h5"), "w") as h5:
h5.create_dataset('inputs', data=self._inputs, compression='gzip', compression_opts=9)
h5.create_dataset('targets', data=self._targets, compression='gzip', compression_opts=9)
if len(self._labels) > 0:
string = h5py.special_dtype(vlen=str)
if isinstance(self._labels, np.ndarray) and self._labels.dtype != string:
labels = self._labels.astype(string)
else:
labels = self._labels
h5.create_dataset('labels', data=labels, compression='gzip', compression_opts=9)

def load_from_disk(self, dir, *args, **kwargs):
"""
Load the dataset from the given directory.
"""
import h5py
loaded = False
if os.path.exists(os.path.join(dir, "inputs.npy")):
self._inputs = np.load(os.path.join(dir, "inputs.npy"), *args, **kwargs)
self._targets = np.load(os.path.join(dir, "targets.npy"), *args, **kwargs)
self._labels = np.load(os.path.join(dir, "labels.npy"), *args, **kwargs)
path = os.path.join(dir, "dataset.h5")
if os.path.exists(path):
h5 = h5py.File(path, "r")
self._inputs = h5["inputs"]
self._targets = h5["targets"]
self._labels = h5["labels"]
self.h5 = h5
self.name = dir
#self.description = description
self._cache_values()
loaded = True
return loaded
Expand Down
2 changes: 1 addition & 1 deletion conx/datasets/_mnist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import conx as cx
import numpy as np
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.utils import (to_categorical, get_file)

description = """
Original source: http://yann.lecun.com/exdb/mnist/
Expand Down
9 changes: 6 additions & 3 deletions conx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,14 +2095,17 @@ def shape(item, summary=False):
else:
return retval

def load_data(filename, *args, **kwargs):
def load_data(filename, return_h5=False, *args, **kwargs):
"""
Load a numpy or h5 datafile.
"""
import h5py
if filename.endswith("h5"):
with h5py.File(filename, 'r') as h5:
return h5["data"]
h5 = h5py.File(filename, 'r')
if return_h5:
return h5
else:
return h5["data"][:]
else:
return np.load(filename, *args, **kwargs)

Expand Down

0 comments on commit 2645283

Please sign in to comment.