In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim
import torch.optim.lr_scheduler as lr_scheduler
import time
import os
import glob
import copy
# import configs
import backbone
from data.datamgr import SetDataManager
import argparse

from io_utils import model_dict, parse_args, get_resume_file, get_trlog, save_fig
from utils import Timer
from methods.protonet import ProtoNet
from data.dataset import SimpleDataset, EpisodicSampler


import argparse

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset'     , default='NWPU',        help='NPPU/WHU-RS19/UCMERCED')

parser.add_argument('--model'       , default='ResNet10',      help='model: Conv{4|6} / ResNet{10|18|34|50|101}') # 50 and 101 are not used in the paper
parser.add_argument('--method'      , default='protonet',   help='baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/maml{_approx}') #relationnet_softmax replace L2 norm with softmax to expedite training, maml_approx use first-order approximation in the gradient for efficiency
parser.add_argument('--train_n_way' , default=5, type=int,  help='class num to classify for training') #baseline and baseline++ would ignore this parameter
parser.add_argument('--test_n_way'  , default=5, type=int,  help='class num to classify for testing (validation) ') #baseline and baseline++ only use this parameter in finetuning
parser.add_argument('--n_shot'      , default=5, type=int,  help='number of labeled data in each class, same as n_support') #baseline and baseline++ only use this parameter in finetuning
parser.add_argument('--n_query'      , default=8, type=int,  help='number of unlabeled  query data in each class, same as n_query') #baseline and baseline++ only use this parameter in finetuning

parser.add_argument('--train_aug'   , default=True, help='perform data augmentation or not during training ') #still required for save_features.py and test.py to find the model path correctly
parser.add_argument('--n_episode', default=100, type=int, help = 'num of episodes in each epoch')
parser.add_argument('--mlp_dropout' , default=0.7, help='dropout rate in word embedding transformer')
# parser.add_argument('--aux'   , default=False,  help='use attribute as auxiliary data, multimodal method')

# learning rate, optim
parser.add_argument('--lr_anneal', default='const', help='const/pwc/exp, schedule learning rate')
parser.add_argument('--init_lr', default=0.001)
parser.add_argument('--optim', default='Adam', help='Adam/SGD')
# learning rate decay

parser.add_argument('--num_classes' , default=200, type=int, help='total number of classes in softmax, only used in baseline') #make it larger than the maximum label value in base class
parser.add_argument('--save_freq'   , default=10, type=int, help='Save frequency')
parser.add_argument('--start_epoch' , default=0, type=int,help ='Starting epoch')
parser.add_argument('--stop_epoch'  , default=300, type=int, help ='Stopping epoch') #for meta-learning methods, each epoch contains 100 episodes. The default epoch number is dataset dependent. See train.py
parser.add_argument('--warmup'      , action='store_true', help='continue from baseline, neglected if resume is true') #never used in the paper

params = parser.parse_args([])
print(params)

Namespace(dataset='NWPU', init_lr=0.001, lr_anneal='const', method='protonet', mlp_dropout=0.7, model='ResNet10', n_episode=100, n_query=8, n_shot=5, num_classes=200, optim='Adam', save_freq=10, start_epoch=0, stop_epoch=300, test_n_way=5, train_aug=True, train_n_way=5, warmup=False)


In [4]:
image_size=224
base_file = os.path.join('./filelists', params.dataset, 'base.json')
train_few_shot_params   = dict(n_way = params.train_n_way, n_support = params.n_shot, n_query=params.n_query) 


In [12]:
base_datamgr            = SetDataManager(image_size, n_episode=params.n_episode, **train_few_shot_params)

In [17]:
data_file = base_file
transform = base_datamgr.trans_loader.get_composed_transform(aug=True)
dataset = SimpleDataset(data_file, transform)

In [19]:
for x, y in dataset:
    print(x.shape)
    print(y)
    break

torch.Size([3, 224, 224])
0


In [20]:
base_loader             = base_datamgr.get_data_loader( base_file , aug = params.train_aug )


In [21]:
for x, y in base_loader:
    print(x.shape)
    print(y)
    break

PicklingError: Can't pickle <function <lambda> at 0x000001B971B92620>: attribute lookup <lambda> on data.dataset failed

In [5]:
model = ProtoNet(model_dict[params.model], params=params, **train_few_shot_params)
model = model.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (pool1): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (trunk): Sequential(
    (0): SimpleBlock(
      (C1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (BN1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (C2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (BN2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace)
      (relu2): ReLU(inplace)
    )
    (1): SimpleBlock(
      (C1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (BN1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (C2): Conv2d(128, 128, k

In [6]:
x = torch.randn(10,3,224,224).cuda()

In [7]:
z = model(x)