In [36]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim
import torch.optim.lr_scheduler as lr_scheduler
import time
import os
import glob
import copy
from methods import backbone, resnet12
from data.datamgr import SetDataManager
# from methods.protonet import ProtoNet
from methods.protonet_multi_gpu import ProtoNetMulti
from loss.nt_xent import NTXentLoss
from io_utils import model_dict, parse_args, get_resume_file, get_trlog, save_fig
from utils import Timer
import argparse
from utils import euclidean_dist

import warnings
warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
parser = argparse.ArgumentParser()
# parser.add_argument('--dataset'     , default='CUB',        help='CUB/miniImagenet/cross/omniglot/cross_char')
parser.add_argument('--os'          , default='linux',        help='linux/windows')

parser.add_argument('--dataset'     , default='NWPU',        help='NPPU/WHU-RS19/UCMERCED')
parser.add_argument('--model'       , default='ResNet12',      help='model: Conv{4|6} / ResNet{10|18|34|50|101}') # 50 and 101 are not used in the paper
parser.add_argument('--method'      , default='cs_protonet',   help='rotate/baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/maml{_approx}') #relationnet_softmax replace L2 norm with softmax to expedite training, maml_approx use first-order approximation in the gradient for efficiency
parser.add_argument('--train_n_way' , default=5, type=int,  help='class num to classify for training') #baseline and baseline++ would ignore this parameter
parser.add_argument('--test_n_way'  , default=5, type=int,  help='class num to classify for testing (validation) ') #baseline and baseline++ only use this parameter in finetuning
parser.add_argument('--n_shot'      , default=1, type=int,  help='number of labeled data in each class, same as n_support') #baseline and baseline++ only use this parameter in finetuning
parser.add_argument('--n_query'      , default=8, type=int,  help='number of unlabeled  query data in each class, same as n_query') #baseline and baseline++ only use this parameter in finetuning

parser.add_argument('--train_aug'   , default=True, type=bool, help='perform data augmentation or not during training ') #still required for save_features.py and test.py to find the model path correctly
# parser.add_argument('--no_aug' ,dest='train_aug', action='store_false', default=True,  help='perform data augmentation or not during training ') 

parser.add_argument('--n_episode', default=100, type=int, help = 'num of episodes in each epoch')
parser.add_argument('--mlp_dropout' , default=0.7, help='dropout rate in word embedding transformer')
# parser.add_argument('--aux'   , default=False,  help='use attribute as auxiliary data, multimodal method')

# learning rate, optim
parser.add_argument('--lr_anneal', default='const', help='const/pwc/exp, schedule learning rate')
parser.add_argument('--init_lr', default=0.001)
parser.add_argument('--optim', default='Adam', help='Adam/SGD')
parser.add_argument('--wd', default='0', help='weight_decay  /  {0|0.001|...}')
parser.add_argument('--alpha'       , default=2.0, type=int, help='for manifold_mixup or S2M2 training ')

parser.add_argument('--num_classes' , default=200, type=int, help='total number of classes in softmax, only used in baseline') #make it larger than the maximum label value in base class
parser.add_argument('--save_freq'   , default=10, type=int, help='Save frequency')
parser.add_argument('--start_epoch' , default=0, type=int,help ='Starting epoch')
parser.add_argument('--stop_epoch'  , default=300, type=int, help ='Stopping epoch') #for meta-learning methods, each epoch contains 100 episodes. The default epoch number is dataset dependent. See train.py
parser.add_argument('--warmup'      , action='store_true', help='continue from baseline, neglected if resume is true') #never used in the paper



params = parser.parse_args([])
print(params)

Namespace(alpha=2.0, dataset='NWPU', init_lr=0.001, lr_anneal='const', method='cs_protonet', mlp_dropout=0.7, model='ResNet12', n_episode=100, n_query=8, n_shot=1, num_classes=200, optim='Adam', os='linux', save_freq=10, start_epoch=0, stop_epoch=300, test_n_way=5, train_aug=True, train_n_way=5, warmup=False, wd='0')


In [23]:
np.random.seed(10)
print(params)    

# if params.os == 'linux':
#     base_file = os.path.join('./filelists', params.dataset, 'base_linux.json')
#     val_file = os.path.join('./filelists', params.dataset, 'val_linux.json')
# else:
base_file = os.path.join('./filelists', params.dataset, ('base_%s.json' % params.os))
val_file = os.path.join('./filelists', params.dataset,  ('val_%s.json' % params.os))
# novel_file = os.path.join('./filelists', params.dataset, 'novel.json')

image_size = 84
params.image_size = image_size

print('image_size = ', image_size)
print("n_query = ", params.n_query)
params.batch_size = (params.n_query + params.n_shot) * params.train_n_way

train_few_shot_params   = dict(n_way = params.train_n_way, n_support = params.n_shot, n_query=params.n_query) 
base_datamgr            = SetDataManager(image_size, n_episode=params.n_episode, params=params, **train_few_shot_params)
base_loader             = base_datamgr.get_data_loader( base_file , aug = params.train_aug )

test_few_shot_params     = dict(n_way = params.test_n_way, n_support = params.n_shot, n_query = params.n_query) 
val_datamgr             = SetDataManager(image_size, n_episode=params.n_episode, params=params, **test_few_shot_params)
val_loader              = val_datamgr.get_data_loader(val_file, aug = False) 

model = ProtoNetMulti(model_dict[params.model], params=params, **train_few_shot_params)
model = model.cuda()
params.checkpoint_dir = 'checkpoints/%s/%s_%s' %(params.dataset, params.model, params.method)
if params.train_aug:
    params.checkpoint_dir += '_aug'

params.checkpoint_dir += '_%s_lr%s_%s_wd%s' % (params.optim, str(params.init_lr), params.lr_anneal, str(params.wd))

if not params.method  in ['baseline', 'baseline++']: 
    params.checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)
params.model_dir = os.path.join(params.checkpoint_dir, 'model')
if not os.path.isdir(params.model_dir):
    os.makedirs(params.model_dir)
print('checkpoint_dir = ', params.checkpoint_dir)
start_epoch = params.start_epoch
stop_epoch = params.stop_epoch
max_acc = 0

Namespace(alpha=2.0, batch_size=45, checkpoint_dir='checkpoints/NWPU/ResNet12_cs_protonet_aug_Adam_lr0.001_const_wd0_5way_1shot', dataset='NWPU', init_lr=0.001, lr_anneal='const', method='cs_protonet', mlp_dropout=0.7, model='ResNet12', model_dir='checkpoints/NWPU/ResNet12_cs_protonet_aug_Adam_lr0.001_const_wd0_5way_1shot/model', n_episode=100, n_query=8, n_shot=1, num_classes=200, optim='Adam', os='linux', save_freq=10, start_epoch=0, stop_epoch=300, test_n_way=5, train_aug=True, train_n_way=5, warmup=False, wd='0')
image_size =  84
n_query =  8
checkpoint_dir =  checkpoints/NWPU/ResNet12_cs_protonet_aug_Adam_lr0.001_const_wd0_5way_1shot


### train contrastive learning

In [21]:
cs_mlp = nn.Sequential(
    nn.Linear(model.feat_dim, model.feat_dim),
    nn.ReLU(), 
    nn.Linear(model.feat_dim, 256)
).cuda()

cs_criterion = NTXentLoss(params.batch_size).cuda()
init_lr = params.init_lr
optimizer = torch.optim.Adam([
            {'params': model.parameters()},
            {'params': cs_mlp.parameters()}
    ], lr=init_lr)

timer = Timer()
print_freq = 50
def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [56]:
print_freq = 10
for i, (x, _) in enumerate(base_loader, 1):
   # manifold mixup loss
    xi = x[0].cuda()    # n_way * (k_shot+query)
    xj = x[1].cuda()    # n_way * (k_shot+query)
    # manifold mixup loss
    inputs = xi
    targets = torch.from_numpy(np.repeat(range(model.n_way), model.n_query)).cuda()
    lam = np.random.beta(params.alpha, params.alpha)
    # 分割 x_support, x_query
    inputs = inputs.view(model.n_way, (model.n_support + model.n_query), 3, params.image_size, params.image_size)

    x_support   = inputs[:, :model.n_support]   # (n_way, n_support, )
    x_query     = inputs[:, model.n_support:]  
    x_support   = x_support.contiguous().view(model.n_way*model.n_support, 3, params.image_size, params.image_size)  # (n_way * n_support) 
    x_query     = x_query.contiguous().view(model.n_way*model.n_query, 3, params.image_size, params.image_size)  # (n_way * n_query) 
    # x_support：model forward：得到z_support
    z_support = model.forward(x_support)
    z_support = z_support.view(model.n_way, model.n_support, -1)
    # x_query 计算插值, 得到z_query, target_a, target_b
    z_query, target_a , target_b = model.forward(x_query, targets, mixup_hidden= True, lam = lam)  # (n_way * n_query, feat_dim)
    img_proto   = z_support.mean(1)   # (n_way, feat_dim)
    # 用z_pred，tareget_a, target_b计算loss
    dists = euclidean_dist(z_query, img_proto)  # (n_way*n_query, n_way)
    scores = -dists
    criterion = model.loss_fn
    mm_loss = mixup_criterion(criterion, scores, target_a, target_b, lam)

    mm_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    cum_mm_loss = mm_loss.item()
    avg_mm_loss = cum_mm_loss / float(i)
    # train_loss += loss.data.item()
    if i % print_freq == 0:
        print('Mixup Loss {:.3f}'.format(avg_mm_loss))

Mixup Loss 6.131
Mixup Loss 1.741
Mixup Loss 2.777
Mixup Loss 1.962
Mixup Loss 0.613
Mixup Loss 1.056
Mixup Loss 0.754
Mixup Loss 0.773
Mixup Loss 0.385
Mixup Loss 0.503


In [27]:
print(inputs.shape)

x_support.shape

torch.Size([5, 9, 3, 84, 84])


torch.Size([5, 1, 3, 84, 84])

In [51]:
x_support   = inputs[:, :model.n_support]   # (n_way, n_support, )
x_query     = inputs[:, model.n_support:]  
x_support   = x_support.contiguous().view(model.n_way*model.n_support, 3, params.image_size, params.image_size)  # (n_way * n_support) 
x_query     = x_query.contiguous().view(model.n_way*model.n_query, 3, params.image_size, params.image_size)  # (n_way * n_query) 
# x_support：model forward：得到z_support
z_support = model.forward(x_support)
z_support = z_support.view(model.n_way, model.n_support, -1)
# x_query 计算插值, 得到z_query, target_a, target_b
z_query, target_a , target_b = model.forward(x_query, targets, mixup_hidden= True, lam = lam)
img_proto   = z_support.mean(1)   # (n_way, feat_dim)
# 用z_pred，tareget_a, target_b计算loss
dists = euclidean_dist(z_query, img_proto)  # n_way
scores = -dists
criterion = model.loss_fn
loss = mixup_criterion(criterion, scores, target_a, target_b, lam)

In [52]:
print(x_support.shape)
print(z_support.shape)
print(z_query.shape)
print(img_proto.shape)
print(dists.shape)  # torch.Size([40, 5])
print(loss)

torch.Size([5, 3, 84, 84])
torch.Size([5, 1, 512])
torch.Size([40, 512])
torch.Size([5, 512])
torch.Size([40, 5])
tensor(33.5620, device='cuda:0', grad_fn=<AddBackward0>)


In [13]:
scores = model.compute_score(zi)
correct_this, count_this = model.correct(scores)
y_query = torch.from_numpy(np.repeat(range(model.n_way), model.n_query))
y_query = Variable(y_query.long().cuda())
clf_loss = model.loss_fn(scores, y_query)

In [15]:
# compute contrastive loss
zics = cs_mlp(zi)
zjcs = cs_mlp(zj)
cs_loss = cs_criterion(zics, zjcs)

In [18]:
print(correct_this, count_this)
print(clf_loss, cs_loss)

12 40
tensor(35.4992, device='cuda:0', grad_fn=<NllLossBackward>) tensor(4.2756, device='cuda:0', grad_fn=<DivBackward0>)


In [14]:
a = torch.randn(45,3,84,84)
a.shape
a = a.view(5, 9, -1, a.shape[-2], a.shape[-1])
support = a[:,:1]
query = a[:, 1:]
support.shape
query.shape
query.view()

torch.Size([5, 8, 3, 84, 84])

## train_s2m2

In [48]:
a = "hello wod"
b = "llo"
b in a

True