-
Notifications
You must be signed in to change notification settings - Fork 403
Using pre-trained model in a way like TensorflowHub #254
Comments
Hi @johnnychhsu an example of transfer learning is presented here -- #140. It's still a work in progress, mainly we need a generic user interface (via config file) to handle these changes so that the user is able to specify which variables to be initialised from checkpoint files and the others from random initialisation. If you are also interested in this direction, we could collaborate on it. Alternatively if you want to have full control of the graph, the code example attached could be a good starting point. A caveat is that the preprocessing layers (not included here) should exactly follow the ones used for training e.g. https://github.com/NifTK/NiftyNet/blob/v0.3.0/niftynet/application/segmentation_application.py#L178. import tensorflow as tf
import os
from niftynet.io.image_reader import ImageReader
from niftynet.engine.sampler_resize import ResizeSampler
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"]='1'
##### Address of the model to be restored
check_point_location='/home/niftynet/models/dense_vnet_abdominal_ct/models/model.ckpt-3000'
#####
##### Create a sampler
data_param = {'image': {'path_to_search': '~/niftynet/data/dense_vnet_abdominal_ct',
'filename_contains': 'CT', 'spatial_window_size': (144, 144, 144)}}
reader = ImageReader().initialise(data_param)
sampler = ResizeSampler(
reader=reader,
data_param=data_param,
batch_size=1,
shuffle_buffer=True,
queue_length=35)
#####
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
sampler.run_threads(sess, tf.train.Coordinator(), num_threads=1)
from niftynet.network.dense_vnet import DenseVNet
data_dict = sampler.pop_batch_op()
net_logits = DenseVNet(num_classes=9)(data_dict['image'])
# restore the variables
saver = tf.train.Saver()
saver.restore(sess, check_point_location)
net_logits = sess.run(net_logits)
print(net_logits.shape) |
Thank you! I would like to work on it, any suggestion ? I think I can try this, thanks! |
@wyli
The sampler in branch is sampler_resize_v2. The input is tensor and feed into the graph when run sess. However still got some error. The error is
Do you have any idea? Thank you! |
Hi @johnnychhsu The problem in your code is the with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
from niftynet.network.dense_vnet import DenseVNet
data_dict = sampler.pop_batch_op()
net_logits = DenseVNet(num_classes=9)(data_dict['image'])
saver = tf.train.Saver()
saver.restore(sess, check_point_location)
net_logits = sess.run(net_logits)
print(net_logits.shape) If you want to use placeholders, one possibility would be overriding the sampler's dataset initialisation (https://github.com/NifTK/NiftyNet/blob/dev/niftynet/engine/image_window_dataset.py#L207) |
@wyli Thank you! |
Yes @johnnychhsu, there are a few tasks in this direction:
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) as sess:
# TensorFlowHub
module_url = "https://tfhub.dev/google/nnlm-en-dim128-with-normalization/1"
segmentation_net = hub.Module(module_url)
my_image_placeholder = tf.Placeholder(...)
output = segmentation_net(my_image_placeholder)
...
# loading data with niftynet's IO
data_dict = sampler()
# TODO: make data_dict compatible with segmentation_net
sess.run(finetuning_model_op, feed_dict={my_image_placeholder: data_dict})
...
Combining these features together we would be able to finetue a tensorflowhub model in niftynet without writing Python code :) |
HI @wyli |
@johnnychhsu |
Your custom data doesn't seem to be the same dimension as the data the network was trained on.
Without knowing more about what you're trying to do, we can't help you. I would suggest following these steps before requesting additional help however. |
@johnnychhsu were you able to use placeholders with niftynet pipeline? I'm unable to restore the model when using a placeholder, please let me know you fixed this issue. Thank you |
@wyli is it possible to train a network in a similar method of tf Sessions? |
I am wondering if we can use NiftyNet pre-trained model in a way like TensorflowHub,
such as the example from TensorflowHub :
I want to do transfer learning using the pre-trained model on my dataset, I think I can modify the config file to do what I want, but I am used to the tensorflowhub way. Thus asking this question.
So the question are :
Thank you!
The text was updated successfully, but these errors were encountered: