Skip to content

Commit

Permalink
Removed vmnist; made mnist smaller; added new mnist_h5; dataset.set_*…
Browse files Browse the repository at this point in the history
…_from_function
  • Loading branch information
dsblank committed Sep 14, 2018
1 parent a93ded5 commit 90fca34
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 94 deletions.
76 changes: 72 additions & 4 deletions conx/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def select(self, function, slice=None, index=False):
Examples:
>>> import conx as cx
>>> print("Downloading...");ds = cx.Dataset.get("vmnist") # doctest: +ELLIPSIS
>>> print("Downloading...");ds = cx.Dataset.get("mnist") # doctest: +ELLIPSIS
Downloading...
>>> ds.inputs.select(lambda i,dataset: True, slice=10, index=True)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Expand Down Expand Up @@ -1122,12 +1122,12 @@ def get(dataset_name=None, *args, **kwargs):
Must be called from the Dataset class.
>>> import conx as cx
>>> print("Downloading..."); ds = cx.Dataset.get("vmnist") # doctest: +ELLIPSIS
>>> print("Downloading..."); ds = cx.Dataset.get("mnist") # doctest: +ELLIPSIS
Downloading...
>>> len(ds.inputs)
70000
>>> ds.targets[0]
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
>>> ds.clear()
"""
Expand Down Expand Up @@ -1262,6 +1262,74 @@ def warn_once(self, message, category=None):
print(message, file=sys.stderr)
self.warnings[message] = True

def set_labels_from_function(self, function, length=None):
"""
The function should take an index from 0 to length - 1,
where length is given, or, if not, is the length of the
inputs/targets. The function should return a string
that represents the category of the input/target
pair.
>>> ds = Dataset()
>>> ds.load([[[0, 0], [0]],
... [[0, 1], [1]],
... [[1, 0], [1]],
... [[1, 1], [0]]])
>>> ds.set_labels_from_function(lambda i: "0" if ds.targets[i] == [0] else "1")
>>> ds.labels[:]
['0', '1', '1', '0']
"""
length = length if length is not None else len(self)
labels = [function(i) for i in range(length)]
self._labels = [np.array(labels, dtype=str)]

def set_targets_from_function(self, function, length=None):
"""
The function should take an index from 0 to length - 1,
where length is given, or, if not, is the length of the
inputs/targets. The function should return a list or matrix
that is the target pattern.
>>> import conx as cx
>>> ds = cx.Dataset()
>>> ds.set_inputs_from_function(lambda i: cx.binary(i, 2), length=4)
>>> def xor(i1, i2):
... return 1 if ((i1 or i2) and not (i1 and i2)) else 0
>>> ds.set_targets_from_function(lambda i: [xor(*ds.inputs[i])])
>>> ds.set_labels_from_function(lambda i: "0" if ds.targets[i] == [0] else "1")
>>> ds.labels[:]
['0', '1', '1', '0']
"""
length = length if length is not None else len(self)
targets = [function(i) for i in range(length)]
self._targets = [np.array(targets)]
self._cache_values()

def set_inputs_from_function(self, function, length=None):
"""
The function should take an index from 0 to length - 1,
where length is given, or, if not, is the length of the
inputs/targets. The function should return a list or matrix
that is the input pattern.
>>> import conx as cx
>>> ds = cx.Dataset()
>>> ds.set_inputs_from_function(lambda i: cx.binary(i, 2), length=4)
>>> def xor(i1, i2):
... return 1 if ((i1 or i2) and not (i1 and i2)) else 0
>>> ds.set_targets_from_function(lambda i: [xor(*ds.inputs[i])])
>>> ds.set_labels_from_function(lambda i: "0" if ds.targets[i] == [0] else "1")
>>> ds.labels[:]
['0', '1', '1', '0']
"""
length = length if length is not None else len(self)
inputs = [function(i) for i in range(length)]
self._inputs = [np.array(inputs)]
self._cache_values()

def set_targets_from_inputs(self, f=None, input_bank=0, target_bank=0):
"""
Copy the inputs to targets. Optionally, apply a function f to
Expand Down Expand Up @@ -1509,7 +1577,7 @@ def chop(self, amount):
0-1, or an integer number of patterns to drop.
>>> import conx as cx
>>> print("Downloading..."); dataset = cx.Dataset.get("vmnist") # doctest: +ELLIPSIS
>>> print("Downloading..."); dataset = cx.Dataset.get("mnist") # doctest: +ELLIPSIS
Downloading...
>>> len(dataset)
70000
Expand Down
2 changes: 1 addition & 1 deletion conx/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## All modules must be named differently from their functions!
## Otherwise, confuses tools like nose, inspect, etc.

