Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ChannelPrunedLearner don't exec #175

Open
as754770178 opened this issue Jan 7, 2019 · 13 comments
Open

ChannelPrunedLearner don't exec #175

as754770178 opened this issue Jan 7, 2019 · 13 comments

Comments

@as754770178
Copy link

I compress the model by ChannelPrunedLearner, when executed to self.sess_train.run(self.train_op) in __train_pruned_model(self, finetune=False) fun, the program don't continue execution。 I don't know This is because my machine is too card or code error。

my gpu card is Tesla K80。

the program stop at :

  def __train_pruned_model(self, finetune=False):
    """Train pruned model"""
    # Initialize varialbes
    self.sess_train.run(self.train_init_op)

    if FLAGS.enbl_multi_gpu:
      self.sess_train.run(self.bcast_op)

    ## Fintuning & distilling
    self.time_prev = timer()

    nb_iters = int(FLAGS.cp_nb_iters_ft_ratio * self.nb_iters_train) \
      if finetune and not FLAGS.cp_retrain else self.nb_iters_train

    for self.idx_iter in range(nb_iters):
      # train the model
      if (self.idx_iter + 1) % FLAGS.summ_step != 0:
        self.sess_train.run(self.train_op)  # stop at here
      else:
        __, summary, log_rslt = self.sess_train.run([self.train_op, self.summary_op, self.log_op])
        self.__monitor_progress(summary, log_rslt)

      # save the model at certain steps
      if (self.idx_iter + 1) % FLAGS.save_step == 0:
        #summary, log_rslt = self.sess_train.run([self.summary_op, self.log_op])
        #self.__monitor_progress(summary, log_rslt)
        if self.__is_primary_worker():
          self.__save_model() # sess_train
          self.evaluate()
@as754770178
Copy link
Author

This is log

INFO:tensorflow:build pruned training model
INFO:tensorflow:Restoring parameters from ./models/pruned_model.ckpt
INFO:tensorflow:Restoring parameters from ./models/pruned_model.ckpt
WARNING:tensorflow:From /home/zgz/anaconda2/envs/tf-1.8-cp3/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py:118: initialize_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.variables_initializer` instead.
INFO:tensorflow:training pruned model

I run pocket flow by some Learner, Can't run smoothly

@jiaxiang-wu
Copy link
Contributor

You mean the program does not have any outputs? Can you take a look at the GPU utility via nvidia-smi? If you decrease the value of FLAGS.summ_step (e.g. set it to 1), will there be any outputs?

@as754770178
Copy link
Author

nvidia-smi info:

 3  Tesla K80           On   | 00000000:86:00.0 Off |                    0 |
| N/A   74C    P0    74W / 149W |  10493MiB / 11439MiB |      0%      Default |

I change FLAGS.summ_step to 1, still don't print.

@jiaxiang-wu
Copy link
Contributor

Are you using the pre-defined models in PocketFlow, or are you using a self-defined model implemented by yourself?

@as754770178
Copy link
Author

using a self-defined model implemented by myself

@jiaxiang-wu
Copy link
Contributor

Does ChannelPrunedLearner hang when using pre-defined models in PocketFlow? May it is not compatible with your self-defined model?

@as754770178
Copy link
Author

Yes, but I want to know how to using a self-defined model implemented by myself. Can you tell me some method that positioning this problem.

@jiaxiang-wu
Copy link
Contributor

Can you post the code for your self-defined model (ModelHelper and the corresponding xxx_run.py script)?

@as754770178
Copy link
Author

I try change it, the code is too long.

@jiaxiang-wu
Copy link
Contributor

Or, can you upload the file?

@as754770178
Copy link
Author

as754770178 commented Jan 8, 2019

The file is also too many, I write a demo of my net. especially the dataset, I mock it and delete the data augmentation and pre-process.
net definition:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import collections

import tensorflow as tf
slim = tf.contrib.slim

class NoOpScope(object):
  """No-op context manager."""

  def __enter__(self):
    return None

  def __exit__(self, exc_type, exc_val, traceback):
    return False

class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
  """A named tuple describing a ResNet block."""

def subsample(inputs, factor, scope=None):
  if factor == 1:
    return inputs
  else:
    return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)

@slim.add_arg_scope
def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None, data_format='NHWC'):
  if stride == 1:
    return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate,
                       padding='SAME', scope=scope)
  else:
    kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
    pad_total = kernel_size_effective - 1
    pad_beg = pad_total // 2
    pad_end = pad_total - pad_beg
    if data_format == 'NHWC':
        inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
    else:
        inputs = tf.pad(inputs, [[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]])
    return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
                       rate=rate, padding='VALID', scope=scope)

@slim.add_arg_scope
def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
               outputs_collections=None, scope=None, use_bounded_activations=False,
               data_format='NHWC'):
  """Bottleneck residual unit variant with BN after convolutions.
  """
  with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
    if data_format == 'NHWC':
      depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
    else:
      depth_in = inputs.get_shape().as_list()[1]
    if depth == depth_in:
      shortcut = subsample(inputs, stride, 'shortcut')
    else:
      shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride,
                             activation_fn=tf.nn.relu6 if use_bounded_activations else None,
                             scope='shortcut')

    residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1,
                           scope='conv1')
    residual = conv2d_same(residual, depth_bottleneck, 3, stride,
                                        rate=rate, scope='conv2')
    residual = slim.conv2d(residual, depth, [1, 1], stride=1,
                           activation_fn=None, scope='conv3')

    if use_bounded_activations:
      # Use clip by value to simulate bandpass activatin
      residual = tf.clip_by_value(residual, -6.0, 6.0)
      output = tf.nn.relu6(shortcut + residual)
    else:
      output = tf.nn.relu(shortcut + residual)

    return slim.utils.collect_named_outputs(outputs_collections,
                                            sc.original_name_scope,
                                            output)

@slim.add_arg_scope
def residual_unit(inputs, depth, stride, rate=1,
                  outputs_collections=None, scope=None, data_format='NHWC'):
  with tf.variable_scope(scope, 'residual_unit_v1', [inputs]) as sc:
    if data_format == 'NHWC':
      depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
    else:
      depth_in = inputs.get_shape().as_list()[1]
    if depth == depth_in:
      shortcut = subsample(inputs, stride, 'shortcut')
    else:
      shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride,
                             activation_fn=None, scope='shortcut')

    residual = slim.conv2d(inputs, depth, [3, 3], stride=1,
                           scope='conv1')

    residual = slim.conv2d(residual, depth, [3, 3], stride=stride,
                           activation_fn=None, scope='conv2')

    output = tf.nn.relu(shortcut + residual)

    return slim.utils.collect_named_outputs(outputs_collections,
                                            sc.original_name_scope,
                                            output)

def resnet_v1_block(scope, base_depth, num_units, stride, use_bottleneck=True):

  if use_bottleneck is True:
      return Block(scope, bottleneck, [{
          'depth': base_depth * 4,
          'depth_bottleneck': base_depth,
          'stride': 1
      }] * (num_units - 1) + [{
          'depth': base_depth * 4,
          'depth_bottleneck': base_depth,
          'stride': stride
      }])
  else:
      return Block(scope, residual_unit, [{
          'depth': base_depth,
          'stride': 1,
      }] + (num_units - 1) * [{
          'depth': base_depth,
          'stride': stride,
      }])

@slim.add_arg_scope
def stack_blocks_dense(net, blocks, output_stride=None,
                       store_non_strided_activations=False,
                       outputs_collections=None):
  current_stride = 1

  # The atrous convolution rate parameter.
  rate = 1

  for block in blocks:
    with tf.variable_scope(block.scope, 'block', [net]) as sc:
      block_stride = 1
      for i, unit in enumerate(block.args):
        if store_non_strided_activations and i == len(block.args) - 1:
          # Move stride from the block's last unit to the end of the block
          block_stride = unit.get('stride', 1)
          unit = dict(unit, stride=1)
        if output_stride is not None and current_stride > output_stride:
          raise ValueError('The target output_stride cannot be reached.')

        with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
          # If we have reached the target output_stride, then we need to employ
          # atrous convolution with stride=1 and multiply the atrous rate by the
          # current unit's stride for use in subsequent layers.
          if output_stride is not None and current_stride == output_stride:
            net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
            rate *= unit.get('stride', 1)

          else:
            net = block.unit_fn(net, rate=1, **unit)
            current_stride *= unit.get('stride', 1)
            if output_stride is not None and current_stride > output_stride:
              raise ValueError('The target output_stride cannot be recalled')

      net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)

      # Subsampling of the block's output activations.
      if output_stride is not None and current_stride == output_stride:
        rate *= block_stride
      else:
        net = subsample(net, block_stride)
        current_stride *= block_stride
        if output_stride is not None and current_stride > output_stride:
          raise ValueError('The target output_stride cannot be reached.')

  if output_stride is not None and current_stride != output_stride:
    raise ValueError('The target output_stride cannot be reached.')

  return net

@slim.add_arg_scope
def resnet_v1(inputs,
              blocks,
              num_classes=None,
              is_training=True,
              global_pool=True,
              output_stride=None,
              include_root_block=True,
              spatial_squeeze=False,
              store_non_strided_activations=False,
              reuse=None,
              scope=None,
              data_format='NHWC'):
  with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc:
    end_points_collection = sc.name + '_end_points'
    with slim.arg_scope([slim.conv2d, bottleneck,
                         stack_blocks_dense],
                        outputs_collections=end_points_collection):
      with (slim.arg_scope([slim.batch_norm], is_training=is_training)
            if is_training is not None else NoOpScope()):
        net = inputs
        if include_root_block:
          if output_stride is not None:
            if output_stride % 4 != 0:
              raise ValueError('The output_stride needs to be a multiple of 4.')
            output_stride /= 4
          net = conv2d_same(net, 64, 7, stride=2, scope='conv1')
          net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
        net = stack_blocks_dense(net, blocks, output_stride,
                                              store_non_strided_activations)

        # Convert end_points_collection into a dictionary of end_points.
        end_points = slim.utils.convert_collection_to_dict(
          end_points_collection)

        if global_pool:
          # Global average pooling.
          if data_format == 'NHWC':
            net = tf.reduce_mean(net, [1, 2], name='pool5', keepdims=True)
            end_points['global_pool'] = net
          else:
            net = tf.reduce_mean(net, [2, 3], name='pool5', keepdims=True)
            end_points['global_pool'] = net
        if num_classes is not None:
          net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                            normalizer_fn=None, scope='logits')
          end_points[sc.name + '/logits'] = net
          if spatial_squeeze:
            if data_format == 'NHWC':
              net = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
              end_points[sc.name + '/spatial_squeeze'] = net
            else:
              net = tf.squeeze(net, [2, 3], name='SpatialSqueeze')
              end_points[sc.name + '/spatial_squeeze'] = net
          end_points['predictions'] = slim.softmax(net, scope='predictions')
        return net, end_points
resnet_v1.default_image_size = 224
resnet_v1.default_labels_offset = 1
resnet_v1.default_logits_pattern = 'logits'
run
def resnet_v1_18(inputs,
                 num_classes=None,
                 is_training=True,
                 global_pool=True,
                 output_stride=None,
                 spatial_squeeze=True,
                 store_non_strided_activations=False,
                 reuse=None,
                 scope='resnet_v1_18'):
  """ResNet-18 model of [1]. See resnet_v1() for arg and return description."""
  blocks = [
      resnet_v1_block('block1', base_depth=64, num_units=2, stride=2, use_bottleneck=False),
      resnet_v1_block('block2', base_depth=128, num_units=2, stride=2, use_bottleneck=False),
      resnet_v1_block('block3', base_depth=256, num_units=2, stride=2, use_bottleneck=False),
      resnet_v1_block('block4', base_depth=512, num_units=2, stride=1, use_bottleneck=False),
  ]
  return resnet_v1(inputs, blocks, num_classes, is_training,
                   global_pool=global_pool, output_stride=output_stride,
                   include_root_block=True, spatial_squeeze=spatial_squeeze,
                   store_non_strided_activations=store_non_strided_activations,
                   reuse=reuse, scope=scope)
resnet_v1_18.default_image_size = resnet_v1.default_image_size
resnet_v1_18.default_labels_offset = resnet_v1.default_labels_offset
resnet_v1_18.default_logits_pattern = resnet_v1.default_logits_pattern


def resnet_arg_scope(weight_decay=0.0001,
                     batch_norm_decay=0.997,
                     batch_norm_epsilon=1e-5,
                     batch_norm_scale=True,
                     data_format='NHWC',
                     use_batch_norm=True,
                     batch_norm_updates_colections=tf.GraphKeys.UPDATE_OPS,
                     **kwargs):
  batch_norm_params = {
      'decay': batch_norm_decay,
      'epsilon': batch_norm_epsilon,
      'scale': batch_norm_scale,
      'updates_collections': batch_norm_updates_colections,
  }

  if 'weights_initializer_params' in kwargs:
    weights_initializer_params = kwargs['weights_initializer_params']
  else:
    weights_initializer_params = {}

  with slim.arg_scope(
      [slim.conv2d],
      weights_regularizer=slim.l2_regularizer(weight_decay),
      weights_initializer=slim.variance_scaling_initializer(**weights_initializer_params),
      activation_fn=tf.nn.relu,
      normalizer_fn=slim.batch_norm if use_batch_norm else None,
      normalizer_params=batch_norm_params):
    with slim.arg_scope([conv2d_same], data_format=data_format):
      with slim.arg_scope([slim.batch_norm], **batch_norm_params):
        with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
          return arg_sc

my model helper:

import tensorflow as tf
import numpy as np

from pocket_flow.nets.abstract_model_helper import AbstractModelHelper
from pocket_flow.utils.lrn_rate_utils import setup_lrn_rate_piecewise_constant
from pocket_flow.utils.lrn_rate_utils import setup_lrn_rate_exponential_decay
from .nets.resnet_v1_18 import resnet_v1_18
from .nets.resnet_v1_18 import resnet_arg_scope
slim = tf.contrib.slim

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_float('lrn_rate_init', 1e-2, 'initial learning rate')
tf.flags.DEFINE_float('batch_size_norm', 128, 'normalization factor of batch size')
tf.flags.DEFINE_float('loss_w_dcy', 2e-4, 'weight decaying loss\'s coefficient')
tf.flags.DEFINE_float('nb_epochs_rat', 1.0, '# of training epochs\'s ratio')
tf.flags.DEFINE_integer('learning_rate_version', 1, 'Learning rate\'s version '
                                                        '(1: Setup the learning rate with piecewise constant strategy.'
                                                        ' or 2: Setup the learning rate with exponential decaying strategy.)')
tf.flags.DEFINE_integer('compress_nb_epochs', 100,
                       'Max number of epochs to compress once.')
tf.flags.DEFINE_string('compress_idxs_epoch', '30,50,100',
                       'Use with compress_decay_rates in Learning rate version 1.')
tf.flags.DEFINE_string('compress_decay_rates', '0.1,0.01,0.001,0.0001',
                       'Use with compress_idxs_epoch in Learning rate version 1.')
tf.flags.DEFINE_float('compress_epoch_step', 2.5,
                       'Use with compress_decay_rate in Learning rate version 2.')
tf.flags.DEFINE_float('compress_decay_rate', 0.98,
                       'Use with compress_epoch_step in Learning rate version 2.')


class ModelHelper(AbstractModelHelper):
  """Model helper for creating a ResNet model for the CIFAR-10 dataset."""

  def __init__(self, image_size=224, data_fromat='NHWC', num_classes=5, batch_size=64):
    """Constructor function."""

    # class-independent initialization
    super(ModelHelper, self).__init__()
    self.image_size = image_size
    self.data_fromat = data_fromat
    self.num_classes = num_classes
    self.batch_size = batch_size

  def build_dataset_train(self, enbl_trn_val_split=False):
    """Build the data subset for training, usually with data augmentation."""

    dataset = tf.data.Dataset.from_tensor_slices({
      "image":np.random.uniform(size=(640, self.image_size, self.image_size, 3)),
      "label":np.random.uniform(size=(self.batch_size, self.num_classes))
    })

    dataset = dataset.shuffle(1000).batch(self.batch_size)
    iterator = dataset.make_initializable_iterator()

    return iterator

  def build_dataset_eval(self):
    """Build the data subset for evaluation, usually without data augmentation."""

    dataset = tf.data.Dataset.from_tensor_slices({
      "image": np.random.uniform(
        size=(640, self.image_size, self.image_size, 3)),
      "label": np.random.uniform(size=(self.batch_size, self.num_classes))
    })

    dataset = dataset.shuffle(1000).batch(self.batch_size)
    iterator = dataset.make_initializable_iterator()

    return iterator

  def forward_train(self, inputs):
    arg_scope = resnet_arg_scope(weight_decay=FLAGS.weight_decay,
                                 data_format=self.data_fromat)
    with slim.arg_scope(arg_scope):
      with slim.arg_scope([slim.bias_add, slim.batch_norm, slim.conv2d,
                           slim.conv2d_in_plane, slim.conv2d_transpose,
                           slim.avg_pool2d, slim.max_pool2d,
                           slim.unit_norm],
                          data_format=self.data_fromat):
        with slim.arg_scope([slim.batch_norm], fused=True,
                            renorm=False):
          logits, end_points = resnet_v1_18(inputs,
                                            num_classes=5,
                                            is_training=True)
    return logits

  def forward_eval(self, inputs):
    arg_scope = resnet_arg_scope(weight_decay=FLAGS.weight_decay,
                                     data_format=self.data_fromat)
    with slim.arg_scope(arg_scope):
      with slim.arg_scope([slim.bias_add, slim.batch_norm, slim.conv2d,
                           slim.conv2d_in_plane, slim.conv2d_transpose,
                           slim.avg_pool2d, slim.max_pool2d,
                           slim.unit_norm],
                          data_format=self.data_fromat):
        with slim.arg_scope([slim.batch_norm], fused=True,
                            renorm=False):
          logits, end_points = resnet_v1_18(inputs,
                                            num_classes=5,
                                            is_training=False)
    return logits

  def calc_loss(self, labels, outputs, trainable_vars):
    """Calculate loss (and some extra evaluation metrics)."""

    loss = tf.losses.softmax_cross_entropy(labels, outputs)
    loss_filter = lambda var: 'batch_normalization' not in var.name
    reg_vars = [tf.nn.l2_loss(var) for var in trainable_vars if loss_filter(var)]
    loss += FLAGS.loss_w_dcy \
      * tf.add_n(reg_vars)
    accuracy = tf.reduce_mean(
      tf.cast(tf.equal(tf.argmax(labels, axis=1), tf.argmax(outputs, axis=1)), tf.float32))
    metrics = {'accuracy': accuracy}

    return loss, metrics

  def setup_lrn_rate(self, global_step):
    """Setup the learning rate (and number of training iterations)."""

    #batch_size = FLAGS.batch_size * (1 if not FLAGS.enbl_multi_gpu else mgw.size())
    batch_size = FLAGS.batch_size
    if FLAGS.learning_rate_version == 1:
      nb_epochs = FLAGS.compress_nb_epochs or 100
      idxs_epoch = [int(int_item) for int_item in FLAGS.compress_idxs_epoch.split(',')] or [30, 60, 80, 90]
      decay_rates = [float(float_item) for float_item in FLAGS.compress_decay_rates.split(',')] or [1.0, 0.1, 0.01, 0.001, 0.0001]
      lrn_rate = setup_lrn_rate_piecewise_constant(global_step, batch_size, idxs_epoch, decay_rates, self.train_dataset_meta)
      nb_iters = int(self.train_dataset_meta.total_num_samples * nb_epochs * FLAGS.nb_epochs_rat / batch_size)
    elif FLAGS.learning_rate_version == 2:
      nb_epochs = FLAGS.compress_nb_epochs or 412
      epoch_step = FLAGS.compress_epoch_step or 2.5
      decay_rate = FLAGS.compress_decay_rate or 0.98
      decay_rate = decay_rate ** epoch_step  # which is better, 0.98 OR (0.98 ** epoch_step)?
      lrn_rate = setup_lrn_rate_exponential_decay(global_step, batch_size, epoch_step, decay_rate, self.train_dataset_meta)
      nb_iters = int(self.train_dataset_meta.total_num_samples * nb_epochs * FLAGS.nb_epochs_rat / batch_size)
    else:
      raise ValueError('invalid MobileNet version: {}'.format(FLAGS.learning_rate_version))

    return lrn_rate, nb_iters

  @property
  def model_name(self):
    """Model's name."""

    return FLAGS.model_name

  @property
  def dataset_name(self):
    """Dataset's name."""

    return FLAGS.dataset_name

my call func:

model_helper = ModelHelper()
learner = create_learner(sm_writer, model_helper)
learner.train()

@as754770178
Copy link
Author

@jiaxiang-wu
The input of pretrained model , I define in next code.

  def input_fn(run_mode, **kwargs):
    if use_eval_data_in_training:
      dataset = eval_dataset
    else:
      dataset = train_dataset

    with tf.variable_scope("data"):
      image, label = dataset.get(['image', 'label'])
      mem_images = tf.placeholder(dtype=image.dtype,
                                  shape=image.shape)
      mem_labels = tf.placeholder(dtype=label.dtype,
                                  shape=label.shape)

    tf.add_to_collection('mem_images', mem_images)
    tf.add_to_collection('mem_labels', mem_labels)
    if run_mode == 'train':
      tf.add_to_collection('train_images', image)
      tf.add_to_collection('train_labels', label)
    else:
      tf.add_to_collection('eval_images', image)
      tf.add_to_collection('eval_labels', label)

    image, label = dataset.get(['image', 'label'])
    return image, label

I only define placeholder as mem_images or mem_labels and add to collection. Do you do like this?

@as754770178
Copy link
Author

I read the code in learner.py.

  def __finetune_pruned_model(self, path=None, finetune=False):
    if path is None:
      path = FLAGS.cp_channel_pruned_path
    start = timer()
    tf.logging.info('build pruned evaluating model')
    self.__build_pruned_evaluate_model(path)
    tf.logging.info('build pruned training model')
    self.__build_pruned_train_model(path, finetune=finetune)
    tf.logging.info('training pruned model')
    self.__train_pruned_model(finetune=finetune)
    tf.logging.info('fintuning time cost: {}s'.format(timer() - start))

Why edit the graph in __build_pruned_train_model ?

train_images = tf.get_collection('train_images')[0]
      train_labels = tf.get_collection('train_labels')[0]
      mem_images = tf.get_collection('mem_images')[0]
      mem_labels = tf.get_collection('mem_labels')[0]

      self.sess_train.close()

      graph_editor.reroute_ts(train_images, mem_images)
      graph_editor.reroute_ts(train_labels, mem_labels)

      self.sess_train = tf.Session(config=config)
      self.saver_train.restore(self.sess_train, path)

When train the pruned model in __train_pruned_model, the code don't use mem_images. and my code stop at self.sess_train.run(self.train_op).

    for self.idx_iter in range(nb_iters):
      # train the model
      if (self.idx_iter + 1) % FLAGS.summ_step != 0:
        self.sess_train.run(self.train_op)
      else:
        __, summary, log_rslt = self.sess_train.run([self.train_op, self.summary_op, self.log_op])
        self.__monitor_progress(summary, log_rslt)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants