In [83]:
import argparse
import os.path as osp
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from utils import IOStream, set_gpu, ensure_path, Timer, count_acc,  Averager, compute_confidence_interval, euclidean_metric
from dataset.dataset import MiniImagenet
from dataset.sampler import CategoriesSampler
from model.protonet import ProtoNet
from tensorboardX import SummaryWriter
import shutil

%matplotlib inline
%load_ext autoreload
%autoreload 2

import warnings

warnings.filterwarnings('ignore')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [38]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='mini_imagenet')
parser.add_argument('--distance', default='l2')
parser.add_argument('--shot', default=1, type=int)
parser.add_argument('--way', default=5, type=int)
parser.add_argument('--query', default=5, type=int)
parser.add_argument('--num_tasks', default=5, type=int)
parser.add_argument('--n_batch', default=100, type=int)
parser.add_argument('--max_epoch', default=200, type=int)

parser.add_argument('--gpu', default=0)
# optimizer
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--step_size', type=int, default=10)  # step size and gamma use to lr scheduler
parser.add_argument('--gamma', type=float, default=0.2)
# model
parser.add_argument('--init_weights', type=str, default=None)
parser.add_argument('--model_type', default='ConvNet', type=str, choices=['ConvNet', 'ResNet', 'AmdimNet'])
# io
parser.add_argument('--save_path', type=str, default='./MINI_ProtoNet_MINI_1shot_5way')


args = parser.parse_args([])
io = IOStream('log/run.log')
io.cprint(str(args))


Namespace(dataset='mini_imagenet', distance='l2', gamma=0.2, gpu=0, init_weights=None, lr=0.0001, max_epoch=200, model_type='ConvNet', n_batch=100, num_tasks=5, query=5, save_path='./MINI_ProtoNet_MINI_1shot_5way', shot=1, step_size=10, way=5)


In [84]:
def removeDir(dirPath):
    if not os.path.isdir(dirPath):
        return
    files = os.listdir(dirPath)
    for file in files:
        filePath = os.path.join(dirPath,file)
        if os.path.isfile(filePath):
            os.remove(filePath)
        elif os.path.isdir(filePath):
            removeDir(filePath)
    os.rmdir(dirPath)

In [94]:
# removeDir('MINI_ProtoNet_MINI_1shot_5way/mini_imagenet-ConvNet-ProtoNet/')
# os.rmdir('MINI_ProtoNet_MINI_1shot_5way/mini_imagenet-ConvNet-ProtoNet/1_5_5/summary')
# osp.exists('MINI_ProtoNet_MINI_1shot_5way/mini_imagenet-ConvNet-ProtoNet/1_5_5/summary')
# os.mkdir('MINI_ProtoNet_MINI_1shot_5way/mini_imagenet-ConvNet-ProtoNet/1_5_5/summary')
# shutil.rmtree('MINI_ProtoNet_MINI_1shot_5way/mini_imagenet-ConvNet-ProtoNet/')

False

In [71]:
valset = MiniImagenet('val', args)
len(valset.data)  # 9600
val_sampler = CategoriesSampler(valset.label, args.n_batch, args.way * args.num_tasks, args.shot + args.query)
len(val_sampler.m_ind)   # 16
val_sampler.m_ind[0].shape  # 600
# for i in val_sampler:
#     print(i.shape)  # torch.Size([96])
#     break

torch.Size([600])

In [69]:
len(trainset.data)   #38400
for i in train_sampler:
    print(i.shape)   # torch.Size([150])  6*25
    break

torch.Size([150])


In [95]:
set_gpu(args.gpu)
save_path1 = '-'.join([args.dataset, args.model_type, 'ProtoNet'])
save_path2 = '_'.join([str(args.shot), str(args.query), str(args.way)])
args.save_path = osp.join(args.save_path, osp.join(save_path1, save_path2))

ensure_path(save_path1, remove=False)
ensure_path(args.save_path)  

