-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add save and load method for Module class #691
Comments
Do we need to save params of the model? |
Yes, says if we did the training and want to save the model, so that later we can deploy it as web service for inference This is common because we can train the model inside a cluster with many computational resource. |
Got it. So I think we better implement this feature on the Python side because the scheduler doesn't have information about the type of operator and it also has no concept of neural network layer. |
If it is in the python side, I guess the easiest way is use pickle to pack a python list or dict, but a drawback is that pickle cannot pack SWIG python object (if I am not wrong) |
Or we can serialize and deserialize data ourselves. There is no need to serialize the entire objects, we just need to save the state data. |
Yes, I guess it is something like this
|
Rafiki dumps model in ONNX format? |
No, I didn't mean that. They dumps model in different format not in ONNX. |
How are the topological connection( |
I guess "high/low priority" refers to the preference? In the "low priority" option, the |
Conslusion firstGood news:
Bad news:
ONNX Training Preview (TrainingInfoProto)In last week, the ONNX team has released a new version 1.7.0 which upgrade its opset version to 12. In this new rleases, they add a new feature called This new feature defines something about training information. There are two main parts in it, initialization-step
The current supported ramdom methods are: training-algorithm-step
In general, this graph contains loss node, gradient node, optimizer node, increment of iteration count, and some calls to the inference graph. The field algorithm.node is the only place the user can use GraphCall operator. Loss node
Optimizer node
Gradient nodeThe gradient node actually only defines the necessary information to compute the gradient for all graph, for example, at the following graph, the gradient defines its inputs containing the It doesn't defines any logic about how to compute the
GraphCall nodeThe GraphCall operator invokes a graph inside TrainingInfoProto's algorithm field. The GraphCall inputs and outputs are bound to those of invoked graph by position. Based on the above inference graph, the GraphCall can use like this:
The previous section's inference graph is called by
The it uses the following
|
we do a forward (e.g., using placeholder recorded by compile()) inside save() to get the output y and then trace back to get all operations. |
we provide two approaches.
|
Updated on May 15 Night class Layer:
def get_params(self):
"""the params of this layer and sublayers as a dict; param name is: layername.param
e.g., self.W = Tensor(), self.b=Tensor()
name of W and b is like conv1.W and conv1.b
"""
def get_states(self):
"""states of this layer as sublayers that are necessary for model training/evaluation/inference.
the states include the params and others, e.g., the running mean and var of batchnorm.
"""
class Module(Layer):
def compile(self ...):
"""set the name of each layer and sublayers, which will be used to create the dict
for get_params and get_states. Then no need to manually config the layer name
the __init__ method of a layer.
For instance,
class Blk(Layer):
def __init__(self):
self.conv1= Conv2d()
self.conv2 = Conv2d()
class MyModel(Module):
def __init__(self):
self.blk1 = Blk() --> blk1.conv1, blk1.conv2
self.blk2 = Blk() --> blk2.conv1, blk2.conv2
"""
# high priority
def save_states(self, fpath, aux_states={}):
"""Save states.
Args:
fpath: output file path (without the extension)
aux_states(dict): values are standard data types or Tensor,
e.g., epoch ID, learning rate, optimizer states
"""
states = get_states() + aux_states + input_placeholders
tensor_dict = {}
for k, v in states:
if type(v) is Tensor:
tensor_dict[k] = v
states[k] = {'shape': v.shape, 'dtype': v.dtype}
save states as json file
save tensor_dict via numpy or hdf5 or protobuf
zip the output files
def load_states(self, fpath, dev, use_graph=True, graph_alg='sequence'):
"""Load the model onto dev
Args:
path: input file path (without the extension)
Returns:
dict
```
unzip the input file
load the json file --> states
load the tensor files --> tensor_dict
put the tensors into states
states --> model_states + input_placeholders + aux_states
self.compile(input_placeholders, dev, use_graph, graph_alg)
model.set_states(model_states)
return the rest states as a dict
# lower priority
def save(fpath, model):
attributes <-- model
replace all tensors in attributes --> {'shape': v.shape, 'dtype': v.dtype}
dump the tensors via numpy or protobuf or hdf5
dump model via pickle
zip the output files
def load(fpath, dev, use_graph, graph_alg):
unzip the input file
load model via pickle
load tensors
restore the tensors in model attributes
return the model
# handle ONNX
def to_onnx(model):
return a onnx model
class SONNXModel(Module):
def __init__(self, onnx_model):
self.store_output = store_output
for layer_name, layer_config in get_layer(onnx_model):
self.__dict__[layer_name] = CreateLayer(...)
def forward(self, aux_output):
run forward according to onnx graph
return the last output + aux_output
class MyModel(SONNXModel):
def __init__(self, onnx):
super.__init__(onnx)
self.layer1 = Conv()
self.layer2 = Conv()
def forward(self, x):
x1, x2 = super.forward(x, aux_output)
x = self.layer1.forward(x2)
return self.layer2.forward(x1) + x
def train_one_batch(self, x, y):
y_ = self.forward(x)
.... Clarification:
|
If we have the model stats, we can recreate the params. Do the placeholders still make sense? I think we don't need to compile the module if we use the set_states function. |
The API is a bit ugly.. But we need to compile() to create the handles, which are not serialized as states. |
Got it. I thought handles were also state info. |
current save params and states
|
How about this one, we pareses onnx by |
Pls check my inline comments starting with **
It's good to reuse singa_rep. ox = onnx.load(fpath)
m = MyModel(ox)
m.compile([x]...) |
update code with the comments with And I need to update the current |
To be consistent, I think we'd better always call m=MyModel()
m.compile([x], use_graph=True)
m.load_states(fpath)
m=MyONNXModel(onnx_model)
m.compile([x], use_graph=True)
m=singa.load(fpath)
m.compile([x], use_graph=True) Then the Any better solution? |
Actually, in the above sonnx API, we merge load_states into compile, right?
|
|
Quoted from @joddiy , From the perspective of a new onnx user, please let me know if this part is not correct. Use case 1, load model from onnx fileclass MySONNXModel(SONNXModel):
pass # so we know the structure of model already?
# load from onnx model
onnx_model=onnx.load('./saved_models/onnx_model_downloaded')
m1=MySONNXModel(onnx_model) # so we know the structure of model already?
m1.compile([placeholder_x], ...)
for _ in data:
m1.train_one_batch(_) use case 2: save states and model# save
m1.save_states('./saved_models/my_checkpoint_1')
singa.save('./saved_models/my_model_1', m1) use case 3 load model and states from disk# Later reuse the model
m2=singa.load('./saved_models/my_model_1')
m2.load_states('./saved_models/my_checkpoint_1')
m2.compile([placeholder_x], use_graph=True) use case 4 load states only# singa model is known
class MyModel(Module):
pass
m3=MyModel(states_path='./saved_models/my_checkpoint_1') # could only be states, right?
# m3=MyModel('./saved_models/my_model_1') # could not be saved_model right? since we know the model
m3.compile(...) To be frank, I am a bit overwhelmed by all the discussions not just in this issue, is it possible to consolidate the new API into a specification including example in singa-doc? Which is useful for new users? btw, is API in onnx-doc gonna change? |
Here is the latest summary: https://gist.github.com/nudles/d7f8043f251872333ec06f2701696cce APIs in onnx-doc should be backward-compatible. |
save and load functions are now available on 3.1 Model Class |
Updated on May 15
Clarification:
Layer.get_params()
Layer.get_states()
class.__dict__
. Superset of states.The text was updated successfully, but these errors were encountered: