Skip to content

AILab-CVC/UniRepLKNet

Repository files navigation

UniRepLKNet: A Universal Perception Large-Kernel ConvNet for Audio, Video,Point Cloud, Time-Series and Image Recognition

1 Tencent AI Lab 2 The Chinese University of Hong Kong
* Equal Contribution 

arXiv Hugging Face Models website License: Apache2.0

Star and watch me if you are interested in this project :)

Code and checkpoints have been thoroughly tested. Please raise an issue if you get a bug. And do not hesitate to try our efficient implementation of large-kernel convolution with PyTorch (see here)!

Star History Chart

Motivation

  • We note that most architectures of the existing large-kernel ConvNets simply follow other models. The architectural design for large-kernel ConvNets remains under-explored.
  • The universal perception ability of Transformers is sparking in multimodal research areas (image, audio, video, time-series, etc). We are curious whether ConvNets can also deliver universal perception ability across multiple modalities with a unified architecture.

Highlights

A ConvNet unifies multiple modalities and outperforms modality-specific models. This paper summarizes architectural guidelines to build large-kernel CNN, which works amazingly well with images and other modalities. This is the latest contribution to both two influential areas - Structural Re-param (since RepVGG, Ding et al. 2021) and very-large-kernel ConvNet (since RepLKNet, Ding et al. 2022). ImageNet accuracy of 88.0%, COCO AP of 56.4, ADE20K mIoU of 55.6 with only ImageNet-22K pretraining. Higher actual speed and performance than recent models like ConvNeXt v2 and InternImage. With a unified architecture and extremely simple modality-specific preprocessing, achieves state-of-the-art performances on audio recognition and, most amazingly, Global Temperature & Wind Speed Forecasting (a challenging huge-scale time-series forecasting task), outperforming the existing global forecasting system.

More specifically, we contribute from two aspects:

  • We propose four architectural guidelines for designing large-kernel ConvNets, the core of which is to exploit the essential characteristics of large kernels that distinguish them from small kernels - they can see wide without going deep. Following such guidelines, our proposed large-kernel ConvNet shows leading performance in image recognition.
  • We discover that large kernels are the key to unlocking the exceptional performance of ConvNets in domains where they were originally not proficient. With certain modality-related preprocessing approaches, the proposed model achieves state-of-the-art performance on time-series forecasting and audio recognition tasks even without modality-specific customization to the architecture.

UniRepLKNet not only signifies a "comeback" for ConvNet in its original domain but also showcases large-kernel ConvNet’s potential to "conquer" new territories, highlighting further adaptability and broad utility across different modalities and tasks.

TODOs

  • Model code
  • Most of the ImageNet-1K and ImageNet-22K pretrained weights
  • Weights released on both Google Drive (see this page) and hugging face (see unireplknet.py)
  • PyTorch efficient large-kernel conv implementation
  • ImageNet training code
  • Code and documents of audio, video, point cloud, and time-series tasks
  • Semantic segmentation code, document, and all the checkpoints
  • Object detection code, document, and all the checkpoints
  • Checkpoints of audio, video, point cloud, and time-series tasks

The ImageNet, COCO, and ADE20K checkpoints have been released (see the huggingface repo shown below), except the ImageNet-22K pretrained UniRepLKNet-S, and UperNet with UniRepLKNet-XL, which were lost, and we are reproducing them.

Latest news: fixed a bug, which results from this commit on Dec 1st, 2023. Now it is fixed . If you used unireplknet.py after Dec 1st, 2023, please check your code.

Code design

  1. There is some MMDetection- and MMSegmentation-related code in unireplknet.py so that you can directly copy-paste it into your MMDetection or MMSegmentation, e.g., here and here. If you do not want to use it with MMDetection or MMSegmentation, you can safely delete those lines of code.
  2. We have provided code to automatically build our models and load our released weights. See the functions here. You can also use timm.create_model to build the models. For example, model = timm.create_model('unireplknet_l', num_classes=num_classes_of_your_task, in_22k_pretrained=True) will call the function unireplknet_l defined here, which will build a UniRepLKNet-L and automatically download our checkpoints and load the weights.
    # The simplest way to use our model in your project is to copy-paste unireplknet.py into your working directory and create models. For example
    from unireplknet import *
    model = timm.create_model('unireplknet_l', num_classes=num_classes_of_your_task, in_22k_pretrained=True)
    
  3. As UniRepLKNet also uses the Structural Re-parameterization methodology, we provide a function reparameterize_unireplknet() that converts a trained UniRepLKNet into the inference structure, which equivalently removes the parallel branches in Dialted Reparam Blocks, Batch Norm layers, and the bias term in GRN. The pseudo-code of the full pipeline will be like
    training_model = unireplknet_l(...,  deploy=False)
    train(training_model)
    trained_results = evaluate(training_model)
    training_model.reparameterize_unireplknet()
    inference_results = evaluate(training_model)
    # you will see inference_results are identical to trained_results
    save(training_model, 'converted_weights.pth')
    # use the converted model
    deploy_model = unireplknet_l(..., deploy=True)
    load_weights(deploy_model, 'converted_weights.pth')
    deploy_results = evaluate(deploy_model)
    # you will see deploy_results == inference_results == trained_results
  4. You may want to read this if you are familiar with the timm library. We sincerely thank timm for providing a convenient re-parameterize function. The code design of UniRepLKNet is compatible with it. That is, calling some_unireplknet_model.reparameterize_unireplknet() is equivalent to calling timm.utils.reparameterize_model(some_unireplknet_model). So if you use our code with timm's codebase, e.g., timm's evaluation code, just add --reparam to your command so that timm.utils.reparameterize_model will be called (see here).