from ._mnist import mnist, vmnist
from ._mnist import mnist, mnist_h5
from ._cifar10 import cifar10
from ._cifar100 import cifar100
from .cmu_faces import cmu_faces_full_size
Expand Down
104 changes: 29 additions & 75 deletions conx/datasets/_mnist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import conx as cx
import numpy as np
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.utils.data_utils import get_file

description = """
Original source: http://yann.lecun.com/exdb/mnist/
Expand All @@ -17,93 +17,47 @@
![MNIST Images](https://github.com/Calysto/conx/raw/master/data/mnist_images.png)
"""

def vmnist(*args, cache_size=3200, **kwargs):
path = "mnist.npz"
path = get_file(path,
origin='https://s3.amazonaws.com/img-datasets/mnist.npz',
file_hash='8a61469f7ea1b51cbae51d4f78837e45')
fp = np.load(path, mmap_mode="r")
img_rows, img_cols = 28, 28
total_len = len(fp["x_train"]) + len(fp["x_test"])

def get_cache(self, cache):
"""
Uses both test and train as data.
"""
## print("vmnist: getting cache #%s" % cache)
pos = cache * self._cache_size
if pos >= total_len:
raise Exception("position %s is beyond data" % pos)
if pos < len(fp["x_train"]):
key_input = "x_train"
key_target = "y_train"
else:
key_input = "x_test"
key_target = "y_test"
pos = pos - len(fp["x_train"])
## print("vmnist: getting pos #%s of %s" % (pos, key_input))
## inputs:
x_train = fp[key_input][pos:pos + self._cache_size]
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_train /= 255
inputs = x_train
## targets:
y_train = fp[key_target][pos:pos + self._cache_size]
labels = y_train
## FIXME: at least one is mis-labeled:
## labels[10994] = 9
targets = to_categorical(labels)
labels = np.array([str(label) for label in labels], dtype=str)
self._labels = [labels]
return [inputs], [targets]

dataset = cx.VirtualDataset(get_cache, total_len,
input_shapes=[(img_rows,img_cols,1)],
target_shapes=[(10,)],
inputs_range=[(0,1)],
targets_range=[(0,1)],
cache_size=cache_size,
load_cache_direct=True)
dataset.name = "MNIST"
dataset.description = ("This is a virtual dataset that loads from disk as needed.\n" +
description)
return dataset

def mnist(*args, **kwargs):
def mnist_h5(*args, **kwargs):
"""
Load the Keras MNIST dataset and format it as images.
Load the Keras MNIST dataset from an H5 file.
"""
import h5py

path = "mnist.h5"
url = "https://raw.githubusercontent.com/Calysto/conx/master/data/mnist.h5"
path = get_file(path, origin=url)
h5 = h5py.File(path, "r")
dataset = cx.Dataset()
dataset._inputs = h5["inputs"]
dataset._targets = h5["targets"]
dataset._labels = h5["labels"]
dataset.h5 = h5
dataset.name = "MNIST-H5"
dataset.description = description
dataset._cache_values()
return dataset

def mnist(*args, **kwargs):
from keras.datasets import mnist
import keras.backend as K

