New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

symbolic+imperative nn interface #5705

Closed
wants to merge 17 commits into
base: master
from

Conversation

Projects
None yet
9 participants
@piiswrong
Copy link
Contributor

piiswrong commented Apr 5, 2017

a high level NN interface that supports symbolic+imperative use. With a mixed Keras and pytorch flavor.

@jermainewang @pluskid @sxjscience @yajiedesign @tqchen @mli @madjam

@sxjscience

This comment has been minimized.

Copy link
Member

sxjscience commented Apr 12, 2017

Should we try to combine this with mx.mod.Module? (We have written a lot of codes in Module.) Also, I find that the old mod.fit way hides too many details and is not so flexible compared to PyTorch's manual forward/backward approach. Should we encourage the user to manually write the training loop in the future?

@piiswrong piiswrong force-pushed the piiswrong:nn branch 2 times, most recently from f838f76 to 2a5b780 Apr 14, 2017

@piiswrong piiswrong changed the base branch from dev to master Apr 14, 2017

@piiswrong piiswrong force-pushed the piiswrong:nn branch from 2a5b780 to 620a89e Apr 19, 2017

@mli
Copy link
Member

mli left a comment

i also have the same question as @sxjscience , what's the relationship between nn and mod, and how to educate users to use

# code to automatically download dataset
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.append(os.path.join(curr_path, "../../tests/python/common"))
import get_data

This comment has been minimized.

@mli

mli Apr 20, 2017

Member

suggest to add a function in test_utils.py, for example

