Skip to content

@CVPR2018: Efficient unrolling iterative matrix square-root normalized ConvNets, implemented by PyTorch (and code of B-CNN,Compact bilinear pooling etc.) for training from scratch & finetuning.

License

Notifications You must be signed in to change notification settings

OrkhanHI/fast-MPN-COV

 
 

Repository files navigation

Fast MPN-COV (i.e., iSQRT-COV)

Created by Jiangtao Xie and Peihua Li

      

Introduction

This repository contains the source code under PyTorch framework and models trained on ImageNet 2012 dataset for the following paper:

     @InProceedings{Li_2018_CVPR,
           author = {Li, Peihua and Xie, Jiangtao and Wang, Qilong and Gao, Zilin},
           title = {Towards Faster Training of Global Covariance Pooling Networks by Iterative Matrix Square Root Normalization},
           booktitle = { IEEE Int. Conf. on Computer Vision and Pattern Recognition (CVPR)},
           month = {June},
           year = {2018}
     }

This paper concerns an iterative matrix square root normalization network (called fast MPN-COV), which is very efficient, fit for large-scale datasets, as opposed to its predecessor (i.e., MPN-COV published in ICCV17) that performs matrix power normalization by Eigen-decompositon. The code on bilinear CNN (B-CNN), compact bilinear pooling and global average pooling etc. is also released for both training from scratch and finetuning. If you use the code, please cite this fast MPN-COV work and its predecessor (i.e., MPN-COV).

Classification results

Classification results (single crop 224x224, %) on ImageNet 2012 validation set

Network Top-1 Error Top-5 Error Pre-trained models
paper reproduce paper reproduce GoogleDrive BaiduCloud
fast MPN-COV-ResNet50 22.14 21.71 6.22 6.13 217.3MB 217.3MB
fast MPN-COV-ResNet101 21.21 20.99 5.68 5.56 289.9MB 289.9MB

Fine-grained classification results (top-1 accuracy rates, %)

Backbone model Dim. Birds Aircrafts Cars
paper reproduce paper reproduce paper reproduce
ResNet-50 32K 88.1 88.0 90.0 90.3 92.8 92.3
ResNet-101 32K 88.7 TODO 91.4 TODO 93.3 TODO
  • Our method uses neither bounding boxes nor part annotations
  • The reproduced results are obtained by simply finetuning our pre-trained fast MPN-COV-ResNet model with a small learning rate, which do not perform SVM as our paper described.

Implementation details

We implement our Fast MPN-COV (i.e., iSQRT-COV) meta-layer under PyTorch package. Note that though autograd package of PyTorch 0.4.0 or above can compute correctly gradients of our meta-layer, that of PyTorch 0.3.0 fails. As such, we decide to implement the backpropagation of our meta-layer without using autograd package, which works well for both PyTorch release 0.3.0 and 0.4.0.

For making our Fast MPN-COV meta layer can be added in a network conveniently, we reconstruct pytorch official demo imagenet/ and models/. In which, we divide any network for three parts: 1) features extractor; 2) global image representation; 3) classifier. As such, we can arbitrarily combine a network with our Fast MPN-COV or some other global image representation methods (e.g.,Global average pooling, Bilinear pooling, Compact bilinear pooling, etc.) Based on these, we can:


  • Finetune a pre-trained model on any image classification datasets.

AlexNet, VGG, ResNet, Inception, etc.


  • Finetune a pre-trained model with a powerful global image representation method on any image classification datasets.

Fast MPN-COV, Bilinear Pooling (B-CNN), Compact Bilinear Pooling (CBP), etc.


  • Train a model from scratch with a powerful global image representation method on any image classification datasets.

Finetune demo and Train from scratch demo

Welcome to contribution. In this repository, we will keep updating for containing more networks and global image representation methods.

Created and Modified

├── main.py
├── imagepreprocess.py
├── functions.py
├── model_init.py
├── src
│   ├── network
│   │   ├── __init__.py
│   │   ├── base.py
│   │   ├── inception.py
│   │   ├── alexnet.py
│   │   ├── mpncovresnet.py
│   │   ├── resnet.py
│   │   └── vgg.py
│   ├── representation
│   │   ├── __init__.py
│   │   ├── MPNCOV.py
│   │   ├── GAvP.py
│   │   ├── BCNN.py
│   │   ├── CBP.py
│   │   └── Custom.py
│   └── torchviz
│       ├── __init__.py
│       └── dot.py
├── trainingFromScratch
│       └── train.sh
└── finetune
        ├── finetune.sh
        └── two_stage_finetune.sh
For more convenient training and finetuning, we
  • implement some functions for plotting convergence curve.
  • adopt network visualization tool pytorchviz for plotting network structure.
  • use shell file to manage the process.

Installation and Usage

  1. Install PyTorch (0.4.0 or above)
  2. type git clone https://github.com/jiangtaoxie/fast-MPN-COV
  3. pip install -r requirements.txt
  4. prepare the dataset as follows
.
├── train
│   ├── class1
│   │   ├── class1_001.jpg
│   │   ├── class1_002.jpg
|   |   └── ...
│   ├── class2
│   ├── class3
│   ├── ...
│   ├── ...
│   └── classN
└── val
    ├── class1
    │   ├── class1_001.jpg
    │   ├── class1_002.jpg
    |   └── ...
    ├── class2
    ├── class3
    ├── ...
    ├── ...
    └── classN

for training from scracth

  1. cp trainingFromScratch/train.sh ./
  2. modify the dataset path in train.sh
  3. sh train.sh

for finetuning our fast MPN-COV model

  1. cp finetune/finetune.sh ./
  2. modify the dataset path in finetune.sh
  3. sh finetune.sh

for finetuning VGG-model by using BCNN

  1. cp finetune/two_stage_finetune.sh ./
  2. modify the dataset path in two_stage_finetune.sh
  3. sh two_stage_finetune.sh

Other Implementations

  1. MatConvNet Implementation
  2. TensorFlow Implemention(coming soon)

Contact

If you have any questions or suggestions, please contact me

jiangtaoxie@mail.dlut.edu.cn

About

@CVPR2018: Efficient unrolling iterative matrix square-root normalized ConvNets, implemented by PyTorch (and code of B-CNN,Compact bilinear pooling etc.) for training from scratch & finetuning.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 88.3%
  • Shell 11.7%