Skip to content
Riemannian approach to batch normalization
Branch: master
Clone or download
Pull request Compare This branch is 6 commits ahead, 4 commits behind wenxinxu:master.
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Type Name Latest commit message Commit time
Failed to load latest commit information.
LICENSE first commit of riemannian-batch-normalization Oct 19, 2017

A Tensorflow Implementation of "Riemannian approach to batch normalization"

This code was used for experiments in Riemannian approach to batch normalization (NIPS 2017) by Minhyung Cho and Jaehyung Lee ( The poster for the conference can be found here.

Refer to for a PyTorch implementation.


Batch Normalization (BN) has proven to be an effective algorithm for deep neural network training by normalizing the input to each neuron and reducing the internal covariate shift. The space of weight vectors in the BN layer can be naturally interpreted as a Riemannian manifold, which is invariant to linear scaling of weights. Following the intrinsic geometry of this manifold provides a new learning rule that is more efficient and easier to analyze. We also propose intuitive and effective gradient clipping and regularization methods for the proposed algorithm by utilizing the geometry of the manifold. The resulting algorithm consistently outperforms the original BN on various types of network architectures and datasets.


Classifiation error rate on CIFAR (median of five runs):

Dataset CIFAR-10 CIFAR-100
Model SGD SGD-G Adam-G SGD SGD-G Adam-G
VGG-13 5.88 5.87 6.05 26.17 25.29 24.89
VGG-19 6.49 5.92 6.02 27.62 25.79 25.59
WRN-28-10 3.89 3.85 3.78 18.66 18.19 18.30
WRN-40-10 3.72 3.72 3.80 18.39 18.04 17.85

Classification error rate on SVHN (median of five runs):

Model SGD SGD-G Adam-G
VGG-13 1.78 1.74 1.72
VGG-19 1.94 1.81 1.77
WRN-16-4 1.64 1.67 1.61
WRN-22-8 1.64 1.63 1.55


WRN-28-10 on CIFAR10 WRN-28-10 on CIFAR100 WRN-22-8 on SVHN

See for details.



The commands below are examples for reproducing results in the paper.


[SGD] python3 --model=resnet --depth=28 --widen_factor=10 --optimizer=sgd --learnRate=0.1 --weightDecay=0.0005 --biasDecay=0.0005 --gammaDecay=0.0005 --betaDecay=0.0005 --keep_prob=0.7 --data=cifar10
[SGD-G] python3 --model=resnet --depth=28 --widen_factor=10 --optimizer=sgdg --grassmann=True --learnRate=0.01 --learnRateG=0.2 --omega=0.1 --grad_clip=0.1 --weightDecay=0.0005 --keep_prob=0.7 --data=cifar10
[Adam-G] python3 --model=resnet --depth=28 --widen_factor=10 --optimizer=adamg --grassmann=True --learnRate=0.01 --learnRateG=0.05 --omega=0.1 --grad_clip=0.1 --weightDecay=0.0005 --keep_prob=0.7 --data=cifar10


[SGD] python3 --model=resnet --depth=28 --widen_factor=10 --optimizer=sgd --learnRate=0.1 --weightDecay=0.0005 --biasDecay=0.0005 --gammaDecay=0.0005 --betaDecay=0.0005 --keep_prob=0.7 --data=cifar100
[SGD-G] python3 --model=resnet --depth=28 --widen_factor=10 --optimizer=sgdg --grassmann=True --learnRate=0.01 --learnRateG=0.2 --omega=0.1 --grad_clip=0.1 --weightDecay=0.0005 --keep_prob=0.7 --data=cifar100
[Adam-G] python3 --model=resnet --depth=28 --widen_factor=10 --optimizer=adamg --grassmann=True --learnRate=0.01 --learnRateG=0.05 --omega=0.1 --grad_clip=0.1 --weightDecay=0.0005 --keep_prob=0.7 --data=cifar100


[SGD] python3 --model=resnet --depth=22 --widen_factor=8 --optimizer=sgd --learnRate=0.01 --weightDecay=0.0005 --biasDecay=0.0005 --gammaDecay=0.0005 --betaDecay=0.0005 --keep_prob=0.6 --learnRateDecay=0.1 --data=svhn
[SGD-G] python3 --model=resnet --depth=22 --widen_factor=8 --optimizer=sgdg --grassmann=True --learnRate=0.001 --learnRateG=0.02 --omega=0.1 --grad_clip=0.1 --weightDecay=0.0005 --keep_prob=0.6 --learnRateDecay=0.1 --data=svhn
[Adam-G] python3 --model=resnet --depth=22 --widen_factor=8 --optimizer=adamg --grassmann=True --learnRate=0.001 --learnRateG=0.005 --omega=0.1 --grad_clip=0.1 --weightDecay=0.0005 --keep_prob=0.6 --learnRateDecay=0.1 --data=svhn

Another example:

[2GPUs] pyhon3 --model=resnet --depth=40 --widen_factor=10 --optimizer=adamg --grassmann=True --learnRate=0.01 --learnRateG=0.05 --omega=0.1 --grad_clip=0.1 --weightDecay=0.0005 --keep_prob=0.7 --data=cifar100 --vali_batch_size=200 --num_gpus=2
[VGG-19] python3 --model=vgg19 --optimizer=adamg --grassmann=True --learnRate=0.01 --learnRateG=0.05 --omega=0.1 --grad_clip=0.1 --weightDecay=0.0005 --data=cifar100

Test the performance of a checkpoint

python3 --model=resnet --depth=40 --widen_factor=10  --data=cifar100 --task=test --load=./logs/resnet_train_cifar100/model.ckpt-78124

To apply this algorithm to your model is the main implementation which provides the proposed SGD-G and Adam-G optimizer, as well as HybridOptimizer, an abstract convenience class. includes all the steps to apply the provided optimizers to your model.

  1. Collect all the weight parameters which need to be optimized on Grassmann manifold (and initialize them to a unit scale):

    weight = [i for i in tf.trainable_variables() if 'weight' in]
    undercomplete =[0:-1])>var.shape[-1]
    if undercomplete and ('conv' in
        ## initialize to scale 1
        var._initializer_op=tf.assign(var, gutils.unit_initializer()(var.shape)).op
        tf.add_to_collection('grassmann', var)
  2. Build the graph for orthogonality regularizer:

    for var in tf.get_collection('grassmann'):
        shape = var.get_shape().as_list()
        v = tf.reshape(var, [-1, shape[-1]])
        v_sim = tf.matmul(tf.transpose(v), v)
        eye = tf.eye(shape[-1])
        assert eye.get_shape()==v_sim.get_shape()
        orthogonality = tf.multiply(tf.reduce_sum( (v_sim-eye)**2 ), 0.5*, name='orthogonality')
        tf.add_to_collection('orthogonality', orthogonality)

    Do not apply weight decay to the parameters above.

  3. Add orthogonality loss to the loss function:

    orthogonality = tf.add_n(tf.get_collection('orthogonality', scope), name='orthogonality')
    total_loss = cross_entropy_mean + weightcost + orthogonality
  4. Initialze the optimizer:

    import grassmann_optimizer
    opta = tf.train.MomentumOptimizer(learning_rate, momentum)
    optb = grassmann_optimizer.SgdgOptimizer(learning_rate, momentum, grad_clip) # or use Adam-G
    opt = grassmann_optimizer.HybridOptimizer(opta, optb)
  5. Build the training graph:

    Pass two lists of (gradient, variable) pairs to apply_gradients(). Variables in grads_a will be updated by opta and variables in grads_b will be updated by optb.

    grads_a = [i for i in grads if not i[1] in tf.get_collection('grassmann')]
    grads_b = [i for i in grads if i[1] in tf.get_collection('grassmann')]
    apply_gradient_op = opt.apply_gradients(grads_a, grads_b)
You can’t perform that action at this time.