This repo contains the code for our paper "NTRENet++: Unleashing the Power of Non-target Knowledge for Few-shot Semantic Segmentation" by Nian Liu, Yuanwei Liu, Yi Wu, Hisham Cholakkal, Rao Muhammad Anwer, Xiwen Yao, and Junwei Han.
The framework of NTRENet++. The framework of clip-NTRENet++.Abstract: Few-shot semantic segmentation (FSS) aims to segment the target object under the condition of a few annotated samples. However, current studies on FSS primarily concentrate on extracting information related to the object, resulting in inadequate identification of ambiguous regions, particularly in non-target areas, including the background (BG) and Distracting Objects (DOs). Intuitively, to alleviate this problem, we propose a novel framework, namely NTRENet++, to explicitly mine and eliminate BG and DO regions in the query. First, we introduce a BG Mining Module (BGMM) to extract BG information and generate a comprehensive BG prototype from all images. For this purpose, a BG mining loss is formulated to supervise the learning of BGMM, utilizing only the known target object segmentation ground truth. Subsequently, based on this BG prototype, we employ a BG Eliminating Module to filter out the BG information from the query and obtain a BG-free result. Following this, the target information is utilized in the target matching module to generate the initial segmentation result. Finally, a DO Eliminating Module is proposed to further mine and eliminate DO regions, based on which we can obtain a BG and DO-free target object segmentation result. Moreover, we present a prototypical-pixel contrastive learning algorithm to enhance the model's capability to differentiate the target object from DOs. Extensive experiments conducted on both PASCAL-5i and COCO-20i datasets demonstrate the effectiveness of our approach despite its simplicity. Additionally, we extend our approach to the few-shot video segmentation task and achieve state-of-the-art performance on the YouTube-VIS dataset, demonstrating its generalization ability.
- Python 3.8
- PyTorch 1.7.0
- cuda 11.0
- torchvision 0.8.1
- tensorboardX 2.14
Please download the following datasets:
-
PASCAL-5i is based on the PASCAL VOC 2012 and SBD where the val images should be excluded from the list of training samples.
This code reads data from .txt files where each line contains the paths for image and the correcponding label respectively. Image and label paths are seperated by a space. Example is as follows:
image_path_1 label_path_1
image_path_2 label_path_2
image_path_3 label_path_3
...
image_path_n label_path_n
Then update the train/val/test list paths in the config files.
-
Update the config file by speficifying the target split and path (
weights
) for loading the checkpoint. -
Execute
mkdir initmodel
at the root directory. -
Download the ImageNet pretrained backbones and put them into the
initmodel
directory. -
Execute this command at the root directory:
python train.py
Performance comparison with the state-of-the-art approaches (i.e., PFENet) in terms of average mIoU across all folds.
Backbone | Method | 1-shot | 5-shot |
---|---|---|---|
ResNet50 | PFENet | 60.8 | 61.9 |
NTRENet++ (ours) | 65.3 (+4.5) | 66.4 (+4.5) | |
ResNet101 | PFENet | 60.1 | 61.4 |
NTRENet++ (ours) | 64.8 (+4.7) | 69.0 (+7.6) |
This repo is mainly built based on PFENet, RePRI, and SemSeg. Thanks for their great work!
If you find our work and this repository useful. Please consider giving a star ⭐ and citation 📚.