In [1]:
import mxnet as mx
import numpy as np

In [2]:
class SimpleIter(mx.io.DataIter):
    def __init__(self, data_names, data_shapes, data_gen,
                 label_names, label_shapes, label_gen, num_batches=10):
        self._provide_data = list(zip(data_names, data_shapes))
        self._provide_label = list(zip(label_names, label_shapes))
        self.num_batches = num_batches
        self.data_gen = data_gen
        self.label_gen = label_gen
        self.cur_batch = 0

    def __iter__(self):
        return self

    def reset(self):
        self.cur_batch = 0

    def __next__(self):
        return self.next()

    @property
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        if self.cur_batch < self.num_batches:
            self.cur_batch += 1
            data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
            label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
            return mx.io.DataBatch(data, label)
        else:
            raise StopIteration

In [3]:
num_classes = 10
net = mx.sym.Variable('data')
lbl1 = mx.sym.Variable('output_1')
lbl2 = mx.sym.Variable('output_2')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
fc1 = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
out1 = mx.sym.SoftmaxOutput(data=fc1, name='output_1', label=lbl1)
fc2 = mx.sym.FullyConnected(data=net, name='fc3', num_hidden=num_classes)
out2  = mx.sym.SoftmaxOutput(data=fc2, name='output_2', label=lbl2)
group = mx.sym.Group([out1, out2])

In [4]:
import logging
logging.basicConfig(level=logging.INFO)

n = 32
data_iter = SimpleIter(['data'], [(n, 100)],
                  [lambda s: np.random.uniform(-1, 1, s)],
                  ['output_1', 'output_2'], [(n,), (n,)],
                  [lambda s: np.random.randint(0, num_classes, s), lambda s: np.random.randint(0, num_classes, s)])

mod = mx.mod.Module(symbol=group, label_names=['output_1', 'output_2'])
mod.fit(data_iter, num_epoch=5)

INFO:root:Epoch[0] Train-accuracy=0.112500
INFO:root:Epoch[0] Time cost=0.012
INFO:root:Epoch[1] Train-accuracy=0.104688
INFO:root:Epoch[1] Time cost=0.009
INFO:root:Epoch[2] Train-accuracy=0.103125
INFO:root:Epoch[2] Time cost=0.009
INFO:root:Epoch[3] Train-accuracy=0.098437
INFO:root:Epoch[3] Time cost=0.009
INFO:root:Epoch[4] Train-accuracy=0.100000
INFO:root:Epoch[4] Time cost=0.009


In [5]:
data_iter2 = SimpleIter(['data'], [(n, 100)],
                  [lambda s: np.random.uniform(-1, 1, s)],
                  ['output_1', 'output_2'], [(n,), (n,)],
                  [lambda s: np.random.randint(0, num_classes, s), lambda s: np.random.randint(0, num_classes, s)])

In [6]:
predictions = mod.predict(next(data_iter2).data[0])

In [7]:
predictions


[[0.09980328 0.09958    0.09992559 0.0992334  0.10046541 0.100039
  0.10044528 0.10066289 0.09978553 0.10005961]
 [0.09990032 0.09943327 0.10007251 0.09912145 0.10055376 0.10018851
  0.10028589 0.10081293 0.0996526  0.0999788 ]
 [0.09980104 0.09943943 0.09991118 0.09928276 0.10051626 0.10013548
  0.10031877 0.10069867 0.09982008 0.1000763 ]
 [0.09981309 0.09960328 0.09996323 0.09940137 0.10052121 0.10006854
  0.10025609 0.10077504 0.09965497 0.09994324]
 [0.09992515 0.09953135 0.09996832 0.09941342 0.10048115 0.10008149
  0.10030657 0.10060601 0.09989382 0.09979264]
 [0.10007142 0.09947996 0.09998374 0.09924676 0.10051473 0.10007808
  0.10032851 0.10076615 0.09957285 0.0999578 ]
 [0.09995866 0.09957986 0.09992631 0.09935763 0.10050821 0.10001034
  0.10029974 0.1007157  0.0997372  0.09990631]
 [0.09987906 0.09946071 0.09998742 0.09935401 0.10053796 0.09998755
  0.10039282 0.10065696 0.09977262 0.09997077]
 [0.09998529 0.09934276 0.09983914 0.09932289 0.10044803 0.10013824
  0.1003972  

In [8]:
mod.get_outputs()

[
 [[0.09980328 0.09958    0.09992559 0.0992334  0.10046541 0.100039
   0.10044528 0.10066289 0.09978553 0.10005961]
  [0.09990032 0.09943327 0.10007251 0.09912145 0.10055376 0.10018851
   0.10028589 0.10081293 0.0996526  0.0999788 ]
  [0.09980104 0.09943943 0.09991118 0.09928276 0.10051626 0.10013548
   0.10031877 0.10069867 0.09982008 0.1000763 ]
  [0.09981309 0.09960328 0.09996323 0.09940137 0.10052121 0.10006854
   0.10025609 0.10077504 0.09965497 0.09994324]
  [0.09992515 0.09953135 0.09996832 0.09941342 0.10048115 0.10008149
   0.10030657 0.10060601 0.09989382 0.09979264]
  [0.10007142 0.09947996 0.09998374 0.09924676 0.10051473 0.10007808
   0.10032851 0.10076615 0.09957285 0.0999578 ]
  [0.09995866 0.09957986 0.09992631 0.09935763 0.10050821 0.10001034
   0.10029974 0.1007157  0.0997372  0.09990631]
  [0.09987906 0.09946071 0.09998742 0.09935401 0.10053796 0.09998755
   0.10039282 0.10065696 0.09977262 0.09997077]
  [0.09998529 0.09934276 0.09983914 0.09932289 0.10044803 0.1001