Skip to content

Commit

Permalink
upate
Browse files Browse the repository at this point in the history
  • Loading branch information
HobbitLong committed Sep 2, 2020
1 parent 9eabfb2 commit f8c837b
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 47 deletions.
52 changes: 31 additions & 21 deletions dataset/transform_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,49 +28,62 @@
])
]

transform_A_test = [

transform_B = [
transforms.Compose([
lambda x: Image.fromarray(x),
transforms.RandomCrop(84, padding=8),
transforms.RandomResizedCrop(84, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
lambda x: np.asarray(x),
transforms.ToTensor(),
normalize
transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))
]),

transforms.Compose([
lambda x: Image.fromarray(x),
transforms.Resize(92),
transforms.CenterCrop(84),
transforms.ToTensor(),
normalize
transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))
])
]

# CIFAR style transformation
mean = [0.5071, 0.4867, 0.4408]
std = [0.2675, 0.2565, 0.2761]
normalize_cifar100 = transforms.Normalize(mean=mean, std=std)
transform_D = [
transform_C = [
transforms.Compose([
lambda x: Image.fromarray(x),
transforms.RandomCrop(32, padding=4),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
# transforms.Resize(92, interpolation = PIL.Image.BICUBIC),
transforms.RandomResizedCrop(80),
# transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomHorizontalFlip(),
lambda x: np.asarray(x),
transforms.ToTensor(),
normalize_cifar100
# Lighting(0.1, imagenet_pca['eigval'], imagenet_pca['eigvec']),
# normalize
transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))
]),

transforms.Compose([
lambda x: Image.fromarray(x),
transforms.Resize(92),
transforms.CenterCrop(80),
transforms.ToTensor(),
normalize_cifar100
# normalize
transforms.Normalize(np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]))
])
]