Models

We have provided five ways to download our checkpoints.

  1. Download via the Google Drive links shown below.
  2. Visit our huggingface repo at https://huggingface.co/DingXiaoH/UniRepLKNet/tree/main and click the download icons.
  3. Use huggingface-hub in your python code. First, install huggingface_hub
pip install huggingface_hub

Then, use huggingface_hub like this in your python code, for example,

from huggingface_hub import hf_hub_download
repo_id = 'DingXiaoH/UniRepLKNet'
cache_file = hf_hub_download(repo_id=repo_id, filename=FILE_NAME)
checkpoint = torch.load(cache_file, map_location='cpu')
model.load_state_dict(checkpoint)

See our huggingface repo or our code for FILE_NAME (e.g., unireplknet_xl_in22k_pretrain.pth).

  1. Use the huggingface CLI. Check the tutorial.

  2. Automatically download our checkpoints by passing in_1k_pretrained=True, in_22k_pretrained=True, or in_22k_to_1k=True while calling our provided functions. See the code here.

ImageNet-1K Pretrained Weights

name resolution acc@1 #params FLOPs Weights
UniRepLKNet-A 224x224 77.0 4.4M 0.6G ckpt
UniRepLKNet-F 224x224 78.6 6.2M 0.9G ckpt
UniRepLKNet-P 224x224 80.2 10.7M 1.6G ckpt
UniRepLKNet-N 224x224 81.6 18.3M 2.8G ckpt
UniRepLKNet-T 224x224 83.2 31M 4.9G ckpt
UniRepLKNet-S 224x224 83.9 56M 9.1G ckpt

ImageNet-22K Pretrained Weights

name resolution #params FLOPs ckpt
UniRepLKNet-S 224x224 56M 26.7G ckpt
UniRepLKNet-B 224x224 98M 47.2G ckpt
UniRepLKNet-L 192x192 218M 105.4G ckpt
UniRepLKNet-XL 192x192 386M 187G ckpt

Pretrained on ImageNet-22K then finetuned on ImageNet-1K

name resolution acc@1 #params FLOPs ckpt
UniRepLKNet-S 384x384 86.4 56M 26.7G ckpt
UniRepLKNet-B 384x384 87.4 98M 47.2G ckpt
UniRepLKNet-L 384x384 87.9 218M 105.4G ckpt
UniRepLKNet-XL 384x384 88.0 386M 187G ckpt

COCO Object Detection

Code, document, and config files have been released. See the detection guide here.

Checkpoints have already been released on hugging face. You can download them right now from https://huggingface.co/DingXiaoH/UniRepLKNet/tree/main.

Or you can download these checkpoints from Google Drive as follows:

name resolution box mAP mask mAP #params FLOPs Weights
UniRepLKNet-T 1280x800 51.7 44.9 89M 749G ckpt
UniRepLKNet-S 1280x800 53.0 45.9 113M 835G ckpt
UniRepLKNet-S_22K 1280x800 54.3 47.1 113M 835G ckpt
UniRepLKNet-B_22K 1280x800 54.8 47.4 155M 978G ckpt
UniRepLKNet-L_22K 1280x800 55.8 48.4 276M 1385G ckpt
UniRepLKNet-XL_22K 1280x800 56.4 49.0 443M 1952G ckpt

ADE-20K Semantic Segmentation

Code, document, and config files have been released. See the segmentation guide here.

Checkpoints have already been released on hugging face. You can download them right now from https://huggingface.co/DingXiaoH/UniRepLKNet/tree/main.

Or you can download these checkpoints from Google Drive as follows:

