/
eval_matchnet.py
82 lines (72 loc) · 3.43 KB
/
eval_matchnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import argparse
import os.path as osp
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from feat.dataloader.samplers import CategoriesSampler
from feat.models.matchnet import MatchNet
from feat.utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric, compute_confidence_interval
from tensorboardX import SummaryWriter
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--way', type=int, default=5)
parser.add_argument('--shot', type=int, default=1)
parser.add_argument('--query', type=int, default=15)
parser.add_argument('--use_bilstm', type=bool, default=False)
parser.add_argument('--model_type', type=str, default='ConvNet', choices=['ConvNet', 'ResNet'])
parser.add_argument('--dataset', type=str, default='MiniImageNet', choices=['MiniImageNet', 'CUB', 'TieredImageNet'])
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--gpu', default='0')
args = parser.parse_args()
args.temperature = 1
pprint(vars(args))
set_gpu(args.gpu)
if args.dataset == 'MiniImageNet':
from feat.dataloader.mini_imagenet import MiniImageNet as Dataset
elif args.dataset == 'CUB':
from feat.dataloader.cub import CUB as Dataset
elif args.dataset == 'TieredImageNet':
from feat.dataloader.tiered_imagenet import tieredImageNet as Dataset
else:
raise ValueError('Non-supported Dataset.')
model = MatchNet(args)
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
model = model.cuda()
test_set = Dataset('test', args)
sampler = CategoriesSampler(test_set.label, 10000, args.way, args.shot + args.query)
loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True)
test_acc_record = np.zeros((10000,))
model.load_state_dict(torch.load(args.model_path)['params'])
model.eval()
ave_acc = Averager()
label = torch.arange(args.way).repeat(args.query)
if torch.cuda.is_available():
label = label.type(torch.cuda.LongTensor)
else:
label = label.type(torch.LongTensor)
label_support = torch.arange(args.way).repeat(args.shot)
label_support = label_support.type(torch.LongTensor)
# transform to one-hot form
label_support_onehot = torch.zeros(args.way * args.shot, args.way)
label_support_onehot.scatter_(1, label_support.unsqueeze(1), 1)
if torch.cuda.is_available():
label_support_onehot = label_support_onehot.cuda() # KN x N
with torch.no_grad():
for i, batch in enumerate(loader, 1):
if torch.cuda.is_available():
data, _ = [_.cuda() for _ in batch]
else:
data = batch[0]
k = args.way * args.shot
data_shot, data_query = data[:k], data[k:]
logits = model(data_shot, data_query) # KqN x KN x 1
# use logits to weights all labels, KN x N
prediction = torch.sum(torch.mul(logits, label_support_onehot.unsqueeze(0)), 1) # KqN x N
acc = count_acc(prediction, label)
ave_acc.add(acc)
test_acc_record[i-1] = acc
print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100))
m, pm = compute_confidence_interval(test_acc_record)
print('Test Acc {:.4f} + {:.4f}'.format(m, pm))