transform_D_test = [
# CIFAR style transformation
mean = [0.5071, 0.4867, 0.4408]
std = [0.2675, 0.2565, 0.2761]
normalize_cifar100 = transforms.Normalize(mean=mean, std=std)
transform_D = [
transforms.Compose([
lambda x: Image.fromarray(x),
transforms.RandomCrop(32, padding=4),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
transforms.RandomHorizontalFlip(),
lambda x: np.asarray(x),
transforms.ToTensor(),
Expand All @@ -85,15 +98,12 @@
]


transforms_list = ['A', 'D']
transforms_list = ['A', 'B', 'C', 'D']


transforms_options = {
'A': transform_A,
'B': transform_B,
'C': transform_C,
'D': transform_D,
}

transforms_test_options = {
'A': transform_A_test,
'D': transform_D_test,
}
39 changes: 32 additions & 7 deletions eval/meta_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@

import torch
from sklearn import metrics
from sklearn.svm import SVC
from sklearn.svm import SVC, LinearSVC
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler


def mean_confidence_interval(data, confidence=0.95):
a = 1.0 * np.array(data)
Expand All @@ -27,7 +30,7 @@ def normalize(x):
return out


def meta_test(net, testloader, use_logit=True, is_norm=True, classifier='LR'):
def meta_test(net, testloader, use_logit=True, is_norm=True, classifier='LR', opt=None):
net = net.eval()
acc = []

Expand All @@ -36,9 +39,9 @@ def meta_test(net, testloader, use_logit=True, is_norm=True, classifier='LR'):
support_xs, support_ys, query_xs, query_ys = data
support_xs = support_xs.cuda()
query_xs = query_xs.cuda()
batch_size, _, height, width, channel = support_xs.size()
support_xs = support_xs.view(-1, height, width, channel)
query_xs = query_xs.view(-1, height, width, channel)
batch_size, _, channel, height, width = support_xs.size()
support_xs = support_xs.view(-1, channel, height, width)
query_xs = query_xs.view(-1, channel, height, width)

if use_logit:
support_features = net(support_xs).view(support_xs.size(0), -1)
Expand All @@ -59,20 +62,29 @@ def meta_test(net, testloader, use_logit=True, is_norm=True, classifier='LR'):
support_ys = support_ys.view(-1).numpy()
query_ys = query_ys.view(-1).numpy()

# clf = SVC(gamma='auto', C=0.1)
if classifier == 'LR':
clf = LogisticRegression(penalty='l2',
random_state=0,
C=10,
C=1.0,
solver='lbfgs',
max_iter=1000,
class_weight='balanced',
multi_class='multinomial')
clf.fit(support_features, support_ys)
query_ys_pred = clf.predict(query_features)
elif classifier == 'SVM':
clf = make_pipeline(StandardScaler(), SVC(gamma='auto',
C=1,
kernel='linear',
decision_function_shape='ovr'))
clf.fit(support_features, support_ys)
query_ys_pred = clf.predict(query_features)
elif classifier == 'NN':
query_ys_pred = NN(support_features, support_ys, query_features)
elif classifier == 'Cosine':
query_ys_pred = Cosine(support_features, support_ys, query_features)
elif classifier == 'Proto':
query_ys_pred = Proto(support_features, support_ys, query_features, opt)
else:
raise NotImplementedError('classifier not supported: {}'.format(classifier))

Expand All @@ -81,6 +93,19 @@ def meta_test(net, testloader, use_logit=True, is_norm=True, classifier='LR'):
return mean_confidence_interval(acc)


def Proto(support, support_ys, query, opt):
"""Protonet classifier"""
nc = support.shape[-1]
support = np.reshape(support, (-1, 1, opt.n_ways, opt.n_shots, nc))
support = support.mean(axis=3)
batch_size = support.shape[0]
query = np.reshape(query, (batch_size, -1, 1, nc))
logits = - ((query - support)**2).sum(-1)
pred = np.argmax(logits, axis=-1)
pred = np.reshape(pred, (-1,))
return pred


def NN(support, support_ys, query):
"""nearest classifier"""
support = np.expand_dims(support.transpose(), 0)
Expand Down
54 changes: 35 additions & 19 deletions eval_fewshot.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
from __future__ import print_function

import os
import argparse
import socket
import time
import sys

import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from models import model_pool
from models import model_dict, model_pool
from models.util import create_model

from dataset.mini_imagenet import MetaImageNet
from dataset.tiered_imagenet import MetaTieredImageNet
from dataset.cifar import MetaCIFAR100
from dataset.transform_cfg import transforms_test_options, transforms_list
from dataset.mini_imagenet import ImageNet, MetaImageNet
from dataset.tiered_imagenet import TieredImageNet, MetaTieredImageNet
from dataset.cifar import CIFAR100, MetaCIFAR100
from dataset.transform_cfg import transforms_options, transforms_list

from eval.meta_eval import meta_test


def parse_option():

hostname = socket.gethostname()

parser = argparse.ArgumentParser('argument for training')

# load pretrained model
Expand All @@ -32,9 +36,6 @@ def parse_option():
'CIFAR-FS', 'FC100'])
parser.add_argument('--transform', type=str, default='A', choices=transforms_list)

# specify data_root
parser.add_argument('--data_root', type=str, default='', help='path to data root')

# meta setting
parser.add_argument('--n_test_runs', type=int, default=600, metavar='N',
help='Number of test runs')
Expand All @@ -46,6 +47,8 @@ def parse_option():
help='Number of query in test')
parser.add_argument('--n_aug_support_samples', default=5, type=int,
help='The number of augmented samples for each meta test sample')
parser.add_argument('--data_root', type=str, default='data', metavar='N',
help='Root dataset')
parser.add_argument('--num_workers', type=int, default=3, metavar='N',
help='Number of workers for dataloader')
parser.add_argument('--test_batch_size', type=int, default=1, metavar='test_batch_size',
Expand All @@ -59,11 +62,16 @@ def parse_option():
opt.use_trainval = False

# set the path according to the environment
if not opt.data_root:
opt.data_root = './data/{}'.format(opt.dataset)
if hostname.startswith('visiongpu'):
opt.data_root = '/data/vision/phillipi/rep-learn/{}'.format(opt.dataset)
opt.data_aug = True
elif hostname.startswith('instance'):
opt.data_root = '/mnt/globalssd/fewshot/{}'.format(opt.dataset)
opt.data_aug = True
elif opt.data_root != 'data':
opt.data_aug = True
else:
opt.data_root = '{}/{}'.format(opt.data_root, opt.dataset)
opt.data_aug = True
raise NotImplementedError('server invalid: {}'.format(hostname))

return opt

Expand All @@ -74,9 +82,11 @@ def main():

# test loader
args = opt
args.batch_size = args.test_batch_size
# args.n_aug_support_samples = 1

if opt.dataset == 'miniImageNet':
train_trans, test_trans = transforms_test_options[opt.transform]
train_trans, test_trans = transforms_options[opt.transform]
meta_testloader = DataLoader(MetaImageNet(args=opt, partition='test',
train_transform=train_trans,
test_transform=test_trans,
Expand All @@ -94,7 +104,7 @@ def main():
else:
n_cls = 64
elif opt.dataset == 'tieredImageNet':
train_trans, test_trans = transforms_test_options[opt.transform]
train_trans, test_trans = transforms_options[opt.transform]
meta_testloader = DataLoader(MetaTieredImageNet(args=opt, partition='test',
train_transform=train_trans,
test_transform=test_trans,
Expand All @@ -112,7 +122,7 @@ def main():
else:
n_cls = 351
elif opt.dataset == 'CIFAR-FS' or opt.dataset == 'FC100':
train_trans, test_trans = transforms_test_options['D']
train_trans, test_trans = transforms_options['D']
meta_testloader = DataLoader(MetaCIFAR100(args=opt, partition='test',
train_transform=train_trans,
test_transform=test_trans,
Expand Down Expand Up @@ -150,22 +160,28 @@ def main():
start = time.time()
val_acc, val_std = meta_test(model, meta_valloader)
val_time = time.time() - start
print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(val_acc, val_std, val_time))
print('val_acc: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(val_acc, val_std,
val_time))

start = time.time()
val_acc_feat, val_std_feat = meta_test(model, meta_valloader, use_logit=False)
val_time = time.time() - start
print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(val_acc_feat, val_std_feat, val_time))
print('val_acc_feat: {:.4f}, val_std: {:.4f}, time: {:.1f}'.format(val_acc_feat,
val_std_feat,
val_time))

start = time.time()
test_acc, test_std = meta_test(model, meta_testloader)
test_time = time.time() - start
print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std, test_time))
print('test_acc: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc, test_std,
test_time))

start = time.time()
test_acc_feat, test_std_feat = meta_test(model, meta_testloader, use_logit=False)
test_time = time.time() - start
print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc_feat, test_std_feat, test_time))
print('test_acc_feat: {:.4f}, test_std: {:.4f}, time: {:.1f}'.format(test_acc_feat,
test_std_feat,
test_time))


if __name__ == '__main__':
Expand Down

0 comments on commit f8c837b

Please sign in to comment.