name resolution mIoU (ss/ms) #params FLOPs Weights
UniRepLKNet-T 512x512 48.6/49.1 61M 946G ckpt
UniRepLKNet-S 512x512 50.5/51.0 86M 1036G ckpt
UniRepLKNet-S_22K 512x512 51.9/52.7 86M 1036G ckpt
UniRepLKNet-S_22K 640x640 52.3/52.7 86M 1618G ckpt
UniRepLKNet-B_22K 640x640 53.5/53.9 130M 1850G ckpt
UniRepLKNet-L_22K 640x640 54.5/55.0 254M 2507G ckpt
UniRepLKNet-XL_22K 640x640 55.2/55.6 425M 3420G ckpt

ImageNet evaluation and training

We give an example evaluation command.

Single-GPU

python main.py --model unireplknet_b --eval true \
--resume unireplknet_b_in22k_to_in1k_384_acc87.40.pth  \
--input_size 384 \
--data_path /path/to/imagenet-1k

Multi-GPU

python -m torch.distributed.launch --nproc_per_node=8 main.py \
--model unireplknet_b --eval true \
--resume unireplknet_b_in22k_to_in1k_384_acc87.40.pth  \
--input_size 384 \
--data_path /path/to/imagenet-1k

For training or finetuning UniRepLKNets on ImageNet-1K or 22K, see this guide

Universal perception of audio, video, point cloud, and time-series tasks

For detailed documentation, please refer to these documents as follows:

Use an efficient large-kernel convolution with PyTorch

We use a large-kernel conv implementation in PyTorch that is more efficient than the native torch.nn.Conv2d . It is implemented based on the iGEMM algorithm and a lightweight tool named cutlass. The installation is very simple and will cost you less than one minute. If you do not install this implementation, you can still use our model anywhere you wish but it will be a bit slower.

  1. Download cutlass.zip, then unzip cutlass.zip, enter the directory. This version of cutlass provided in this repository works fine with our large-kernel implementation and multiple python versions. You may alternatively use the cutlass branch maintained by the MegEngine team (clone https://github.com/MegEngine/cutlass), but you may need to be more careful with your python version (see this issue).
  2. cd examples/19_large_depthwise_conv2d_torch_extension
  3. ./setup.py install --user. If you get errors, check your CUDA_HOME.
  4. You may do a quick check to verify that the results of forward/backward computations are the same as torch.nn.Conv2d: python depthwise_conv2d_implicit_gemm.py
  5. Add PATH_TO_CUTLASS_DIRECTORY/examples/19_large_depthwise_conv2d_torch_extension into your PYTHONPATH so that you can from depthwise_conv2d_implicit_gemm import DepthWiseConv2dImplicitGEMM anywhere. Then you may use DepthWiseConv2dImplicitGEMM as a replacement of nn.Conv2d.

It should work with a wide range of GPUs and PyTorch/CUDA versions. We suggest you try first and check the environments only if you get any errors. Our latest testes used both

  1. Ubuntu 18.04 + CUDA 11.3 + nvcc 11.3 + cudnn 8.2.0 + python 3.8.12 + pytorch 1.10 + gcc 7.3.0 + nccl 2.10.3 + NVIDIA driver 450.102.04 + V100 and A100 GPUs
  2. Ubuntu 18.04 + CUDA 10.2 + nvcc 10.0 + cudnn 7.6.5 + python 3.6.9 + pytorch 1.9 + gcc 7.5.0 + nccl 2.7.8 + NVIDIA driver 460.32.03 + 2080Ti and V100 GPUs

It is reported (see here) that a python version mismatch may result in an error (forward_fp32.cu(212): error: more than one instance of constructor "cutlass::Tensor4DCoord::Tensor4DCoord" ... or cutlass/include/cutlass/fast_math.h(741): error: no suitable conversion function from "__half" to "float" exists). Please upgrade or downgrade your python. We sincerely thank @sleeplessai and @ewrfcas for sharing their experience.

Pull requests (e.g., better or other implementations or implementations on other frameworks) are welcomed.

Citation

If the code and paper help your research, please kindly cite:

@article{ding2023unireplknet,
  title={UniRepLKNet: A Universal Perception Large-Kernel ConvNet for Audio, Video, Point Cloud, Time-Series and Image Recognition},
  author={Ding, Xiaohan and Zhang, Yiyuan and Ge, Yixiao and Zhao, Sijie and Song, Lin and Yue, Xiangyu and Shan, Ying},
  journal={arXiv preprint arXiv:2311.15599},
  year={2023}
}

License

This project is released under the Apache 2.0 license. Please see the LICENSE file for more information.

About

[CVPR'24] UniRepLKNet: A Universal Perception Large-Kernel ConvNet for Audio, Video, Point Cloud, Time-Series and Image Recognition

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published