# input image dimensions
img_rows, img_cols = 28, 28
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
## We need to convert the data to images, but which format?
## We ask this Keras instance what it wants, and convert:
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
inputs = np.concatenate((x_train,x_test))
labels = np.concatenate((y_train,y_test))
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float16')
x_test = x_test.astype('float16')
inputs = np.concatenate((x_train,x_test)) / 255
labels = np.concatenate((y_train,y_test)) # ints, 0 to 10
###########################################
# fix mis-labeled image(s) in Keras dataset
labels[10994] = 9
###########################################
targets = to_categorical(labels)
targets = to_categorical(labels).astype("uint8")
labels = np.array([str(label) for label in labels], dtype=str)
dataset.name = "MNIST"
dataset.description = description
dataset = cx.Dataset()
dataset.load_direct([inputs], [targets], [labels])
return dataset
2 changes: 1 addition & 1 deletion conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,7 @@ def get_dataset(self, dataset_name):
'output'
>>> net.connect()
>>> net.compile(error="mse", optimizer="adam")
>>> net.get_dataset("vmnist")
>>> net.get_dataset("mnist")
"""
self.set_dataset(Dataset.get(dataset_name))

Expand Down
6 changes: 3 additions & 3 deletions conx/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_dataset():
net.connect('hidden1', 'hidden2')
net.connect('hidden2', 'output')
net.compile(optimizer="adam", error="binary_crossentropy")
net.get_dataset("vmnist")
net.get_dataset("mnist")
assert net is not None
net.dataset.clear()

Expand All @@ -103,15 +103,15 @@ def test_dataset2():
net.connect('hidden1', 'hidden2')
net.connect('hidden2', 'output')
net.compile(optimizer="adam", error="binary_crossentropy")
net.get_dataset("vmnist")
net.get_dataset("mnist")
net.dataset.split(100)
net.dataset.slice(100)
assert net is not None
net.dataset.clear()

def test_images():
net = Network("MNIST")
net.get_dataset("vmnist")
net.get_dataset("mnist")
assert net.dataset.inputs.shape == [(28,28,1)]
net.add(Layer("input", shape=(28, 28, 1), colormap="hot", minmax=(0,1)))
net.add(FlattenLayer("flatten"))
Expand Down
50 changes: 41 additions & 9 deletions conx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from urllib.parse import urlparse
import requests
import zipfile
import types
import math
import tqdm
import sys
Expand Down Expand Up @@ -847,7 +848,7 @@ def image_to_array(img, resize=None, raw=False):
img = img.resize(resize)
if img.mode != "RGB":
img = img.convert("RGB")
img = (np.array(img, "float32") / 255.0)
img = (np.array(img, "float16") / 255.0)
if not raw:
img = img.tolist()
return img
Expand Down Expand Up @@ -2096,15 +2097,46 @@ def shape(item, summary=False):

def load_data(filename, *args, **kwargs):
"""
Load a numpy datafile.
Load a numpy or h5 datafile.
"""
return np.load(filename, *args, **kwargs)

def save_data(filename, data, *args, **kwargs):
"""
Save an numpy datafile.
"""
np.save(filename, data, *args, **kwargs)
import h5py
if filename.endswith("h5"):
with h5py.File(filename, 'r') as h5:
return h5["data"]
else:
return np.load(filename, *args, **kwargs)

def save_data(filename, data, dtype='uint8', *args, **kwargs):
"""
Save data to a disk. data can be a generator.
If the extension is npy, we'll just try to save it.
If it is h5, then it will be in a particular format:
The h5 data format is [bank[:length]] which may be:
* [array(), array() array(), ...
* [list, list, list, ...]
"""
import h5py
if filename.endswith("h5"):
row_count = 0
with h5py.File(filename, 'w') as h5:
# Initialize a resizable dataset to hold the output
dset = None
for chunk in data: # list or generator
if dset is None:
chunk_shape = (1,) + shape(chunk)
maxshape = (None,) + shape(chunk)
dset = h5.create_dataset('data',
maxshape=maxshape,
shape=chunk_shape,
dtype=dtype)
dset.resize(row_count + 1, axis=0)
dset[row_count:] = chunk
row_count += 1 ## could allow bigger chunks!
else:
np.save(filename, data, *args, **kwargs)

def get_ranges(items, form, dims=tuple()):
"""
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ tqdm
requests
pydot
cairosvg
h5py
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
url='https://github.com/Calysto/conx',
install_requires=['numpy', 'keras>=2.1.3', 'matplotlib', 'ipywidgets>=7.0',
'Pillow', 'IPython', 'h5py', "svgwrite", "sklearn",
"tqdm", "requests", "pydot", "cairosvg"],
"tqdm", "requests", "pydot", "cairosvg", "h5py"],
packages=find_packages(include=['conx', 'conx.*']),
include_data_files = True,
test_suite = 'nose.collector',
Expand Down

0 comments on commit 90fca34

Please sign in to comment.