def get_mnist():
    """Download and load the MNIST dataset

    Returns
    -------
    dict
        A dict containing
    """
    def read_data(label_url, image_url):
        with gzip.open(mx.test_utils.download(label_url)) as flbl:
            magic, num = struct.unpack(">II", flbl.read(8))
            label = np.fromstring(flbl.read(), dtype=np.int8)
        with gzip.open(mx.test_utils.download(image_url), 'rb') as fimg:
            magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
            image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
            image = image.reshape(image.shape[0], 1, 28, 28).astype(np.float32)/255
        return (label, image)

    # changed to mxnet.io for more stable hosting
    # path='http://yann.lecun.com/exdb/mnist/'
    path='http://data.mxnet.io/data/mnist/'
    (train_lbl, train_img) = read_data(
        path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz')
    (test_lbl, test_img) = read_data(
            path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz')
    return {'train_data':train_img, 'train_label':train_lbl,
            'test_data':test_img, 'test_label':test_lbl}

then we can use NDArrayIter for them

This comment has been minimized.

@piiswrong

piiswrong Apr 20, 2017

Contributor

This is temporary. I'll do something like pytorch.dataset later

Layers can also contain other Layers, allowing you to nest them in a tree
structure. You can assign sublayers as regular attributes::
from mxnet import nn

This comment has been minimized.

@mli

mli Apr 20, 2017

Member

there should be an empty line before from

This comment has been minimized.

@mli

mli Apr 20, 2017

Member

also suggest to add >>> before each line, then we can have doctest to run all these commands

@@ -0,0 +1,93 @@
# coding: utf-8
# pylint: disable=

This comment has been minimized.

@mli

mli Apr 20, 2017

Member

i think both coding and pylint are not necessary

@piiswrong piiswrong referenced this pull request Apr 21, 2017

Closed

Dynamic Graphs #5918

@piiswrong piiswrong force-pushed the piiswrong:nn branch from 9134641 to 7ef67ea Apr 23, 2017



class SimpleLayer(Layer):
"""SimpleLayer is a Layer that supports forwarding with both `Symbol` and `NDArray`.

This comment has been minimized.

@ZihengJiang

ZihengJiang Apr 25, 2017

Contributor

What's the difference between SimpleLayer and Layer? Layer cannot forward with NDArray?

@piiswrong piiswrong force-pushed the piiswrong:nn branch 4 times, most recently from f4eb5bf to bd9f800 Apr 26, 2017

@piiswrong piiswrong changed the title Nn symbolic+imperative nn interface May 2, 2017

@sbodenstein

This comment has been minimized.

Copy link
Contributor

sbodenstein commented May 3, 2017

@piiswrong: we are very interested in dynamic graph construction direction MXNet is going in. A question: how flexible (compared to pytorch) is this? For example, can complicated NLP examples be implemented (like https://devblogs.nvidia.com/parallelforall/recursive-neural-networks-pytorch/)?

@piiswrong

This comment has been minimized.

Copy link
Contributor

piiswrong commented May 4, 2017

@sbodenstein It should match pytorch feature by feature. If you see anything missing it should be easy to add it

@piiswrong piiswrong force-pushed the piiswrong:nn branch from bd9f800 to b1a7adf May 4, 2017

@pluskid

This comment has been minimized.

Copy link
Contributor

pluskid commented May 5, 2017

@sxjscience The existing module API can actually do low level forward backward easily, as a demo shown here for writing your own explicit training loop. But comparing to PyTorch's full interface, it is still a bit basic. I do think we have a lot to learn from PyTorch / Chainer API design by either extending module or have a new API.

label = nn.utils.load_data(batch.label[0], ctx_list=ctx, batch_axis=0)
outputs = []
with ag.train_section():
for x, y in zip(data, label):

This comment has been minimized.

@pluskid

pluskid May 5, 2017

Contributor

Why are we having an inner loop? Is it iterating through the batch one example at a time? Is it used to demonstrate that we can do this or is it suggesting the user should do this?

This comment has been minimized.

@piiswrong

piiswrong May 10, 2017

Contributor

This is doing data parallel on multiple gpus.
we can have a DataParallel util like pytorch but collecting the outputs to one gpu is not a good idea.

This comment has been minimized.

@pluskid

pluskid May 10, 2017

Contributor

I see. Is it possible to hide the data-par mechanism like the original mxnet way? I guess in the imperative style API, the users always explicitly write for loops to handle time index or other dimensions. So it seems to be reasonable to always assume the 1st dimension is batch.

This comment has been minimized.

@mli

mli May 19, 2017

Member

it also confuses me. i thought you are iterating for example by example. probably rename load_data into something like split_data_to_dev.

def __init__(self, prefix=None, params=None):
super(Net, self).__init__(prefix=prefix, params=params)
self.dense1 = nn.Dense(20, in_units=10, prefix=self.prefix+'dense1_')
self.dense2 = nn.Dense(20, in_units=20, prefix=self.prefix+'dense2_')

This comment has been minimized.

@pluskid

pluskid May 5, 2017

Contributor

pytorch use metaclass magic to automatically name the sub-layer according to the variable name it get's assigned. So

self.dense1 = nn.Dense(...)

will automatically have a (nested) name dense1. However, I'm not sure this is actually a feature we want to have. It sometimes could confuse people. If you do something like

self.layers = [nn.Dense(...) for i in range(5)]

the meta-class will no longer be able to figure it out automatically, and not even be able to collect all the parameters in a layer.

parameter by passing the same dictionary to them. For example::
params = nn.ParameterDict(prefix='dense_')
dense1 = nn.Dense(20, in_units=10, prefix='dense1_', params=params)
dense2 = nn.Dense(20, in_units=10, prefix='dense2_', params=params)

This comment has been minimized.

@pluskid

pluskid May 5, 2017

Contributor

What does this example mean? The two separate prefix dense1_ and dense2_ are used to name the layer, but their parameter will be named dense_ XXX?

This comment has been minimized.

@piiswrong

piiswrong May 10, 2017

Contributor

yes, for sharing parameters

self._children = []

def __setattr__(self, name, value):
"""Automatically register sublayers."""

This comment has been minimized.

@pluskid

pluskid May 5, 2017

Contributor

Maybe we are not going to make thing so complicated to support setting a list of layers or a dict of layers (sometimes people want to do that when there are a lot of inner layers or they are generated pragmatically). But maybe could be useful to detect this situation and print a warning that the layers might not be registered.


def __setattr__(self, name, value):
"""Automatically register sublayers."""
super(Layer, self).__setattr__(name, value)

This comment has been minimized.

@pluskid

pluskid May 5, 2017

Contributor

Maybe we can use the name as a prefix if the layer being registered does not already explicitly have a prefix?

This comment has been minimized.

@piiswrong

piiswrong May 10, 2017

Contributor

I've considered this. But it gets complicated and confusing when you start to nest layers.

should renaming be recursive?

This comment has been minimized.

@pluskid

pluskid May 10, 2017

Contributor

Hmm, I see. Yes, it should be recursive, and then becomes a bit complicated.

I checked the pytorch version. Their Module object do not have a name attached to it. The name is only used as a key when a module is nested in other modules and will be generated on the fly so that recursive key prefix is simple: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py#L440

For us, the module has its own name, and if we assign the same module as a submodule to two different places, it could be potentially conflicting names. So the situation is more complicated.

This comment has been minimized.

@piiswrong

piiswrong May 11, 2017

Contributor

Overall implicit registration on assign is pretty bad engineering. But it does make the interface easier to use.

This comment has been minimized.

@jekbradbury

jekbradbury May 13, 2017

Chainer used to use explicit registration at __init__ time, and is moving to PyTorch's implicit registration on assign because it works much better with IDEs.

This comment has been minimized.

@jekbradbury

jekbradbury May 16, 2017

Actually the with approach is very clever; it solves the list vs ModuleList problem too!

This comment has been minimized.

@piiswrong

piiswrong May 16, 2017

Contributor

we can also call it build or constructor instead of init so that its even more obvious

This comment has been minimized.

@piiswrong

piiswrong May 16, 2017

Contributor

yeah that's the idea. I have been banging my head against this for 2 days now...
pytorch let the devil out of the box and now we have to follow suit

This comment has been minimized.

@pluskid

pluskid May 16, 2017

Contributor

I like this scope thing. Except that I would prefer not to hide it, but to require user to write it explicitly (so that they know what is happening and do not get confused when they accidentally overwrite __init__ instead of init. So I'm suggesting maybe we could make it explicit by using a informative name for the scope such as

with self.collect_params():
    self.fcs = [nn.Dense(...) for i in range(3)]
    self.pred = nn.Dense(...)

collect_params could be replaced with a better name, but my point is to ask user to do it explicitly instead of hiding it. What do you guys think?

This comment has been minimized.

@pluskid

pluskid May 16, 2017

Contributor

@jekbradbury as for __call__ and forward. I think both are fine. What I do not like about pytorch is that it seems __call__ does extra bookkeepings than forward and I do not what kind of consequence you will have by calling forward directly for user vs calling __call__. On the other hand, I think the Chainer way where __call__ and forward are the same thing, sounds better, and less confusing.

@yzhliu

This comment has been minimized.

Copy link
Member

yzhliu commented May 11, 2017

Seems we are going to manually define a bunch of NN layers in python, e.g., Pooling, BatchNorm, etc. Can they be generated automatically from c lib, just like the functions for symbol? Other frontend languages can follow the same way to do code generation.

@tdomhan

This comment has been minimized.

Copy link
Contributor

tdomhan commented May 13, 2017

this is a super exciting development. I'm also wondering what the relationship to the existing APIs like the Module API will be in the long run?

@jekbradbury

This comment has been minimized.

Copy link

jekbradbury commented May 13, 2017

This is awesome! My nvidia blogpost has one example of the kind of model that needs this sort of framework, but a more practically useful one that'll demonstrate the full power of this approach in MXNet is https://github.com/facebookresearch/clevr-iep. I think it would be surprisingly easy to port as a demonstration (one benefit of PyTorch not building in any training loop abstractions etc. is that there's less to port when you move it to another framework).

@piiswrong piiswrong force-pushed the piiswrong:nn branch from 5b29943 to 97de14b May 15, 2017

@pluskid

This comment has been minimized.

Copy link
Contributor

pluskid commented May 16, 2017

Several comments from previous conversations that I think might be good to post here for public discussions.

Merging Metric and Loss

The current proposal serialize a list of metrics inside a loss object. An alternative proposal is to unify metric and loss. A Metric is an object that

  • support forward that maps inputs to a scalar
  • support backward that by default is NOP (i.e. zero gradient)

A Loss is a subclass of Metric, that has meaningful backward implementation.

When constructing module, a list of Metric can be given. During training and scoring, all the metrics will be forward evaluated. During training, all the Loss will be backward evaluated to provide initial gradient for back-prop.

So the effective loss is the sum of all Loss subclasses in the list metrics supplied.

Layer and SimpleLayer

I think it might be good to rename Layer to BaseLayer, and SimpleLayer to Layer, as the user will mostly be interacting with Layer.

The simple_forward function, which supports both NDArray forward and symbolic forward, might be renamed to mixed_forward or generic_forward (name coming from C++ generics).

Whether to provide a built-in fit function

The current fit implementation on Module is actually simple. However, it looks complicated because tooooooo many keyword arguments are added to that function. One way to potentially avoid this is to ask users to always write their own training loop, which should be fairly straightforward.

To make common scenarios simple to use. We can provide a training loop code template, from which user can copy & paste and customize at will. Alternative, we could provide a simplified fit function, with only a few arguments, and state in the doc encouraging the user to write their own training loop if extension is needed.

@piiswrong piiswrong force-pushed the piiswrong:nn branch from 5bcf972 to 3c612ea May 17, 2017

loss = nn.loss.softmax_cross_entropy_loss(z, y)
ag.compute_gradient([loss])
outputs.append(z)
optim.step(batch.data[0].shape[0])

This comment has been minimized.

@mli

mli May 19, 2017

Member

change to batch_size

from ...contrib import autograd

# pylint: disable= invalid-name
tensor_types = (symbol.Symbol, ndarray.NDArray)

This comment has been minimized.

@mli

mli May 19, 2017

Member

you can change to

tensor_types = (symbol.Symbol, ndarray.NDArray)  # pylint: disable=invalid-name

@piiswrong piiswrong force-pushed the piiswrong:nn branch from b058bca to 7a537a0 May 28, 2017

@piiswrong piiswrong closed this May 30, 2017

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment