Skip to content

Commit

Permalink
C-SGD codes
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnDing1994 committed Apr 30, 2019
1 parent c0f262a commit 48c77c7
Show file tree
Hide file tree
Showing 15 changed files with 2,435 additions and 0 deletions.
202 changes: 202 additions & 0 deletions cr_base.py
@@ -0,0 +1,202 @@
import os
import numpy as np
import tensorflow as tf
from mg_train import mg_train
from csgd_utils import calculate_eqcls_evenly, calculate_eqcls_biasly, calculate_eqcls_by_kmeans, calculate_bn_eqcls_dc40, tfm_prune_filters_and_save_dc40, tfm_prune_filters_and_save
from tf_utils import log_important
from tfm_callbacks import MergeGradientHandler
from tfm_origin_eval import evaluate_model

CR_OVERALL_LOG_FILE = 'cr_overall_logs.txt'

def cr_base_pipeline(network_type, train_dataset, eval_dataset, train_mode, eval_mode,
normal_builder_type, model_type,
subsequent_strategy, eqcls_follow_dict,
fc_layer_idxes,
st_gpu_idxes,
l2_factor, eval_batch_size, image_size,

try_arg, target_deps, origin_deps, pretrained_model_file,
cluster_method, eqcls_layer_idxes,

st_batch_size_per_gpu, st_max_epochs_per_gpu, st_lr_epoch_boundaries, st_lr_values,

num_steps_per_ckpt_st, frequently_save_interval, frequently_save_last_epochs,

diff_factor,

schedule_vector,

restore_itr=0, restore_st_step=0,

init_step=0,

slow_on_vec=False
):


assert cluster_method in ['kmeans', 'even', 'biased']

origin_deps = np.array(origin_deps)

prefix = '{}_{}'.format(network_type, try_arg)
train_dir_pattern = prefix + '_itr{}_train'
ckpt_dir = prefix + '_ckpt'
important_log_file = prefix + '_important_log.txt'
eqcls_file_pattern = prefix + '_itr{}_eqcls.npy'
sted_weights_pattern = prefix + '_itr{}_sted.hdf5'
pruned_weights_pattern = prefix + '_itr{}_prunedweights.hdf5'

target_deps = np.ceil(target_deps).astype(np.int32)
deps_to_prune = origin_deps - target_deps
for itr in range(restore_itr, len(schedule_vector)):
if itr == 0:
cur_start_model_file = pretrained_model_file
else:
cur_start_model_file = pruned_weights_pattern.format(itr - 1)
if itr == 0:
cur_remain_deps = np.ceil(origin_deps).astype(np.int32)
else:
cur_remain_deps = np.ceil(origin_deps - (deps_to_prune * schedule_vector[itr - 1])).astype(np.int32)

next_remain_deps = np.ceil(origin_deps - (deps_to_prune * schedule_vector[itr])).astype(np.int32)

cur_train_dir = train_dir_pattern.format(itr)
cur_sted_weights = sted_weights_pattern.format(itr)
cur_pruned_weights = pruned_weights_pattern.format(itr)

str_start_from = cur_start_model_file or 'SCRATCH'
log_important(
'CSGD: start itr {}, start from {}, st to {}, cur remain deps {}, next remain deps {}, cur train dir {}'
.format(itr, str_start_from, cur_sted_weights,
list(cur_remain_deps), list(next_remain_deps), cur_train_dir),
log_file=important_log_file)

eval_builder = normal_builder_type(training=False, deps=cur_remain_deps)
train_builder = normal_builder_type(training=True, deps=cur_remain_deps)
eqcls_file = eqcls_file_pattern.format(itr)

if itr != restore_itr:
restore_st_step = 0
if restore_st_step >= 0:

# calculate eqcls and test
test_model = model_type(dataset=eval_dataset, inference_fn=eval_builder.build, mode=eval_mode,
batch_size=eval_batch_size,
image_size=image_size)
test_model.load_weights_from_file(cur_start_model_file)
eqcls_dict = {}
log_important('I calculate eqcls by {}'.format(cluster_method), log_file=important_log_file)
if os.path.exists(eqcls_file):
eqcls_dict = np.load(eqcls_file).item()

for i in eqcls_layer_idxes:
if i in eqcls_dict:
continue
if cluster_method == 'kmeans':
eqcls = calculate_eqcls_by_kmeans(test_model, i, next_remain_deps[i])
elif cluster_method == 'even':
eqcls = calculate_eqcls_evenly(filters=cur_remain_deps[i], num_eqcls=next_remain_deps[i])
elif cluster_method == 'biased':
eqcls = calculate_eqcls_biasly(filters=cur_remain_deps[i], num_eqcls=next_remain_deps[i])
else:
assert False
eqcls_dict[i] = eqcls
np.save(eqcls_file, eqcls_dict)

if eqcls_follow_dict is not None:
for k, v in eqcls_follow_dict.items():
if v in eqcls_dict:
eqcls_dict[k] = eqcls_dict[v]

if network_type == 'dc40':
bn_layer_to_eqcls = calculate_bn_eqcls_dc40(eqcls_dict)
else:
bn_layer_to_eqcls = None

# test 1
evaluate_model(test_model, num_examples=eval_dataset.num_examples_per_epoch(),
results_record_file=important_log_file,
comment='eval at itr {}, cur remain deps {}'.format(itr, cur_remain_deps))
del test_model

if st_max_epochs_per_gpu <= 0: # skip C-SGD training
cur_sted_weights = cur_start_model_file
else:
# C-SGD train
# Note that l2_factor = 0 (cancel the original L2 regularization term implemented by base_loss += 0.5 * l2_factor * sum(kernels ** 2)
train_model = model_type(dataset=train_dataset, inference_fn=train_builder.build, mode=train_mode,
batch_size=st_batch_size_per_gpu, image_size=image_size, l2_factor=0,
deps=cur_remain_deps)
gradient_handler = MergeGradientHandler(model=train_model, layer_to_eqcls=eqcls_dict,
l2_factor=l2_factor,
diff_factor=diff_factor, exclude_l2_decay_keywords=None,
bn_layer_to_eqcls=bn_layer_to_eqcls,
version=2, slow_on_vec=slow_on_vec)
cur_init_step = init_step if itr == restore_itr else 0
lr = train_model.get_piecewise_lr(st_lr_values, boundaries_epochs=st_lr_epoch_boundaries, init_step=cur_init_step)
optimizer = tf.train.MomentumOptimizer(lr, momentum=0.9, use_nesterov=True)
if itr == restore_itr and restore_st_step > 0:
load_ckpt = os.path.join(cur_train_dir, 'model.ckpt-{}'.format(restore_st_step))
print('the CSGD train restarts from ', load_ckpt)
else:
load_ckpt = None
if itr == restore_itr:
start_step = max(init_step, restore_st_step)
else:
start_step = 0
mg_train(model=train_model, train_dir=cur_train_dir, optimizer=optimizer, layer_to_eqcls=eqcls_dict,
max_epochs_per_gpu=st_max_epochs_per_gpu, max_steps_per_gpu=None,
init_step=start_step, init_file=cur_start_model_file, load_ckpt=load_ckpt,
save_final_hdf5=cur_sted_weights,
ckpt_dir=ckpt_dir, ckpt_prefix='st_ckpt_itr{}'.format(itr),
num_steps_every_ckpt=num_steps_per_ckpt_st,
gradient_handler=gradient_handler, gpu_idxes=st_gpu_idxes, bn_layer_to_eqcls=bn_layer_to_eqcls,
frequently_save_interval=frequently_save_interval, frequently_save_last_epochs=frequently_save_last_epochs)
log_important('C-SGD train completed. save to: {}'.format(cur_sted_weights),
log_file=important_log_file)
del train_model

else:
print('load eqcls form file: ', eqcls_file)
eqcls_dict = np.load(eqcls_file).item()
if eqcls_follow_dict is not None:
for k, v in eqcls_follow_dict.items():
eqcls_dict[k] = eqcls_dict[v]

# test after CSGD train before trimming
test_model = model_type(dataset=eval_dataset, inference_fn=eval_builder.build, mode=eval_mode,
batch_size=eval_batch_size, image_size=image_size)
test_model.load_weights_from_file(cur_sted_weights)
evaluate_model(test_model, num_examples=eval_dataset.num_examples_per_epoch(),
results_record_file=important_log_file,
comment='eval before trimming at itr {}, cur remain deps {}'.format(itr, cur_remain_deps),
close_threads=False)
del test_model

# Trim the C-SGD trained weights
trim_model = model_type(dataset=eval_dataset, inference_fn=eval_builder.build, mode=eval_mode,
batch_size=eval_batch_size, image_size=image_size)
trim_model.load_weights_from_file(cur_sted_weights)
if network_type == 'dc40':
bn_layer_to_eqcls = calculate_bn_eqcls_dc40(eqcls_dict)
tfm_prune_filters_and_save_dc40(trim_model, eqcls_dict, bn_layer_to_eqcls=bn_layer_to_eqcls,
save_file=cur_pruned_weights,
new_deps=next_remain_deps)
else:
tfm_prune_filters_and_save(trim_model, eqcls_dict, save_file=cur_pruned_weights,
fc_layer_idxes=fc_layer_idxes, subsequent_strategy=subsequent_strategy, new_deps=next_remain_deps)

log_important('finished trimming at itr {}, save to {}'.format(itr, cur_pruned_weights),
log_file=important_log_file)
del trim_model

# test the trimmed model
pruned_builder = normal_builder_type(training=False, deps=next_remain_deps)
pruned_test_model = model_type(dataset=eval_dataset, inference_fn=pruned_builder.build, mode=eval_mode,
batch_size=eval_batch_size, image_size=image_size)
pruned_test_model.load_weights_from_file(cur_pruned_weights)
evaluate_model(pruned_test_model, num_examples=eval_dataset.num_examples_per_epoch(),
results_record_file=important_log_file,
comment='eval after trimming at itr {}, cur remain deps {}'.format(itr, next_remain_deps))
del pruned_test_model
39 changes: 39 additions & 0 deletions cr_dc40.py
@@ -0,0 +1,39 @@
from cr_base import cr_base_pipeline
from tf_dataset import CIFAR10Data
from tfm_model import TFModel
from tfm_constants import *
from tfm_builder_densenet import DC40Builder

def cr_dc40(
try_arg, target_deps, origin_deps, pretrained_model_file,
cluster_method, eqcls_layer_idxes,

st_batch_size_per_gpu, st_max_epochs_per_gpu, st_lr_epoch_boundaries, st_lr_values,

diff_factor,
schedule_vector,

restore_itr=0, restore_st_step=0,
init_step=0,
frequently_save_interval=None, frequently_save_last_epochs=None,
slow_on_vec=False
):
cr_base_pipeline(
network_type='dc40',
train_dataset=CIFAR10Data('train'), eval_dataset=CIFAR10Data('validation'),
train_mode='train', eval_mode='eval',
normal_builder_type=DC40Builder, model_type=TFModel,
subsequent_strategy=DC40_SUBSEQUENT_STRATEGY, eqcls_follow_dict=DC40_FOLLOW_DICT,
fc_layer_idxes=DC40_FC_LAYERS, st_gpu_idxes=[0],
l2_factor=1e-4, eval_batch_size=500, image_size=32,
frequently_save_interval=frequently_save_interval or 1000, frequently_save_last_epochs=frequently_save_last_epochs or 50, num_steps_per_ckpt_st=20000,

try_arg=try_arg, target_deps=target_deps, origin_deps=origin_deps, pretrained_model_file=pretrained_model_file,
cluster_method=cluster_method, eqcls_layer_idxes=eqcls_layer_idxes,
st_batch_size_per_gpu=st_batch_size_per_gpu, st_max_epochs_per_gpu=st_max_epochs_per_gpu, st_lr_epoch_boundaries=st_lr_epoch_boundaries, st_lr_values=st_lr_values,
diff_factor=diff_factor,
schedule_vector=schedule_vector,
restore_itr=restore_itr, restore_st_step=restore_st_step,
init_step=init_step,
slow_on_vec=slow_on_vec
)
112 changes: 112 additions & 0 deletions csgd_standalone.py
@@ -0,0 +1,112 @@
from tfm_model import TFModel
from tf_dataset import CIFAR10Data
from tfm_origin_train import train
from tfm_builder_densenet import DC40Builder
from tfm_constants import DC40_ORIGIN_DEPS, customized_dc40_deps, DC40_ALL_CONV_LAYERS
import tensorflow as tf
import sys
from cr_dc40 import cr_dc40
import os
import numpy as np
from csgd_utils import calculate_bn_eqcls_dc40, tfm_prune_filters_and_save_dc40
from tfm_origin_eval import evaluate_model
from tf_utils import extract_deps_from_weights_file

PRETRAINED_MODEL_FILE = 'std_dc40_9382.hdf5'


LR_VALUES = [3e-3, 3e-4, 3e-5, 3e-6]
LR_BOUNDARIES = [200, 400, 500]
MAX_EPOCHS = 600
BATCH_SIZE = 64

EPSILON_VALUE = 3e-3

TARGET_DEPS = customized_dc40_deps('3-3-3') # 3 filters per incremental conv layer (the original is 12)

DC40_L2_FACTOR = 1e-4



def eval_model(weights_path):
dataset = CIFAR10Data('validation')
deps = extract_deps_from_weights_file(weights_path)
if deps is None:
deps = DC40_ORIGIN_DEPS
builder = DC40Builder(training=False, deps=deps)
model = TFModel(dataset, builder.build, 'eval', batch_size=250, image_size=32)
model.load_weights_from_file(weights_path)
evaluate_model(model, num_examples=dataset.num_examples_per_epoch(), results_record_file='origin_dc40_eval_record.txt')


def compare_csgd(prefix):
cr_dc40(prefix, target_deps=TARGET_DEPS, origin_deps=DC40_ORIGIN_DEPS, pretrained_model_file=PRETRAINED_MODEL_FILE,
cluster_method='kmeans', eqcls_layer_idxes=DC40_ALL_CONV_LAYERS,
st_batch_size_per_gpu=BATCH_SIZE, st_max_epochs_per_gpu=MAX_EPOCHS, st_lr_epoch_boundaries=LR_BOUNDARIES,
st_lr_values=LR_VALUES,
diff_factor=EPSILON_VALUE, schedule_vector=[1], slow_on_vec=False)


# In order to reuse the pruning function, we produce $\mathcal{C}$ based on the magnitude of filter kernels
# thus pruning according to it becomes equivalent to pruning the 9 filters smaller in magnitude at each layer
def _produce_magnitude_equivalent_eqcls(target_deps, save_path):
builder = DC40Builder(True, deps=DC40_ORIGIN_DEPS)
prune_model = TFModel(CIFAR10Data('train'), builder.build, 'train', batch_size=BATCH_SIZE, image_size=32, l2_factor=DC40_L2_FACTOR, deps=DC40_ORIGIN_DEPS)
prune_model.load_weights_from_file(PRETRAINED_MODEL_FILE)

equivalent_dict_eqcls = {}
for i in DC40_ALL_CONV_LAYERS:
kernel_value = prune_model.get_value(prune_model.get_kernel_tensors()[i])
summed_kernel_value = np.sum(np.abs(kernel_value), axis=(0, 1, 2))
assert len(summed_kernel_value) == DC40_ORIGIN_DEPS[i]
index_array = np.argsort(summed_kernel_value)
index_to_delete = index_array[:(DC40_ORIGIN_DEPS[i] - target_deps[i])]
cur_eqcls = []
for k in range(DC40_ORIGIN_DEPS[i]):
if k not in index_to_delete:
cur_eqcls.append([k])
for k in index_to_delete:
cur_eqcls[0].append(k)
equivalent_dict_eqcls[i] = cur_eqcls

np.save(save_path, equivalent_dict_eqcls)
del prune_model
return equivalent_dict_eqcls


def compare_magnitude(prefix):
pruned = 'dc40_{}_prunedweights.hdf5'.format(prefix)
target_deps = customized_dc40_deps('3-3-3')
save_hdf5 = '{}_trained.hdf5'.format(prefix)
equivalent_eqcls_path = 'dc40_equivalent_eqcls_{}.npy'.format(prefix)
if not os.path.exists(pruned):
eqcls_dict = _produce_magnitude_equivalent_eqcls(target_deps=target_deps, save_path=equivalent_eqcls_path)
bn_layer_to_eqcls = calculate_bn_eqcls_dc40(eqcls_dict)

builder = DC40Builder(False, deps=DC40_ORIGIN_DEPS)
prune_model = TFModel(CIFAR10Data('train'), builder.build, 'eval', batch_size=64, image_size=32,
l2_factor=1e-4, deps=DC40_ORIGIN_DEPS)
prune_model.load_weights_from_file(PRETRAINED_MODEL_FILE)
tfm_prune_filters_and_save_dc40(prune_model, eqcls_dict, bn_layer_to_eqcls=bn_layer_to_eqcls,
save_file=pruned, new_deps=target_deps)
del prune_model

builder = DC40Builder(True, deps=target_deps)
model = TFModel(CIFAR10Data('train'), builder.build, 'eval', batch_size=BATCH_SIZE, image_size=32,
l2_factor=DC40_L2_FACTOR, deps=target_deps)
lr = model.get_piecewise_lr(values=LR_VALUES, boundaries_epochs=LR_BOUNDARIES)
optimizer = tf.train.MomentumOptimizer(lr, momentum=0.9, use_nesterov=True)
train(model, train_dir='{}_train'.format(prefix), optimizer=optimizer, max_epochs_per_gpu=MAX_EPOCHS,
gpu_idxes=[0], init_file=pruned, save_final_hdf5=save_hdf5,
ckpt_dir='{}_ckpt'.format(prefix), ckpt_prefix=prefix, num_steps_every_ckpt=20000)



if __name__ == '__main__':
prefix = sys.argv[1]
if 'csgd' in prefix:
compare_csgd(prefix)
elif 'magnitude' in prefix:
compare_magnitude(prefix)
elif prefix == 'eval':
eval_model(sys.argv[2])

0 comments on commit 48c77c7

Please sign in to comment.