Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Improve mxnet support for activity classifier save/load #129

Merged
merged 2 commits into from
Dec 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def setUpClass(self):
}
self.exposed_fields_ans = list(self.get_ans.keys())
self.fields_ans = self.exposed_fields_ans + ['_recalibrated_batch_size',
'_loss_model', '_pred_model', '_id_target_map',
'_pred_model', '_id_target_map',
'_predictions_in_chunk', '_target_id_map']


Expand Down
14 changes: 14 additions & 0 deletions src/unity/python/turicreate/toolkits/_internal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,20 @@ def _validate_row_label(dataset, label=None, default_label='__id'):
## Return the modified dataset and label
return dataset, label

def _model_version_check(file_version, code_version):
"""
Checks if a saved model file with version (file_version)
is compatible with the current code version (code_version).
Throws an exception telling the user to upgrade.
"""
if (file_version > code_version):
raise RuntimeError("Failed to load model file.\n\n"
"The model that you are trying to load was saved with a newer version of\n"
"Turi Create than what you have. Please upgrade before attempting to load\n"
"the file again:\n"
"\n"
" pip install -U turicreate\n")

def _mac_ver():
"""
Returns Mac version as a tuple of integers, making it easy to do proper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,15 @@ def create(dataset, session_id, target, features=None, prediction_window=100,
# Train the model
log = _fit_model(loss_model, data_iter, valid_iter,
max_iterations, num_gpus, verbose)

# Set up prediction model
pred_model.bind(data_shapes=data_iter.provide_data, label_shapes=None,
for_training=False, shared_module=loss_model)
for_training=False)
arg_params, aux_params = loss_model.get_params()
pred_model.init_params(arg_params=arg_params, aux_params=aux_params)

# Save the model
state = {
'_loss_model': loss_model,
'_pred_model': pred_model,
'verbose': verbose,
'training_time': _time.time() - start_time,
Expand Down Expand Up @@ -258,21 +261,19 @@ class ActivityClassifier(_CustomModel):
This model should not be constructed directly.
"""

_PYTHON_ACTIVITY_CLASSIFIER_VERSION = 1
_PYTHON_ACTIVITY_CLASSIFIER_VERSION = 2

def __init__(self, state):
self.__proxy__ = _PythonProxy(state)

def _get_native_state(self):
state = self.__proxy__.get_state()
state['_loss_model'] = _mxnet_utils.get_mxnet_state(state['_loss_model'])
state['_pred_model'] = _mxnet_utils.get_mxnet_state(state['_pred_model'])
return state

@classmethod
def _load_version(cls, state, version):
if (version > cls._PYTHON_ACTIVITY_CLASSIFIER_VERSION):
raise RuntimeError("Corrupted model. Cannot load a model with this version.")
_tkutl._model_version_check(version, cls._PYTHON_ACTIVITY_CLASSIFIER_VERSION)

data_seq_len = state['prediction_window'] * state['_predictions_in_chunk']
data = {'data': (state['_recalibrated_batch_size'], data_seq_len, len(state['features']))}
Expand All @@ -281,12 +282,27 @@ def _load_version(cls, state, version):
('weights', (state['_recalibrated_batch_size'], state['_predictions_in_chunk'], 1))
]

from ._model_architecture import _define_model
import mxnet as _mx
context = _mxnet_utils.get_mxnet_context(max_devices=state['num_sessions'])
state['_loss_model'] = _mxnet_utils.load_mxnet_model_from_state(
state['_loss_model'], data, labels, None, context)
Copy link
Collaborator

@igiloh igiloh Dec 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this mean we're no longer backward compatible?
If someone saved a model using version 1, the weights are now saved only in the loss model, and therefore when later in lines 301-303 when loading params from state['_pred_model'] they would be all zeros, won't they?

I can understand not being forward compatible (model saved in new version should not load in old version). But backwards compatibility is important.
We could check for if version==1 or '_loss_model' in state then extract the params from loss model, else extract from pred model.
Right?

Copy link
Collaborator Author

@gustavla gustavla Dec 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! In the current v4, there is no weight sharing when it gets saved to file. All weights are saved twice. Looking at the actual saved files, a model saved with v4 takes 4 MB while a model saved with v4+ takes 2 MB. Therefore, there is no problem for v4+ to simply ignore half of those weights and load the model entirely from the pred_model.

Also, regarding backward compatibility. Every cell in the 6x6 matrix I showed in the original post is the result of an actual test and not just my hopes (I wanted to be very thorough!), so I have tested and verified full backward compatibility.

_, _pred_model = _define_model(state['features'], state['_target_id_map'],
state['prediction_window'],
state['_predictions_in_chunk'], context)

batch_size = state['batch_size']
preds_in_chunk = state['_predictions_in_chunk']
win = state['prediction_window'] * preds_in_chunk
num_features = len(state['features'])
data_shapes = [('data', (batch_size, win, num_features))]
target_shape= (batch_size, preds_in_chunk, 1)

_pred_model.bind(data_shapes=data_shapes, label_shapes=None,
for_training=False)
arg_params = _mxnet_utils.params_from_dict(state['_pred_model']['arg_params'])
aux_params = _mxnet_utils.params_from_dict(state['_pred_model']['aux_params'])
_pred_model.init_params(arg_params=arg_params, aux_params=aux_params)
state['_pred_model'] = _pred_model

state['_pred_model'] = _mxnet_utils.load_mxnet_model_from_state(
state['_pred_model'], data, None, state['_loss_model'], context)
return ActivityClassifier(state)

@classmethod
Expand Down Expand Up @@ -322,7 +338,7 @@ def export_coreml(self, filename):
(prob_name, _cmt.models.datatypes.Array(*(self.num_classes,)))
]

model_params = self._loss_model.get_params()
model_params = self._pred_model.get_params()
weights = {k: v.asnumpy() for k, v in model_params[0].items()}
weights = _mx.rnn.LSTMCell(num_hidden=_net_params['lstm_h']).unpack_weights(weights)
moving_weights = {k: v.asnumpy() for k, v in model_params[1].items()}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,12 @@ def _get_native_state(self):
return state

@classmethod
def _load_version(self, state, version):
def _load_version(cls, state, version):
"""
A function to load a previously saved ImageClassifier
instance.
"""
_tkutl._model_version_check(version, cls._PYTHON_IMAGE_CLASSIFIER_VERSION)
from turicreate.toolkits.classifier.logistic_classifier import LogisticClassifier
state['classifier'] = LogisticClassifier(state['classifier'])
state['classes'] = state['classifier'].classes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def _get_native_state(self):
return state

@classmethod
def _load_version(self, state, version):
def _load_version(cls, state, version):
"""
A function to load a previously saved ImageClassifier
instance.
Expand All @@ -191,6 +191,7 @@ def _load_version(self, state, version):
version : int
Version number maintained by the class writer.
"""
_tkutl._model_version_check(version, cls._PYTHON_IMAGE_SIMILARITY_VERSION)
from turicreate.toolkits.nearest_neighbors import NearestNeighborsModel
state['similarity_model'] = NearestNeighborsModel(state['similarity_model'])
# Load pre-trained model & feature extractor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,8 @@ def _get_version(self):

@classmethod
def _load_version(cls, state, version):
_tkutl._model_version_check(version, cls._PYTHON_OBJECT_DETECTOR_VERSION)
from ._model import tiny_darknet as _tiny_darknet
if (version > cls._PYTHON_OBJECT_DETECTOR_VERSION):
raise RuntimeError("Corrupted model. Cannot load a model with this version.")

num_anchors = len(state['anchors'])
num_classes = state['num_classes']
Expand Down