trainset = MiniImagenet('train', args)
train_sampler = CategoriesSampler(trainset.label, args.n_batch, args.way * args.num_tasks, args.shot + args.query)
train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, num_workers=2, pin_memory=True)

# only 16 classes in val, so no num_tasks
valset = MiniImagenet('val', args)   
val_sampler = CategoriesSampler(valset.label, args.n_batch, args.way, args.shot + args.query)  
val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=2, pin_memory=True)

model = ProtoNet(args)
if args.model_type == 'ConvNet':
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
elif args.model_type == 'ResNet':
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, nesterov=True, weight_decay=0.0005)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)        

        # logits = model(data_shot, data_query)
        # loss = F.cross_entropy(logits, label)

model_dict = model.state_dict()
# load pretrained model initialization, according to my test, it can't work
if args.init_weights is not None:
    model_detail = torch.load(args.init_weights)
    if 'params' in model_detail:
        pretrained_dict = model_detail['params']
        # remove weights for FC
        pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()}
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)

    else:
        pretrained_dict = model_detail['model']
        pretrained_dict = {k.replace('module.', ''): v for k, v in pretrained_dict.items() if k.replace('module.', '') in model_dict}
        # pretrained_dict is empty
        model_dict.update(pretrained_dict)           

model.load_state_dict(model_dict)    

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    model = model.cuda()

using gpu: 0


In [58]:
# log component
trlog = {}
trlog['args'] = vars(args)
trlog['train_loss'] = []
trlog['val_loss'] = []
trlog['train_acc'] = []
trlog['val_acc'] = []
trlog['max_acc'] = 0.0
trlog['max_acc_epoch'] = 0

timer = Timer()
global_count = 0
writer = SummaryWriter(logdir=osp.join(args.save_path, 'summary'))


for epoch in range(1, args.max_epoch + 1):
    lr_scheduler.step()
    model.train()
    tl = Averager()  # average train loss of the epoch
    ta = Averager()  # train acc of the epoch

    for i, batch in enumerate(train_loader, 1):
        global_count = global_count + 1
        data, _ = [_.cuda() for _ in batch]
        p = args.num_tasks * args.shot * args.way
        data_shot, data_query = data[:p], data[p:]

        label = torch.arange(args.way*args.num_tasks).repeat(args.query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)
            
        logits = model(data_shot, data_query)

        
        
        break
    
    break

In [59]:
# train
label.shape  # torch.Size([125])
logits.shape  # torch.Size([125, 25])
data.shape

torch.Size([150, 3, 84, 84])

In [104]:
with torch.no_grad():  
    for i, batch in enumerate(val_loader, 1):
        if torch.cuda.is_available():
            data, _ = [_.cuda() for _ in batch]
        else:
            data = batch[0]

        label = torch.arange(args.way).repeat(args.query)
        if torch.cuda.is_available():
            label = label.type(torch.cuda.LongTensor)
        else:
            label = label.type(torch.LongTensor)

        p = args.shot * args.way
        data_shot, data_query = data[:p], data[p:]

        logits = model(data_shot, data_query)
#         loss = F.cross_entropy(logits, label)
        break

torch.Size([5, 64])


RuntimeError: shape '[1, 25, -1]' is invalid for input of size 320

In [105]:
# val
label.shape   # torch.Size([25])
logits.shape  # torch.Size([71, 25])
data_query.shape  torch.Size([25, 3, 84, 84])
# p    # 5
# data.shape   # ([30, 3, 84, 84]


torch.Size([25, 3, 84, 84])

In [18]:
a = torch.tensor(range(24)).view(2,3,4)
a.shape
b = torch.tensor(range(24, 48, 1)).view(2,3,4)
logits = -((a - b)**2).sum(dim=2)
logits.shape

torch.Size([2, 3])

In [25]:
ip = torch.randn(3, 5, requires_grad=True)
target = torch.randint(5, (3,), dtype=torch.int64)  # tensor([1, 3, 4])
loss = F.cross_entropy(ip, target)
loss


tensor(1.4281, grad_fn=<NllLossBackward>)