In [1]:
import torch
import numpy as np
from torch.autograd import Variable
import torch.nn as nn
import torch.optim
import json
import torch.utils.data.sampler
import os
import glob
import random
import time
import h5py

import configs
import backbone
import data.feature_loader as feat_loader
from data.datamgr import SetDataManager
from methods.protonet import ProtoNet
from io_utils import model_dict, parse_args, get_resume_file, get_best_file , get_assigned_file

In [8]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset'     , default='CUB',        help='CUB/miniImagenet/cross/omniglot/cross_char')
parser.add_argument('--model'       , default='ResNet10',      help='model: Conv{4|6} / ResNet{10|18|34|50|101}') # 50 and 101 are not used in the paper
parser.add_argument('--method'      , default='protonet',   help='baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/maml{_approx}') #relationnet_softmax replace L2 norm with softmax to expedite training, maml_approx use first-order approximation in the gradient for efficiency
parser.add_argument('--train_n_way' , default=5, type=int,  help='class num to classify for training') #baseline and baseline++ would ignore this parameter
# parser.add_argument('--train_n_way' , default=30, type=int,  help='class num to classify for training') #baseline and baseline++ would ignore this parameter

parser.add_argument('--test_n_way'  , default=5, type=int,  help='class num to classify for testing (validation) ') #baseline and baseline++ only use this parameter in finetuning

parser.add_argument('--n_shot'      , default=5, type=int,  help='number of labeled data in each class, same as n_support') #baseline and baseline++ only use this parameter in finetuning
parser.add_argument('--train_aug'   , default=True, action='store_true',  help='perform data augmentation or not during training ') #still required for save_features.py and test.py to find the model path correctly

parser.add_argument('--split'       , default='novel', help='base/val/novel') #default novel, but you can also test base/val class accuracy if you want 
parser.add_argument('--save_iter', default=-1, type=int,help ='saved feature from the model trained in x epoch, use the best model if x is -1')
parser.add_argument('--adaptation'  , action='store_true', help='further adaptation in test time or not')
params = parser.parse_args([])

In [11]:
acc_all = []
split_list = ['base', 'val']
iter_num = 600

few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot)   # 5 way, 5 shot

model           = ProtoNet( model_dict[params.model], **few_shot_params )
model = model.cuda()
checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
if params.train_aug:
    checkpoint_dir += '_aug'
if not params.method in ['baseline', 'baseline++'] :
    checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)

# load  pre-trained-model params
modelfile   = get_best_file(checkpoint_dir)
tmp = torch.load(modelfile)
model.load_state_dict(tmp['state'])

In [15]:
novel_file = os.path.join( checkpoint_dir.replace("checkpoints","features"),
                          "base_best.hdf5") #defaut split = novel, but you can also test base or val classes
cl_data_file = feat_loader.init_loader(novel_file)
n_way = 5, 
n_support = 5, 
n_query = 15

In [31]:
file = h5py.File(novel_file)   # <KeysViewHDF5 ['all_feats', 'all_labels', 'count']>
print(file.keys())
feats = file['all_feats']
# feats.shape               # 5888, 512
labels = file['all_labels']
# labels.shape     # (5888,)
while np.sum(feats[-1]) == 0:
    feats  = np.delete(feats,-1,axis = 0)
    labels = np.delete(labels,-1,axis = 0)

print(feats.shape)
print(labels.shape)

<KeysViewHDF5 ['all_feats', 'all_labels', 'count']>
(5885, 512)
(5885,)


  """Entry point for launching an IPython kernel.


In [None]:
theta_h = nn.Sequential(
)