In [None]:
from bigdl.dataset import transformer
from bigdl.nn import criterion
from bigdl.nn import layer
from bigdl.optim import optimizer
from bigdl.util import common
import pyspark
import os
from bigdl.dataset import mnist
from bigdl.dataset import transformer
import glob
import imageio
import numpy as np

In [None]:
conf = (
    common.create_spark_conf()
    .setAppName('bigdl-mnist')
    .setMaster(os.environ.get('SPARK_MASTER'))
    )
conf = conf.set('spark.executor.cores', 1)
conf = conf.set('spark.cores.max', 1)
##ADD BIGDL_JARS
conf.set("spark.jars",os.environ.get('BIGDL_JARS'))
sc = pyspark.SparkContext(conf=conf)
common.init_engine()

In [None]:
def build_model(class_num):
    model = layer.Sequential()
    model.add(layer.Reshape([1, 28, 28]))
    model.add(layer.SpatialConvolution(1, 6, 5, 5))
    model.add(layer.Tanh())
    model.add(layer.SpatialMaxPooling(2, 2, 2, 2))
    model.add(layer.Tanh())
    model.add(layer.SpatialConvolution(6, 12, 5, 5))
    model.add(layer.SpatialMaxPooling(2, 2, 2, 2))
    model.add(layer.Reshape([12 * 4 * 4]))
    model.add(layer.Linear(12 * 4 * 4, 100))
    model.add(layer.Tanh())
    model.add(layer.Linear(100, class_num))
    model.add(layer.LogSoftMax())
    return model

In [None]:
##Files from local dataset
files = glob.glob(os.environ.get('DATA_DIR')+'/train/*.png')
def mapper(x):
    label = int(x.split('/')[-1].split('-')[-1][:-4])+1
    image = imageio.imread('file://'+x).astype(np.float32).reshape(1,28,28)/255
    return common.Sample.from_ndarray(image, label)
trainRDD = sc.parallelize(files).map(mapper)

In [None]:
opt = optimizer.Optimizer(
    model=build_model(10),
    training_rdd=trainRDD,
    criterion=criterion.ClassNLLCriterion(),
    optim_method=optimizer.SGD(
        learningrate=0.01, learningrate_decay=0.0002
    ),
    end_trigger=optimizer.MaxEpoch(1),
    batch_size=10
)

In [None]:
trained_model = opt.optimize()

In [None]:
os.mkdir('/tmp/mnist')
trained_model.saveModel(
    '/tmp/mnist/model.pb',
    '/tmp/mnist/model.bin',
    over_write=True
)

In [None]:
files = glob.glob(os.environ.get('DATA_DIR')+'/test/*.png')
validateRDD = sc.parallelize(files).map(mapper)

In [None]:
results = trained_model.evaluate(validateRDD,10,[optimizer.Top1Accuracy()])

In [None]:
print(results[0])

In [None]:
files = glob.glob(os.environ.get('DATA_DIR')+'/test/*.png')
def mapper_test(x):
    label = int(x.split('/')[-1].split('-')[-1][:-4])+1
    image = imageio.imread('file://'+x).astype(np.float32).reshape(1, 28, 28)/255
    return (label,image)
testRDD = sc.parallelize(files).map(mapper_test)
predictRDD  = testRDD.map(lambda x: common.Sample.from_ndarray(x[1],np.array([2.0])))
labelsRDD = testRDD.map(lambda x: x[0])

In [None]:
predicts = trained_model.predict(predictRDD).map(lambda x: np.argmax(x)+1)

In [None]:
labelsRDD.zip(predicts).collect()

In [None]:
sc.stop()