Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Update model to add save/load and period checkpoint #105

Merged
merged 2 commits into from
Sep 20, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import model
from . import initializer
from . import visualization
import atexit
# use viz as short for mx.ndarray
from . import visualization as viz

__version__ = "0.1.0"
41 changes: 41 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,44 @@ def ctypes2numpy_shared(cptr, shape):
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)


def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.

num_args : mx_uint
Number of arguments.

arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.

arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.

arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.

remove_dup : boolean, optional
Whether remove duplication or not.

Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args.value):
key = py_str(arg_names[i])
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = py_str(arg_types[i])
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
15 changes: 4 additions & 11 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .base import _LIB
from .base import c_array, c_str, mx_uint, py_str
from .base import DataIterHandle, NDArrayHandle
from .base import check_call
from .base import check_call, ctypes2docstring
from .ndarray import NDArray

class DataIter(object):
Expand Down Expand Up @@ -99,24 +99,17 @@ def _make_io_iterator(handle):
ctypes.byref(arg_types), \
ctypes.byref(arg_descs)))
iter_name = py_str(name.value)
param_str = []
for i in range(num_args.value):
ret = '%s : %s' % (arg_names[i], arg_types[i])
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)

doc_str = ('%s\n\n' +
'Parameters\n' +
'----------\n' +
'%s\n' +
'name : string, required.\n' +
' Name of the resulting data iterator.\n\n' +
'Returns\n' +
'-------\n' +
'iterator: Iterator\n'+
'iterator: DataIter\n'+
' The result iterator.')
doc_str = doc_str % (desc.value, '\n'.join(param_str))
doc_str = doc_str % (desc.value, param_str)

def creator(*args, **kwargs):
"""Create an iterator.
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self):
def update(self, pred, label):
pred = pred.asnumpy()
label = label.asnumpy().astype('int32')
y = np.argmax(pred, axis=1)
self.sum_metric += np.sum(y == label)
py = np.argmax(pred, axis=1)
self.sum_metric += np.sum(py == label)
self.num_inst += label.size


Expand Down
Loading