Skip to content

TjuJianyu/RRL

Repository files navigation

learning useful representations for shifting tasks and distributions

Official Pytorch implementation of paper

Jianyu Zhang, Léon Bottou

Requirements

  • python==3.7
  • torch>=1.13.1
  • torchvision>=0.14.1
  • pyyaml==6.0
  • classy-vision==0.6.0

Datasets

We consider the following datasets:

Download and extract ImageNet and Inaturalist18 datasets to data/imagenet and data/inaturalist18. The resulting folder structure should be:

📦 RRL
 ┣ 📂data
 ┃ ┣ 📂imagenet
 ┃ ┣ 📂inaturalist18

Supervised transfer learning (ResNet)

Download (ImageNet1k) pretrained checkpoints:

You can get pretrained checkpoints either:

The resulting folder structure should be:

📦 RRL
 ┣ 📂checkpoints
 ┃ ┣ 📂supervised_pretrain
 ┃ ┃ ┣ 📂resnet50
 ┃ ┃ ┃ ┣📜 checkpoint_run0.pth.tar 
 ┃ ┃ ┃ ┃ ...            
 ┃ ┃ ┃ ┗📜 checkpoint_run9.pth.tar 
 ┃ ┃ ┣📜 2resnet50_imagenet1k_supervised.pth.tar
 ┃ ┃ ┣📜 4resnet50_imagenet1k_supervised.pth.tar
 ┃ ┃ ┣📜 resnet50w2_imagenet1k_supervised.pth.tar
 ┃ ┃ ┗📜 resnet50w4_imagenet1k_supervised.pth.tar
 ┃ ┃ ┗📜 resnet50_imagenet1k_supervised_distill5.pth.tar

Transfer by Linear Probing, Fine-Tuning, and Two-stage Fine-Tuning:

Transfer the learned representation (on ImageNet1k) to Cifar10, Cifar100, and Inaturalist18 by:

  • Linear Probing: concatenate these representation and learn a big linear classifier on top.
  • (Normal) Fine tuning: concatenate pretrained representations then fine tuning all weights.
  • (Two-stage) Fine tuning: fine-tune each pretrained representation on target tasks separately, then concatenate the representation and apply linear probing.

The following table provides scripts for these transfer learning experiments:

method architecture target task linear probing fine-tuning two-stage fine-tuning
ERM resnet50 Cifar10/Cifar100 scripts scripts -
ERM resnet50w2/w4 2x/4xresnet50 Cifar10/Cifar100 scripts scripts -
CAT - Cifar10/Cifar100 scripts scripts scripts
Distill resnet50 Cifar10/Cifar100 scripts scripts -
ERM resnet50 Inaturalist18 scripts scripts -
ERM resnet50w2/w4 2x/4xresnet50 Inaturalist18 scripts scripts -
CAT - Inaturalist18 scripts scripts scripts
Distill resnet50 Inaturalist18 scripts scripts -

Tab1: transfer learning experiments scripts.

The following figure shows (focus on solid curves) the transfer learning performance of different representations (ERM / CAT / Distill) and transfer methods (pinear probing / fine-tuning / two-stage fine-tuning).

Fig1: Supervised transfer learning from ImageNet to Inat18, Cifar100, and Cifar10. The top row shows the superior linear probing performance of the CATn networks (blue, “cat”). The bottom row shows the performance of fine-tuned CATn, which is poor with normal fine-tuning (gray, “[init]cat”) and excellent for two-stage fine tuning (blue, “[2ft]cat”). DISTILLn (pink, “distill”) representation is obtained by distilling CATn into one ResNet50.

Supervised transfer learning (ViT)

Download (Imagenet21k) pretrained & (ImageNet1k) finetuned ViT checkpoints according to download_checkpoint.md

The resulting folder structure looks like:

📦 RRL
 ┣ 📂checkpoints
 ┃ ┣ 📂supervised_pretrain
 ┃ ┃ ┣ 📂vit
 ┃ ┃ ┃ ┣📜 vitaugreg/imagenet21k/ViT-B_16.npz
 ┃ ┃ ┃ ┣📜 vitaugreg/imagenet21k/ViT-L_16.npz
 ┃ ┃ ┃ ┣📜 vit/imagenet21k/ViT-B_16.npz
 ┃ ┃ ┃ ┗📜 vit/imagenet21k/ViT-L_16.npz
 ┃ ┃ ┣📜 vitaugreg/imagenet21k/imagenet2012/ViT-L_16.npz
 ┃ ┃ ┣📜 vitaugreg/imagenet21k/imagenet2012/ViT-L_16.npz
 ┃ ┃ ┣📜 vit/imagenet21k/imagenet2012/ViT-L_16.npz
 ┃ ┃ ┣📜 vit/imagenet21k/imagenet2012/ViT-L_16.npz

With the same experiment protocol as Tab1, we can have the following transfer learning curves with Vision Transformer:

Fig2:

self-supervised transfer learning

Download SWAV and SEER checkpoints according to download_checkpoint.md

The resulting folder structure looks like:

📦 RRL
 ┣ 📂checkpoints
 ┃ ┣ 📂self_supervised_pretrain
 ┃ ┃ ┣📜 swav_400ep_pretrain.pth.tar
 ┃ ┃ ┣📜 swav_RN50w2_400ep_pretrain.pth.tar
 ┃ ┃ ┣📜 swav_RN50w4_400ep_pretrain.pth.tar
 ┃ ┃ ┣📜 swav_RN50w5_400ep_pretrain.pth.tar
 ┃ ┃ ┣📜 swav_400ep_pretrain_seed5.pth.tar
 ┃ ┃ ┣📜 swav_400ep_pretrain_seed6.pth.tar
 ┃ ┃ ┣📜 swav_400ep_pretrain_seed7.pth.tar
 ┃ ┃ ┣📜 swav_400ep_pretrain_seed8.pth.tar
 ┃ ┃ ┣📜 seer_regnet32gf.pth
 ┃ ┃ ┣📜 seer_regnet64gf.pth
 ┃ ┃ ┣📜 seer_regnet128gf.pth
 ┃ ┃ ┣📜 seer_regnet256gf.pth
 ┃ ┃ ┣📜 seer_regnet32gf_finetuned.pth
 ┃ ┃ ┣📜 seer_regnet64gf_finetuned.pth
 ┃ ┃ ┣📜 seer_regnet128gf_finetuned.pth
 ┃ ┃ ┣📜 seer_regnet256gf_finetuned.pth

With the same experiment protocol as Tab1, we can have the following self-supervised transfer learning curves:

Fig2: Self-supervised transfer learning with SWAV trained on unlabeled ImageNet(1K) (top row) and with SEER on Instagram1B (bottom row). The constructed rich representation, CATn, yields the best linear probing performance (“cat” and “catsub”) for supervised ImageNet, INAT18, CIFAR100, and CIFAR10 target tasks. The two-stage fine-tuning (“[2ft]cat”) matches equivalently sized baseline models (“[init]wide” and “[init]wide&deep”), but with much easier training. The sub-networks of CAT5 (and CAT2) in SWAV hold the same architecture

Meta-learning & few-shots learning and Out-of-distribution generalization

Fig3: Few-shot learning performance on MINIIMAGENET and CUB. Four common few-shot learning algorithms are shown in red (results from Chen et al. (2019)(https://arxiv.org/abs/1904.04232)). Two supervised transfer methods, with either a linear classifier (BASELINE) or cosine- based classifier (BASELINE++) are shown in blue. The DISTILL and CAT results, with a cosine-base classifier, are respectively shown in orange and gray. The CAT5-S and DISTILL5-S results were obtained using five snapshots taken during a single training episode with a relatively high step size. The dark blue line shows the best individual snapshot. Standard deviations over five repeats are reported.

Fig4: Test accuracy on the CAMELYON17 dataset with DENSENET121. We compare various initialization (ERM, CATn, DISTILLn, and Bonsai(https://arxiv.org/pdf/2203.15516.pdf)) for two algorithms VREX and ERM using either the IID or OOD hyperparameter tuning method. The standard deviations over 5 runs are reported.

Citation

If you find this code useful for your research, please consider citing our work:

@inproceedings{zhang2023learning,
  title={Learning useful representations for shifting tasks and distributions},
  author={Zhang, Jianyu and Bottou, L{\'e}on},
  booktitle={International Conference on Machine Learning},
  pages={40830--40850},
  year={2023},
  organization={PMLR}
}

About

Rich Representation Learning

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published