-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Disable weight update for the vgg parameters #2990
Disable weight update for the vgg parameters #2990
Conversation
tensorflow_variables[key] = _tf.Variable( | ||
initial_value=_utils.convert_conv2d_coreml_to_tf(net_params[key]), | ||
name=key, | ||
trainable=trainable, | ||
trainable=train_param, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the transformer layer? Why is trainable parameter switching between True and False for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So trainable specifies if the convolutional layers of the transformer
network are trainable or not. The instance normalization layers are always trainable whereas the convolutional layers of the vgg
network should never be trainable. Since the vgg
network only contains convolutional layers. and the only other network that contains convolutional layers is the transformer
network, this split works.
src/python/turicreate/toolkits/style_transfer/_tf_model_architecture.py
Outdated
Show resolved
Hide resolved
Returns | ||
------- | ||
out: dict | ||
The TF Variable dictionary. | ||
""" | ||
tensorflow_variables = dict() | ||
for key in net_params.keys(): | ||
if "weight" in key: | ||
if "conv" in key: | ||
if 'weight' in key: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we going back to single quotes here? Are not using the new formatter/linter?
Overview
All the parameters were being trained in the TF style transfer implementation. Only the Transformer network needed to be trained.
Previous Results
New Results