A PyTorch implementation of our paper "Finding Task-Relevant Features for Few-Shot Learning by Category Traversal", published at CVPR 2019, an ORAL presentation.
By Hongyang Li, David Eigen, Samuel Dodge, Matthew Zeiler, and Xiaogang Wang.
[End-of-internship Presentation Slides] (40 mins)
[Short Slides] (5 mins at CVPR)
(a) describes the conventional metric-based methods and (b) depicts the proposed CTM where features are traversed across categories for acquiring better representations.
The following figure shows a detailed configuration of our proposed CTM module.
- PyTorch
0.4
or above, tested in Linux/cluster/multi/single-gpu(s). - Datasets:
tieredImagenet
andminiImagenet
- A metric-based few-shot learning algorithm
- The proposed Category Traversal Module (CTM) serves as a plug-and-play unit to most existing methods, with ~2% improvement in accuracy.
There are some dependencies; be sure to install the newer version to be compatible with the latest pytorch. For example:
conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
conda install -c anaconda pillow pyyaml opencv requests
conda install -c conda-forge visdom
Prepare the dataset (miniImageNet
for example):
sh dataset/get_tier_and_mini.sh
python main.py --yaml_file configs/demo/mini/20way_1shot.yaml
We conduct all the experiments on tieredImagenet
and miniImagenet
benchmarks; to download
them, please refer to DATASET.md
.
Please refer to forward_CTM
method in the core/model.py
file for details.
The current version contains some legacy variable names in early trial experiments; we would remove them later and make the repo cleaner.
Please cite in the following manner if you find it useful in your research:
@inproceedings{li2019ctm,
title = {{Finding Task-Relevant Features for Few-Shot Learning by Category Traversal}},
author = {Hongyang Li and David Eigen and Samuel Dodge and Matthew Zeiler and Xiaogang Wang},
booktitle = {CVPR},
year = {2019}
}