In [1]:
# import necesary libraries and modules
from __future__ import print_function
import os
import argparse

from zoo.orca import init_orca_context, stop_orca_context
from zoo.orca import OrcaContext

In [2]:

# recommended to set it to True when running Analytics Zoo in Jupyter notebook. 
OrcaContext.log_output = True # (this will display terminal's stdout and stderr in the Jupyter notebook).

cluster_mode = "local"

if cluster_mode == "local":
    init_orca_context(cores=1, memory="2g")   # run in local mode
elif cluster_mode == "k8s":
    init_orca_context(cluster_mode="k8s", num_nodes=2, cores=4) # run on K8s cluster
elif cluster_mode == "yarn":
    init_orca_context(
      cluster_mode="yarn-client", cores=4, num_nodes=2, memory="2g",
      driver_memory="10g", driver_cores=1,
      conf={"spark.rpc.message.maxSize": "1024",
            "spark.task.maxFailures": "1",
            "spark.driver.extraJavaOptions": "-Dbigdl.failure.retryTimes=1"})   # run on Hadoop YARN cluster

Initializing orca context
Current pyspark location is : /intern/spark/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/__init__.py
Start to getOrCreate SparkContext
pyspark_submit_args is:  --driver-class-path /home/zhenhao/analytics-zoo/zoo/target/analytics-zoo-bigdl_0.12.1-spark_2.4.3-0.10.0-SNAPSHOT-dist-all/lib/analytics-zoo-bigdl_0.12.1-spark_2.4.3-0.10.0-SNAPSHOT-jar-with-dependencies.jar pyspark-shell 
2021-02-18 16:40:15 WARN  Utils:66 - Your hostname, intern01 resolves to a loopback address: 127.0.1.1; using 10.239.44.107 instead (on interface eno1)
2021-02-18 16:40:15 WARN  Utils:66 - Set SPARK_LOCAL_IP if you need to bind to another address
2021-02-18 16:40:15 WARN  NativeCodeLoader:62 - Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/home/zhenhao/analytics-zoo/zoo/target/analytics-zoo-bigdl_0.12.1-spark_2.4.3-0.10.0-SNAPSHOT-dist-all/lib/analytics-zoo-bigdl_0.12.1-spark_2.4.3-0.10.0-SNAPSHOT-jar-with-dependencies.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/intern/spark/spark-2.4.3-bin-hadoop2.7/jars/slf4j-log4j12-1.7.16.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.slf4j.impl.Log4jLoggerFactory]
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


cls.getname: com.intel.analytics.bigdl.python.api.Sample
BigDLBasePickler registering: bigdl.util.common  Sample
cls.getname: com.intel.analytics.bigdl.python.api.EvaluatedResult
BigDLBasePickler registering: bigdl.util.common  EvaluatedResult
cls.getname: com.intel.analytics.bigdl.python.api.JTensor
BigDLBasePickler registering: bigdl.util.common  JTensor
cls.getname: com.intel.analytics.bigdl.python.api.JActivity
BigDLBasePickler registering: bigdl.util.common  JActivity
Successfully got a SparkContext



User settings:

   KMP_AFFINITY=granularity=fine,compact,1,0
   KMP_BLOCKTIME=0
   KMP_SETTINGS=1
   OMP_NUM_THREADS=1

Effective settings:

   KMP_ABORT_DELAY=0
   KMP_ADAPTIVE_LOCK_PROPS='1,1024'
   KMP_ALIGN_ALLOC=64
   KMP_ALL_THREADPRIVATE=128
   KMP_ATOMIC_MODE=2
   KMP_BLOCKTIME=0
   KMP_CPUINFO_FILE: value is not defined
   KMP_DETERMINISTIC_REDUCTION=false
   KMP_DEVICE_THREAD_LIMIT=2147483647
   KMP_DISP_HAND_THREAD=false
   KMP_DISP_NUM_BUFFERS=7
   KMP_DUPLICATE_LIB_OK=false
   KMP_FORCE_REDUCTION: value is not defined
   KMP_FOREIGN_THREADS_THREADPRIVATE=true
   KMP_FORKJOIN_BARRIER='2,2'
   KMP_FORKJOIN_BARRIER_PATTERN='hyper,hyper'
   KMP_FORKJOIN_FRAMES=true
   KMP_FORKJOIN_FRAMES_MODE=3
   KMP_GTID_MODE=3
   KMP_HANDLE_SIGNALS=false
   KMP_HOT_TEAMS_MAX_LEVEL=1
   KMP_HOT_TEAMS_MODE=0
   KMP_INIT_AT_FORK=true
   KMP_INIT_WAIT=2048
   KMP_ITT_PREPARE_DELAY=0
   KMP_LIBRARY=throughput
   KMP_LOCK_KIND=queuing
   KMP_MALLOC_POOL_INCR=1M
   KMP_NEXT_WAIT=1024
   KMP_NUM_L

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = LeNet()
model.train()
criterion = nn.NLLLoss()
lr = 0.001

adam = torch.optim.Adam(model.parameters(), lr)

In [2]:
print(model)

LeNet(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)


In [4]:
import torch
from torchvision import datasets, transforms

torch.manual_seed(0)
dir='./dataset'
batch_size=320
test_batch_size=320

train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(dir, train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size= batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(dir, train=False,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=test_batch_size, shuffle=False)

In [5]:

from zoo.orca.learn.pytorch import Estimator 
from zoo.orca.learn.metrics import MAE

est = Estimator.from_torch(model=model, optimizer=adam, loss=criterion, metrics=[MAE()])

creating: createZooKerasMAE
creating: createTorchLoss
creating: createTorchOptim
creating: createTorchModel
creating: createEstimator


In [6]:

from zoo.orca.learn.trigger import EveryEpoch 

est.fit(data=train_loader, epochs=1, validation_data=test_loader,
        checkpoint_trigger=EveryEpoch())

creating: createEveryEpoch
creating: createMaxEpoch


                                                                                

linux-x86_64/libiomp5.so
linux-x86_64/libmklml_intel.so
linux-x86_64/libtensorflow_framework-zoo.so
linux-x86_64/libtensorflow_jni.so


                                                                                

2021-02-18 16:40:54 INFO  DistriOptimizer$:818 - caching training rdd ...


                                                                                

2021-02-18 16:41:18 INFO  DistriOptimizer$:649 - Cache thread models...
2021-02-18 16:41:18 INFO  DistriOptimizer$:632 - model thread pool size is 1
2021-02-18 16:41:18 INFO  DistriOptimizer$:651 - Cache thread models... done
2021-02-18 16:41:18 INFO  DistriOptimizer$:161 - Count dataset


[Stage 7:>                                                          (0 + 1) / 1]

Warn: jep.JepException: <class 'StopIteration'>
	at jep.Jep.exec(Native Method)
	at jep.Jep.exec(Jep.java:478)
	at com.intel.analytics.zoo.common.PythonInterpreter$$anonfun$1.apply$mcV$sp(PythonInterpreter.scala:108)
	at com.intel.analytics.zoo.common.PythonInterpreter$$anonfun$1.apply(PythonInterpreter.scala:107)
	at com.intel.analytics.zoo.common.PythonInterpreter$$anonfun$1.apply(PythonInterpreter.scala:107)
	at scala.concurrent.impl.Future$PromiseCompletingRunnable.liftedTree1$1(Future.scala:24)
	at scala.concurrent.impl.Future$PromiseCompletingRunnable.run(Future.scala:24)
	at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511)
	at java.util.concurrent.FutureTask.run(FutureTask.java:266)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

2021-02-18 16:41:27 INFO  DistriOptimizer$:165 - Count dataset comp

[Stage 8:>                                                          (0 + 1) / 1]

Warn: jep.JepException: <class 'StopIteration'>
	at jep.Jep.exec(Native Method)
	at jep.Jep.exec(Jep.java:478)
	at com.intel.analytics.zoo.common.PythonInterpreter$$anonfun$1.apply$mcV$sp(PythonInterpreter.scala:108)
	at com.intel.analytics.zoo.common.PythonInterpreter$$anonfun$1.apply(PythonInterpreter.scala:107)
	at com.intel.analytics.zoo.common.PythonInterpreter$$anonfun$1.apply(PythonInterpreter.scala:107)
	at scala.concurrent.impl.Future$PromiseCompletingRunnable.liftedTree1$1(Future.scala:24)
	at scala.concurrent.impl.Future$PromiseCompletingRunnable.run(Future.scala:24)
	at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511)
	at java.util.concurrent.FutureTask.run(FutureTask.java:266)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

2021-02-18 16:41:36 WARN  DistriOptimizer$:167 - If the dataset is 

                                                                                

2021-02-18 16:41:37 INFO  DistriOptimizer$:427 - [Epoch 1 320/60160][Iteration 1][Wall Clock 0.59850156s] Trained 320.0 records in 0.59850156 seconds. Throughput is 534.66864 records/second. Loss is 2.2994292. 
2021-02-18 16:41:37 INFO  DistriOptimizer$:427 - [Epoch 1 640/60160][Iteration 2][Wall Clock 0.849114011s] Trained 320.0 records in 0.250612451 seconds. Throughput is 1276.872 records/second. Loss is 2.186614. 
2021-02-18 16:41:37 INFO  DistriOptimizer$:427 - [Epoch 1 960/60160][Iteration 3][Wall Clock 1.097876899s] Trained 320.0 records in 0.248762888 seconds. Throughput is 1286.3655 records/second. Loss is 2.0546646. 
2021-02-18 16:41:37 INFO  DistriOptimizer$:427 - [Epoch 1 1280/60160][Iteration 4][Wall Clock 1.383833165s] Trained 320.0 records in 0.285956266 seconds. Throughput is 1119.0522 records/second. Loss is 1.8639662. 
2021-02-18 16:41:38 INFO  DistriOptimizer$:427 - [Epoch 1 1600/60160][Iteration 5][Wall Clock 1.627987086s] Trained 320.0 records in 0.244153921 second

2021-02-18 16:41:46 INFO  DistriOptimizer$:427 - [Epoch 1 12800/60160][Iteration 40][Wall Clock 9.561469003s] Trained 320.0 records in 0.218275417 seconds. Throughput is 1466.0377 records/second. Loss is 0.17649579. 
2021-02-18 16:41:46 INFO  DistriOptimizer$:427 - [Epoch 1 13120/60160][Iteration 41][Wall Clock 9.777986721s] Trained 320.0 records in 0.216517718 seconds. Throughput is 1477.9391 records/second. Loss is 0.30072063. 
2021-02-18 16:41:46 INFO  DistriOptimizer$:427 - [Epoch 1 13440/60160][Iteration 42][Wall Clock 9.995656842s] Trained 320.0 records in 0.217670121 seconds. Throughput is 1470.1145 records/second. Loss is 0.2571685. 
2021-02-18 16:41:46 INFO  DistriOptimizer$:427 - [Epoch 1 13760/60160][Iteration 43][Wall Clock 10.226178561s] Trained 320.0 records in 0.230521719 seconds. Throughput is 1388.1555 records/second. Loss is 0.20060596. 
2021-02-18 16:41:47 INFO  DistriOptimizer$:427 - [Epoch 1 14080/60160][Iteration 44][Wall Clock 10.446905388s] Trained 320.0 records

2021-02-18 16:41:54 INFO  DistriOptimizer$:427 - [Epoch 1 24960/60160][Iteration 78][Wall Clock 17.867503079s] Trained 320.0 records in 0.217766055 seconds. Throughput is 1469.4668 records/second. Loss is 0.15499051. 
2021-02-18 16:41:54 INFO  DistriOptimizer$:427 - [Epoch 1 25280/60160][Iteration 79][Wall Clock 18.082381759s] Trained 320.0 records in 0.21487868 seconds. Throughput is 1489.2124 records/second. Loss is 0.13634413. 
2021-02-18 16:41:54 INFO  DistriOptimizer$:427 - [Epoch 1 25600/60160][Iteration 80][Wall Clock 18.313012033s] Trained 320.0 records in 0.230630274 seconds. Throughput is 1387.5021 records/second. Loss is 0.117556736. 
2021-02-18 16:41:55 INFO  DistriOptimizer$:427 - [Epoch 1 25920/60160][Iteration 81][Wall Clock 18.526501735s] Trained 320.0 records in 0.213489702 seconds. Throughput is 1498.9014 records/second. Loss is 0.15658899. 
2021-02-18 16:41:55 INFO  DistriOptimizer$:427 - [Epoch 1 26240/60160][Iteration 82][Wall Clock 18.74485568s] Trained 320.0 reco

2021-02-18 16:42:02 INFO  DistriOptimizer$:427 - [Epoch 1 37120/60160][Iteration 116][Wall Clock 26.137964024s] Trained 320.0 records in 0.216236263 seconds. Throughput is 1479.8628 records/second. Loss is 0.07247432. 
2021-02-18 16:42:02 INFO  DistriOptimizer$:427 - [Epoch 1 37440/60160][Iteration 117][Wall Clock 26.353199351s] Trained 320.0 records in 0.215235327 seconds. Throughput is 1486.7449 records/second. Loss is 0.15185812. 
2021-02-18 16:42:03 INFO  DistriOptimizer$:427 - [Epoch 1 37760/60160][Iteration 118][Wall Clock 26.567673919s] Trained 320.0 records in 0.214474568 seconds. Throughput is 1492.0183 records/second. Loss is 0.05162517. 
2021-02-18 16:42:03 INFO  DistriOptimizer$:427 - [Epoch 1 38080/60160][Iteration 119][Wall Clock 26.784366677s] Trained 320.0 records in 0.216692758 seconds. Throughput is 1476.7452 records/second. Loss is 0.12618582. 
2021-02-18 16:42:03 INFO  DistriOptimizer$:427 - [Epoch 1 38400/60160][Iteration 120][Wall Clock 27.000979948s] Trained 320.

2021-02-18 16:42:11 INFO  DistriOptimizer$:427 - [Epoch 1 49280/60160][Iteration 154][Wall Clock 34.336829952s] Trained 320.0 records in 0.215002667 seconds. Throughput is 1488.3536 records/second. Loss is 0.06714714. 
2021-02-18 16:42:11 INFO  DistriOptimizer$:427 - [Epoch 1 49600/60160][Iteration 155][Wall Clock 34.549025707s] Trained 320.0 records in 0.212195755 seconds. Throughput is 1508.0415 records/second. Loss is 0.07002549. 
2021-02-18 16:42:11 INFO  DistriOptimizer$:427 - [Epoch 1 49920/60160][Iteration 156][Wall Clock 34.770275408s] Trained 320.0 records in 0.221249701 seconds. Throughput is 1446.3296 records/second. Loss is 0.18483515. 
2021-02-18 16:42:11 INFO  DistriOptimizer$:427 - [Epoch 1 50240/60160][Iteration 157][Wall Clock 34.984838116s] Trained 320.0 records in 0.214562708 seconds. Throughput is 1491.4054 records/second. Loss is 0.0683266. 
2021-02-18 16:42:11 INFO  DistriOptimizer$:427 - [Epoch 1 50560/60160][Iteration 158][Wall Clock 35.195254038s] Trained 320.0

Py4JJavaError: An error occurred while calling o51.estimatorTrainMiniBatch.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 385.0 failed 1 times, most recent failure: Lost task 0.0 in stage 385.0 (TID 385, localhost, executor driver): java.lang.IllegalArgumentException: requirement failed: tensor size not match 320x10 320
	at scala.Predef$.require(Predef.scala:224)
	at com.intel.analytics.bigdl.tensor.DenseTensor$.expandSize(DenseTensor.scala:2790)
	at com.intel.analytics.bigdl.tensor.DenseTensor.expandTensor(DenseTensor.scala:979)
	at com.intel.analytics.bigdl.tensor.DenseTensor.add(DenseTensor.scala:973)
	at com.intel.analytics.bigdl.nn.AbsCriterion.updateOutput(AbsCriterion.scala:36)
	at com.intel.analytics.bigdl.nn.AbsCriterion.updateOutput(AbsCriterion.scala:28)
	at com.intel.analytics.bigdl.nn.abstractnn.AbstractCriterion.forward(AbstractCriterion.scala:73)
	at com.intel.analytics.zoo.pipeline.api.keras.metrics.MAE.apply(MAE.scala:32)
	at com.intel.analytics.zoo.pipeline.api.keras.metrics.MAE.apply(MAE.scala:27)
	at com.intel.analytics.bigdl.optim.AbstractOptimizer$$anonfun$2$$anonfun$apply$5$$anonfun$3$$anonfun$apply$6$$anonfun$apply$9.apply(AbstractOptimizer.scala:158)
	at com.intel.analytics.bigdl.optim.AbstractOptimizer$$anonfun$2$$anonfun$apply$5$$anonfun$3$$anonfun$apply$6$$anonfun$apply$9.apply(AbstractOptimizer.scala:157)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
	at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
	at com.intel.analytics.bigdl.optim.AbstractOptimizer$$anonfun$2$$anonfun$apply$5$$anonfun$3$$anonfun$apply$6.apply(AbstractOptimizer.scala:157)
	at com.intel.analytics.bigdl.optim.AbstractOptimizer$$anonfun$2$$anonfun$apply$5$$anonfun$3$$anonfun$apply$6.apply(AbstractOptimizer.scala:141)
	at com.intel.analytics.bigdl.utils.ThreadPool$$anonfun$invokeAndWait$1$$anonfun$apply$3.apply(ThreadPool.scala:133)
	at scala.concurrent.impl.Future$PromiseCompletingRunnable.liftedTree1$1(Future.scala:24)
	at scala.concurrent.impl.Future$PromiseCompletingRunnable.run(Future.scala:24)
	at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511)
	at java.util.concurrent.FutureTask.run(FutureTask.java:266)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1889)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1877)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1876)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1876)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2110)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2048)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2158)
	at org.apache.spark.rdd.RDD$$anonfun$reduce$1.apply(RDD.scala:1035)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
	at org.apache.spark.rdd.RDD.reduce(RDD.scala:1017)
	at com.intel.analytics.bigdl.optim.AbstractOptimizer.validate(AbstractOptimizer.scala:168)
	at com.intel.analytics.bigdl.optim.DistriOptimizer$.optimize(DistriOptimizer.scala:490)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at com.intel.analytics.zoo.pipeline.api.keras.layers.utils.KerasUtils$.invokeMethod(KerasUtils.scala:302)
	at com.intel.analytics.zoo.pipeline.api.keras.layers.utils.KerasUtils$.invokeMethodWithEv(KerasUtils.scala:329)
	at com.intel.analytics.zoo.pipeline.api.keras.models.InternalOptimizerUtil$.optimizeModels(Topology.scala:1063)
	at com.intel.analytics.zoo.pipeline.api.keras.models.InternalDistriOptimizer.train(Topology.scala:1262)
	at com.intel.analytics.zoo.pipeline.api.keras.models.InternalDistriOptimizer.train(Topology.scala:1475)
	at com.intel.analytics.zoo.pipeline.api.keras.models.InternalDistriOptimizer.train(Topology.scala:1145)
	at com.intel.analytics.zoo.pipeline.estimator.Estimator.train(Estimator.scala:190)
	at com.intel.analytics.zoo.pipeline.estimator.python.PythonEstimator.estimatorTrainMiniBatch(PythonEstimator.scala:117)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.lang.IllegalArgumentException: requirement failed: tensor size not match 320x10 320
	at scala.Predef$.require(Predef.scala:224)
	at com.intel.analytics.bigdl.tensor.DenseTensor$.expandSize(DenseTensor.scala:2790)
	at com.intel.analytics.bigdl.tensor.DenseTensor.expandTensor(DenseTensor.scala:979)
	at com.intel.analytics.bigdl.tensor.DenseTensor.add(DenseTensor.scala:973)
	at com.intel.analytics.bigdl.nn.AbsCriterion.updateOutput(AbsCriterion.scala:36)
	at com.intel.analytics.bigdl.nn.AbsCriterion.updateOutput(AbsCriterion.scala:28)
	at com.intel.analytics.bigdl.nn.abstractnn.AbstractCriterion.forward(AbstractCriterion.scala:73)
	at com.intel.analytics.zoo.pipeline.api.keras.metrics.MAE.apply(MAE.scala:32)
	at com.intel.analytics.zoo.pipeline.api.keras.metrics.MAE.apply(MAE.scala:27)
	at com.intel.analytics.bigdl.optim.AbstractOptimizer$$anonfun$2$$anonfun$apply$5$$anonfun$3$$anonfun$apply$6$$anonfun$apply$9.apply(AbstractOptimizer.scala:158)
	at com.intel.analytics.bigdl.optim.AbstractOptimizer$$anonfun$2$$anonfun$apply$5$$anonfun$3$$anonfun$apply$6$$anonfun$apply$9.apply(AbstractOptimizer.scala:157)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
	at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
	at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
	at com.intel.analytics.bigdl.optim.AbstractOptimizer$$anonfun$2$$anonfun$apply$5$$anonfun$3$$anonfun$apply$6.apply(AbstractOptimizer.scala:157)
	at com.intel.analytics.bigdl.optim.AbstractOptimizer$$anonfun$2$$anonfun$apply$5$$anonfun$3$$anonfun$apply$6.apply(AbstractOptimizer.scala:141)
	at com.intel.analytics.bigdl.utils.ThreadPool$$anonfun$invokeAndWait$1$$anonfun$apply$3.apply(ThreadPool.scala:133)
	at scala.concurrent.impl.Future$PromiseCompletingRunnable.liftedTree1$1(Future.scala:24)
	at scala.concurrent.impl.Future$PromiseCompletingRunnable.run(Future.scala:24)
	at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511)
	at java.util.concurrent.FutureTask.run(FutureTask.java:266)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more


In [None]:
train_loader

In [None]:

result = est.evaluate(data=test_loader)
for r in result:
    print(str(r))