Skip to content

RashLog/PUTM

Repository files navigation

PUTM

Official PyTorch Implementation of PUTM: Prototypes-oriented Transductive Few-shot Learning with Conditional Transport(ICCV 2023).

Abstract:

Transductive Few-Shot Learning (TFSL) has recently attracted increasing attention since it typically outperforms its inductive peer by leveraging statistics of query samples. However, previous TFSL methods usually encode uniform prior that all the classes within query samples are equally likely, which is biased in imbalanced TFSL and causes severe performance degradation. Given this pivotal issue, in this work, we propose a novel Conditional Transport (CT) based imbalanced TFSL model called Prototypes-oriented Unbiased Transfer Model (PUTM) to fully exploit unbiased statistics of imbalanced query samples, which employs forward and backward navigators as transport matrices to balance the prior of query samples per class between uniform and adaptive data-driven distributions. For efficiently transferring statistics learned by CT, we further derive a closed form solution to refine prototypes based on MAP given the learned navigators. The above two steps of discovering and transferring unbiased statistics follow an iterative manner, formulating our EM-based solver. Experimental results on four standard benchmarks including miniImageNet, tieredImageNet, CUB, and CIFAR-FS demonstrate superiority of our model in class-imbalanced generalization.

image-20230903200722764

Download Features

For WRN-28-10(S2M2) backbone, we used the pre-trained features provided by S2M2_fewshot and can be downloaded from here checkpoints.

For WRN-28-10(softmax) and ResNet18 backbones, we used the pre-trained features provided by TIM and can be downloaded from here checkpoints.

If the above links are not available, you can also download features from here features.

Code Structure

├── features
│   ├── resnet18
│   │   ├── cub
│   │   └── mini
│   ├── wideres
│   │   ├── mini
│   │   └── tiered
│   └── wrn_s2m2
│       ├── cifar
│       ├── cub
│       ├── mini
│       └── tiered
├── PUTM
│   ├── cache
│   ├── config
│   │   ├── balanced
│   │   └── dirichlet
│   ├── datasets
│   ├── methods
│   ├── models
│   ├── split
│   │   ├── cub
│   │   ├── mini
│   │   └── tiered
│   └── train_mean # train_mean files for some methods, not necessary

Evaluation

Firstly, you should modify the configuration file "config/base_config.yaml" for evaluation on different settings. (dataset, balanced/imbalanced, backbone, etc)

cd PUTM
mkdir cache

python eval.py --base_config config/base_config.yaml --method_config config/[balanced, dirichlet]/methods_config/[method_name].yaml

For example, if you want to evaluate PUTM on imbalanced setting, then use the following config file,

MODEL:
  arch: 'wrn_s2m2' # ('resnet18', 'wideres', 'wrn_s2m2')

DATA:
  dataset: 'mini' # ('mini', 'cub', 'cifar', 'tiered')
  batch_size_loader: 256
  enlarge: True
  num_workers: 4
  disable_random_resize: False
  jitter: False
  path: 'data'

EVAL:
  evaluate: True      # Set to True to evaluate methods
  number_tasks: 1000 # Number of tasks to evaluate
  batch_size: 1000
  n_ways: 5
  n_query: 15 # Balanced case: 15 query data per class
  balanced: 'dirichlet' # ('balanced' | 'dirichlet')
  alpha_dirichlet: 2
  model_tag: 'best'
  plt_metrics: ['accs']
  shots: [1, 3, 5]
  used_set: 'test'
  fresh_start: False

and run this command,

python eval.py --base_config config/base_config.yaml --method_config   config/dirichlet/methods_config/putm.yaml

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages