In [1]:
import torch
import pickle

from learner import Learner
from meta_learner import Meta_Learner
from mydataset.miniimagenet import MiniImageNet
from learning import *

In [2]:
class MyArgs:
    def __init__(
        self,
        epoch=10,
        n_way=5,
        k_spt=1,
        k_qry=15,
        imgsz=84,
        imgc=3,
        task_num=4,
        meta_lr=1e-3,
        update_lr=0.01,
        update_step=5,
        update_step_test=10,
    ):
        self.epoch = epoch
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.imgsz = imgsz
        self.imgc = imgc
        self.task_num = task_num
        self.meta_lr = meta_lr
        self.update_lr = update_lr
        self.update_step = update_step
        self.update_step_test = update_step_test

## train : 5-way 1-shot
+ query : 5-way 15-shot

In [3]:
args = MyArgs(epoch=10,
              n_way=5,
              k_spt=1,
              k_qry=15,
              imgsz=84,
              imgc=3,
              task_num=8,
              meta_lr=1e-3,
              update_lr=0.01,
              update_step=5,
              update_step_test=10)

In [4]:
config = [
    ("conv2d", [32, 3, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 1, 0]),
    ("flatten", []),
    ("linear", [args.n_way, 32 * 5 * 5]),
]

device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
maml = Meta_Learner(args, config, task='classification').to(device)

print(maml)

Meta_Learner(
  (net): Learner(
    conv2d:(ch_in:3, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:1, padding:0)
    flatten:()
    linear:(in:800, out:5)
    
    (vars): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 32x3x3x3 (GPU 0)]
        (1): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (2): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (3): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (4): Parameter containing: [torch

In [5]:
# batchsz here means total episode number
train_db = MiniImageNet(
    "miniimagenet/",
    mode="train",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=10000,
    resize=args.imgsz,
)
valid_db = MiniImageNet(
    "miniimagenet/",
    mode="valid",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=80,
    resize=args.imgsz,
)
test_db = MiniImageNet(
    "miniimagenet/",
    mode="test",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=100,
    resize=args.imgsz,
)

Mini_ImageNet Dataset(train) :
	Batch_size : 10000
	Support sets : 5-way 1-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84
Mini_ImageNet Dataset(valid) :
	Batch_size : 80
	Support sets : 5-way 1-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84
Mini_ImageNet Dataset(test) :
	Batch_size : 100
	Support sets : 5-way 1-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84


In [6]:
history = train(maml, device, args.epoch, args.task_num, train_db, valid_db, 16, 32)

1250it [1:28:17,  4.24s/it, epoch=1/10, loss=1.449661, acc=0.374]
Mean Train Loss : 1.536690
Mean Train acc  : 0.3400
Mean Valid Loss : 1.583008
Mean Valid acc  : 0.3198

1250it [1:12:20,  3.47s/it, epoch=2/10, loss=1.406114, acc=0.406]
Mean Train Loss : 1.448813
Mean Train acc  : 0.3772
Mean Valid Loss : 1.488281
Mean Valid acc  : 0.3611

1250it [1:12:52,  3.50s/it, epoch=3/10, loss=1.289559, acc=0.480]
Mean Train Loss : 1.400545
Mean Train acc  : 0.4029
Mean Valid Loss : 1.465820
Mean Valid acc  : 0.3750

1250it [1:12:33,  3.48s/it, epoch=4/10, loss=1.373032, acc=0.410]
Mean Train Loss : 1.366641
Mean Train acc  : 0.4228
Mean Valid Loss : 1.451172
Mean Valid acc  : 0.3850

1250it [1:12:01,  3.46s/it, epoch=5/10, loss=1.302749, acc=0.451]
Mean Train Loss : 1.334590
Mean Train acc  : 0.4374
Mean Valid Loss : 1.433594
Mean Valid acc  : 0.3943

1250it [1:11:05,  3.41s/it, epoch=6/10, loss=1.290852, acc=0.457]
Mean Train Loss : 1.296346
Mean Train acc  : 0.4552
Mean Valid Loss : 1.412109


In [7]:
test_result = evaluate(maml, device, test_db, mode='test')

print('loss mean : {}'.format([round(val, 5) for val in test_result[0]]))
print('acc mean  : {}'.format([round(val, 3) for val in test_result[1]]))
print('Total loss : {:.6f}'.format(test_result[2]))
print('Total acc  : {:.4f}'.format(test_result[3]))

100it [01:40,  1.01s/it, loss=1.231380, acc=0.513]
loss mean : [1.65234, 1.36328, 1.3584, 1.36035, 1.36328, 1.36523, 1.36816, 1.37012, 1.37305, 1.375, 1.37695]
acc mean  : [0.213, 0.423, 0.429, 0.431, 0.432, 0.433, 0.434, 0.434, 0.435, 0.435, 0.435]
Total loss : 1.393555
Total acc  : 0.4121


In [11]:
with open('history/miniimagenet_5-way_1-shot_train_history.pickle', 'wb') as f:
    pickle.dump(history, f, pickle.HIGHEST_PROTOCOL)
    
with open('history/miniimagenet_5-way_1-shot_test_result.pickle', 'wb') as f:
    pickle.dump(test_result, f, pickle.HIGHEST_PROTOCOL)

## train : 5-way 5-shot
+ query : 5-way 15-shot

In [3]:
args = MyArgs(epoch=10,
              n_way=5,
              k_spt=5,
              k_qry=15,
              imgsz=84,
              imgc=3,
              task_num=8,
              meta_lr=1e-3,
              update_lr=0.01,
              update_step=5,
              update_step_test=10)

In [4]:
config = [
    ("conv2d", [32, 3, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 1, 0]),
    ("flatten", []),
    ("linear", [args.n_way, 32 * 5 * 5]),
]

device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
maml = Meta_Learner(args, config, task='classification').to(device)

print(maml)

Meta_Learner(
  (net): Learner(
    conv2d:(ch_in:3, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:1, padding:0)
    flatten:()
    linear:(in:800, out:5)
    
    (vars): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 32x3x3x3 (GPU 0)]
        (1): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (2): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (3): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (4): Parameter containing: [torch

In [5]:
# batchsz here means total episode number
train_db = MiniImageNet(
    "miniimagenet/",
    mode="train",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=10000,
    resize=args.imgsz,
)
valid_db = MiniImageNet(
    "miniimagenet/",
    mode="valid",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=80,
    resize=args.imgsz,
)
test_db = MiniImageNet(
    "miniimagenet/",
    mode="test",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=100,
    resize=args.imgsz,
)

Mini_ImageNet Dataset(train) :
	Batch_size : 10000
	Support sets : 5-way 5-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84
Mini_ImageNet Dataset(valid) :
	Batch_size : 80
	Support sets : 5-way 5-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84
Mini_ImageNet Dataset(test) :
	Batch_size : 100
	Support sets : 5-way 5-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84


In [6]:
history = train(maml, device, args.epoch, args.task_num, train_db, valid_db, 16, 32)

1250it [1:29:46,  4.31s/it, epoch=1/10, loss=1.165166, acc=0.536]
Mean Train Loss : 1.365055
Mean Train acc  : 0.4307
Mean Valid Loss : 1.419922
Mean Valid acc  : 0.4050

1250it [1:28:18,  4.24s/it, epoch=2/10, loss=1.049421, acc=0.578]
Mean Train Loss : 1.188116
Mean Train acc  : 0.5173
Mean Valid Loss : 1.272461
Mean Valid acc  : 0.4807

1250it [1:29:42,  4.31s/it, epoch=3/10, loss=0.946231, acc=0.626]
Mean Train Loss : 1.091933
Mean Train acc  : 0.5591
Mean Valid Loss : 1.215820
Mean Valid acc  : 0.5078

1250it [1:28:18,  4.24s/it, epoch=4/10, loss=1.084442, acc=0.557]
Mean Train Loss : 1.021567
Mean Train acc  : 0.5861
Mean Valid Loss : 1.188477
Mean Valid acc  : 0.5215

1250it [1:27:27,  4.20s/it, epoch=5/10, loss=1.155469, acc=0.544]
Mean Train Loss : 0.992494
Mean Train acc  : 0.5982
Mean Valid Loss : 1.175781
Mean Valid acc  : 0.5293

1250it [1:27:28,  4.20s/it, epoch=6/10, loss=0.883692, acc=0.628]
Mean Train Loss : 0.959545
Mean Train acc  : 0.6123
Mean Valid Loss : 1.174805


In [7]:
test_result = evaluate(maml, device, test_db, mode='test')

print('loss mean : {}'.format([round(val, 5) for val in test_result[0]]))
print('acc mean  : {}'.format([round(val, 3) for val in test_result[1]]))
print('Total loss : {:.6f}'.format(test_result[2]))
print('Total acc  : {:.4f}'.format(test_result[3]))

100it [01:30,  1.11it/s, loss=1.374872, acc=0.450]
loss mean : [2.05664, 1.10645, 1.04785, 1.02539, 1.02539, 1.02734, 1.0293, 1.03125, 1.0332, 1.03613, 1.03809]
acc mean  : [0.195, 0.564, 0.58, 0.594, 0.594, 0.597, 0.598, 0.598, 0.599, 0.599, 0.6]
Total loss : 1.132812
Total acc  : 0.5562


In [8]:
with open('history/miniimagenet_5-way_5-shot_train_history.pickle', 'wb') as f:
    pickle.dump(history, f, pickle.HIGHEST_PROTOCOL)
    
with open('history/miniimagenet_5-way_5-shot_test_result.pickle', 'wb') as f:
    pickle.dump(test_result, f, pickle.HIGHEST_PROTOCOL)