Skip to content

Commit

Permalink
Merge pull request #73 from amzn/SIB
Browse files Browse the repository at this point in the history
Code for Synthetic Info Bottleneck method
  • Loading branch information
adamian committed Mar 31, 2020
2 parents 0a78388 + fbe887e commit e5c8e60
Show file tree
Hide file tree
Showing 20 changed files with 2,324 additions and 1 deletion.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ In more detail:
- [xfer-ml](xfer-ml): A library that allows quick and easy transfer of knowledge stored in deep neural networks implemented in MXNet. xfer-ml can be used with data of arbitrary numeric format, and can be applied to the common cases of image or text data. It can be used as a pipeline that spans from extracting features to training a repurposer. The repurposer is then an object that carries out predictions in the target task. You can also use individual components of the library as part of your own pipeline. For example, you can leverage the feature extractor to extract features from deep neural networks or ModelHandler, which allows for quick building of neural networks, even if you are not an MXNet expert.
- [leap](leap): MXNet implementation of "leap", the meta-gradient path learner published in ICLR 2019: [(link)](https://arxiv.org/abs/1812.01054) by S. Flennerhag, P. G. Moreno, N. Lawrence, A. Damianou.
- [nn_similarity_index](nn_similarity_index): PyTorch code for comparing trained neural networks using both feature and gradient information.
- [finite_ntk](finite_ntk): PyTorch implementation of finite width neural tangent kernels from the paper *On Transfer Learning with Linearised Neural Networks* [(link)](http://metalearning.ml/2019/papers/metalearn2019-maddox.pdf), by W. Maddox, S. Tang, P. G. Moreno, A. G. Wilson, and A. Damianou which appeared at the 3rd MetaLearning Workshop at NeurIPS, 2019.
- [finite_ntk](finite_ntk): PyTorch implementation of finite width neural tangent kernels from the paper *On Transfer Learning with Linearised Neural Networks* [(link)](http://metalearning.ml/2019/papers/metalearn2019-maddox.pdf), by W. Maddox, S. Tang, P. G. Moreno, A. G. Wilson, and A. Damianou which appeared at the 3rd MetaLearning Workshop at NeurIPS, 2019.
- [synthetic_info_bottleneck](synthetic_info_bottleneck) PyTorch implementation of the *synthetic information bottleneck* algorithm for few-shot classification on Mini-ImageNet, which is used in the ICLR 2020 paper *Empirical Bayes Transductive Meta-Learning with Synthetic Gradients* [(link)](https://openreview.net/forum?id=Hkg-xgrYvH).


Navigate to the corresponding folder for more details.

Expand Down
69 changes: 69 additions & 0 deletions synthetic_info_bottleneck/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# \[ICLR 2020\] Synthetic information bottleneck for transductive meta-learning
This repo contains the implementation of the *synthetic information bottleneck* algorithm for few-shot classification on Mini-ImageNet,
which is used in our ICLR 2020 paper
[Empirical Bayes Transductive Meta-Learning with Synthetic Gradients](https://openreview.net/forum?id=Hkg-xgrYvH).

If our code is helpful for your research, please consider citing:
``` Bash
@inproceedings{
Hu2020Empirical,
title={Empirical Bayes Transductive Meta-Learning with Synthetic Gradients},
author={Shell Xu Hu and Pablo Garcia Moreno and Yang Xiao and Xi Shen and Guillaume Obozinski and Neil Lawrence and Andreas Damianou},
booktitle={International Conference on Learning Representations (ICLR)},
year={2020},
url={https://openreview.net/forum?id=Hkg-xgrYvH}
}
```

## Authors of the code
[Shell Xu Hu](http://hushell.github.io/), [Xi Shen](https://xishen0220.github.io/) and [Yang Xiao](https://youngxiao13.github.io/)


## Dependencies
The code is tested under **Pytorch > 1.0 + Python 3.6** environment with extra packages:
``` Bash
pip install -r requirements.txt
```


## How to use the code on Mini-ImageNet?
### **Step 0**: Download Mini-ImageNet dataset

``` Bash
cd data
bash download_miniimagenet.sh
cd ..
```

### **Step 1** (optional): train a WRN-28-10 feature network (aka backbone)
The weights of the feature network are downloaded in step 0, but you may also train from scratch by running

``` Bash
python main_feat.py --outDir miniImageNet_WRN_60Epoch --cuda --dataset miniImageNet --nbEpoch 60
```

### **Step 2**: Meta-training on Mini-ImageNet, e.g., 5-way-1-shot:

``` Bash
python main.py --config config/miniImageNet_1shot.yaml --seed 100 --gpu 0
```

### **Step 3**: Meta-testing on Mini-ImageNet with a checkpoint:

``` Bash
python main.py --config config/miniImageNet_1shot.yaml --seed 100 --gpu 0 --ckpt cache/miniImageNet_1shot_K3_seed100/outputs_xx.xxx/netSIBBestxx.xxx.pth
```

## Mini-ImageNet Results (LAST ckpt)

| Setup | 5-way-1-shot | 5-way-5-shot |
| ------------- | -------------:| ------------:|
| SIB (K=3) | 70.700% ± 0.585% | 80.045% ± 0.363%|
| SIB (K=5) | 70.494 ± 0.619% | 80.192% ± 0.372%|

## CIFAR-FS Results (LAST ckpt)

| Setup | 5-way-1-shot | 5-way-5-shot |
| ------------- | -------------:| ------------:|
| SIB (K=3) | 79.763% ± 0.577% | 85.721% ± 0.369%|
| SIB (K=5) | 79.627 ± 0.593% | 85.590% ± 0.375%|
267 changes: 267 additions & 0 deletions synthetic_info_bottleneck/algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
# ==============================================================================

import os
import itertools
import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from utils.outils import progress_bar, AverageMeter, accuracy, getCi
from utils.utils import to_device

class Algorithm:
"""
Algorithm logic is implemented here with training and validation functions etc.
:param args: experimental configurations
:type args: EasyDict
:param logger: logger
:param netFeat: feature network
:type netFeat: class `WideResNet` or `ConvNet_4_64`
:param netSIB: Classifier/decoder
:type netSIB: class `ClassifierSIB`
:param optimizer: optimizer
:type optimizer: torch.optim.SGD
:param criterion: loss
:type criterion: nn.CrossEntropyLoss
"""
def __init__(self, args, logger, netFeat, netSIB, optimizer, criterion):
self.netFeat = netFeat
self.netSIB = netSIB
self.optimizer = optimizer
self.criterion = criterion

self.nbIter = args.nbIter
self.nStep = args.nStep
self.outDir = args.outDir
self.nFeat = args.nFeat
self.batchSize = args.batchSize
self.nEpisode = args.nEpisode
self.momentum = args.momentum
self.weightDecay = args.weightDecay

self.logger = logger
self.device = torch.device('cuda' if args.cuda else 'cpu')

# Load pretrained model
if args.resumeFeatPth :
if args.cuda:
param = torch.load(args.resumeFeatPth)
else:
param = torch.load(args.resumeFeatPth, map_location='cpu')
self.netFeat.load_state_dict(param)
msg = '\nLoading netFeat from {}'.format(args.resumeFeatPth)
self.logger.info(msg)

if args.test:
self.load_ckpt(args.ckptPth)


def load_ckpt(self, ckptPth):
"""
Load checkpoint from ckptPth.
:param ckptPth: the path to the ckpt
:type ckptPth: string
"""
param = torch.load(ckptPth)
self.netFeat.load_state_dict(param['netFeat'])
self.netSIB.load_state_dict(param['SIB'])
lr = param['lr']
self.optimizer = torch.optim.SGD(itertools.chain(*[self.netSIB.parameters(),]),
lr,
momentum=self.momentum,
weight_decay=self.weightDecay,
nesterov=True)
msg = '\nLoading networks from {}'.format(ckptPth)
self.logger.info(msg)


def compute_grad_loss(self, clsScore, QueryLabel):
"""
Compute the loss between true gradients and synthetic gradients.
"""
# register hooks
def require_nonleaf_grad(v):
def hook(g):
v.grad_nonleaf = g
h = v.register_hook(hook)
return h
handle = require_nonleaf_grad(clsScore)

loss = self.criterion(clsScore, QueryLabel)
loss.backward(retain_graph=True) # need to backward again

# remove hook
handle.remove()

gradLogit = self.netSIB.dni(clsScore) # B * n x nKnovel
gradLoss = F.mse_loss(gradLogit, clsScore.grad_nonleaf.detach())

return loss, gradLoss


def validate(self, valLoader, lr=None, mode='val'):
"""
Run one epoch on val-set.
:param valLoader: the dataloader of val-set
:type valLoader: class `ValLoader`
:param float lr: learning rate for synthetic GD
:param string mode: 'val' or 'train'
"""
if mode == 'test':
nEpisode = self.nEpisode
self.logger.info('\n\nTest mode: randomly sample {:d} episodes...'.format(nEpisode))
elif mode == 'val':
nEpisode = len(valLoader)
self.logger.info('\n\nValidation mode: pre-defined {:d} episodes...'.format(nEpisode))
valLoader = iter(valLoader)
else:
raise ValueError('mode is wrong!')

episodeAccLog = []
top1 = AverageMeter()

self.netFeat.eval()

if lr is None:
lr = self.optimizer.param_groups[0]['lr']

#for batchIdx, data in enumerate(valLoader):
for batchIdx in range(nEpisode):
data = valLoader.getEpisode() if mode == 'test' else next(valLoader)
data = to_device(data, self.device)

SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
data['SupportTensor'].squeeze(0), data['SupportLabel'].squeeze(0), \
data['QueryTensor'].squeeze(0), data['QueryLabel'].squeeze(0)

with torch.no_grad():
SupportFeat, QueryFeat = self.netFeat(SupportTensor), self.netFeat(QueryTensor)
SupportFeat, QueryFeat, SupportLabel = \
SupportFeat.unsqueeze(0), QueryFeat.unsqueeze(0), SupportLabel.unsqueeze(0)

clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr)
clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1], -1)
QueryLabel = QueryLabel.view(-1)
acc1 = accuracy(clsScore, QueryLabel, topk=(1,))
top1.update(acc1[0].item(), clsScore.shape[0])

msg = 'Top1: {:.3f}%'.format(top1.avg)
progress_bar(batchIdx, nEpisode, msg)
episodeAccLog.append(acc1[0].item())

mean, ci95 = getCi(episodeAccLog)
self.logger.info('Final Perf with 95% confidence intervals: {:.3f}%, {:.3f}%'.format(mean, ci95))
return mean, ci95


def train(self, trainLoader, valLoader, lr=None, coeffGrad=0.0) :
"""
Run one epoch on train-set.
:param trainLoader: the dataloader of train-set
:type trainLoader: class `TrainLoader`
:param valLoader: the dataloader of val-set
:type valLoader: class `ValLoader`
:param float lr: learning rate for synthetic GD
:param float coeffGrad: deprecated
"""
bestAcc, ci = self.validate(valLoader, lr)
self.logger.info('Acc improved over validation set from 0% ---> {:.3f} +- {:.3f}%'.format(bestAcc,ci))

self.netSIB.train()
self.netFeat.eval()

losses = AverageMeter()
top1 = AverageMeter()
history = {'trainLoss' : [], 'trainAcc' : [], 'valAcc' : []}

for episode in range(self.nbIter):
data = trainLoader.getBatch()
data = to_device(data, self.device)

with torch.no_grad() :
SupportTensor, SupportLabel, QueryTensor, QueryLabel = \
data['SupportTensor'], data['SupportLabel'], data['QueryTensor'], data['QueryLabel']
nC, nH, nW = SupportTensor.shape[2:]

SupportFeat = self.netFeat(SupportTensor.reshape(-1, nC, nH, nW))
SupportFeat = SupportFeat.view(self.batchSize, -1, self.nFeat)

QueryFeat = self.netFeat(QueryTensor.reshape(-1, nC, nH, nW))
QueryFeat = QueryFeat.view(self.batchSize, -1, self.nFeat)

if lr is None:
lr = self.optimizer.param_groups[0]['lr']

self.optimizer.zero_grad()

clsScore = self.netSIB(SupportFeat, SupportLabel, QueryFeat, lr)
clsScore = clsScore.view(QueryFeat.shape[0] * QueryFeat.shape[1], -1)
QueryLabel = QueryLabel.view(-1)

if coeffGrad > 0:
loss, gradLoss = self.compute_grad_loss(clsScore, QueryLabel)
loss = loss + gradLoss * coeffGrad
else:
loss = self.criterion(clsScore, QueryLabel)

loss.backward()
self.optimizer.step()

acc1 = accuracy(clsScore, QueryLabel, topk=(1, ))
top1.update(acc1[0].item(), clsScore.shape[0])
losses.update(loss.item(), QueryFeat.shape[1])
msg = 'Loss: {:.3f} | Top1: {:.3f}% '.format(losses.avg, top1.avg)
if coeffGrad > 0:
msg = msg + '| gradLoss: {:.3f}%'.format(gradLoss.item())
progress_bar(episode, self.nbIter, msg)

if episode % 1000 == 999 :
acc, _ = self.validate(valLoader, lr)

if acc > bestAcc :
msg = 'Acc improved over validation set from {:.3f}% ---> {:.3f}%'.format(bestAcc , acc)
self.logger.info(msg)

bestAcc = acc
self.logger.info('Saving Best')
torch.save({
'lr': lr,
'netFeat': self.netFeat.state_dict(),
'SIB': self.netSIB.state_dict(),
'nbStep': self.nStep,
}, os.path.join(self.outDir, 'netSIBBest.pth'))

self.logger.info('Saving Last')
torch.save({
'lr': lr,
'netFeat': self.netFeat.state_dict(),
'SIB': self.netSIB.state_dict(),
'nbStep': self.nStep,
}, os.path.join(self.outDir, 'netSIBLast.pth'))

msg = 'Iter {:d}, Train Loss {:.3f}, Train Acc {:.3f}%, Val Acc {:.3f}%'.format(
episode, losses.avg, top1.avg, acc)
self.logger.info(msg)
history['trainLoss'].append(losses.avg)
history['trainAcc'].append(top1.avg)
history['valAcc'].append(acc)

losses = AverageMeter()
top1 = AverageMeter()

return bestAcc, acc, history
26 changes: 26 additions & 0 deletions synthetic_info_bottleneck/config/CIFAR_1shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Few-shot dataset
nClsEpisode: 5 # number of categories in each episode
nSupport: 1 # number of samples per category in the support set
nQuery: 15 # number of samples per category in the query set
dataset: 'Cifar' # choices = ['miniImageNet', 'Cifar']

# Network
nStep: 3 # number of synthetic gradient steps
architecture: 'WRN_28_10' # choices = ['WRN_28_10', 'Conv64_4']
batchSize: 1 # number of episodes in each batch

# Optimizer
lr: 0.001 # lr is fixed
weightDecay: 0.0005
momentum: 0.9

# Training details
expName: cifar-fs
nbIter: 50000 # number of training iterations
seed: 100 # can be reset with --seed
gpu: '1' # can be reset with --gpu
resumeFeatPth : './ckpts/CIFAR-FS/netFeatBest62.561.pth' # feat ckpt
coeffGrad: 0 # grad loss coeff

# Testing
nEpisode: 2000 # number of episodes for testing

0 comments on commit e5c8e60

Please sign in to comment.