In [1]:
import torch
import pickle

from learner import Learner
from meta_learner import Meta_Learner
from mydataset.omniglot import Omniglot
from learning import *

In [2]:
class MyArgs:
    def __init__(
        self,
        epoch=10,
        n_way=5,
        k_spt=1,
        k_qry=15,
        imgsz=84,
        imgc=3,
        task_num=4,
        meta_lr=1e-3,
        update_lr=0.01,
        update_step=5,
        update_step_test=10):
        
        self.epoch = epoch
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.imgsz = imgsz
        self.imgc = imgc
        self.task_num = task_num
        self.meta_lr = meta_lr
        self.update_lr = update_lr
        self.update_step = update_step
        self.update_step_test = update_step_test

## 5-way 1-shot
+ query : 5-way 15-shot

In [3]:
args = MyArgs(epoch=10,
              n_way=5,
              k_spt=1,
              k_qry=15,
              imgsz=84,
              imgc=1,
              task_num=8,
              meta_lr=1e-3,
              update_lr=0.01,
              update_step=5,
              update_step_test=10)

In [4]:
config = [
    ("conv2d", [32, 1, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 1, 0]),
    ("flatten", []),
    ("linear", [args.n_way, 32 * 5 * 5]),
]


device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
maml = Meta_Learner(args, config, task='classification').to(device)

print(maml)

Meta_Learner(
  (net): Learner(
    conv2d:(ch_in:1, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:1, padding:0)
    flatten:()
    linear:(in:800, out:5)
    
    (vars): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 32x1x3x3 (GPU 0)]
        (1): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (2): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (3): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (4): Parameter containing: [torch

In [5]:
# batchsz here means total episode number
train_db = Omniglot(
    "omniglot/",
    mode="train",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=5000,
    resize=args.imgsz,
)
valid_db = Omniglot(
    "omniglot/",
    mode="valid",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=50,
    resize=args.imgsz,
)
test_db = Omniglot(
    "omniglot/",
    mode="test",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=100,
    resize=args.imgsz,
)

StanFord Dogs Dataset(train) :
	Batch_size : 5000
	Support sets : 5-way 1-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84
StanFord Dogs Dataset(valid) :
	Batch_size : 50
	Support sets : 5-way 1-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84
StanFord Dogs Dataset(test) :
	Batch_size : 100
	Support sets : 5-way 1-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84


In [6]:
history = train(maml, device, args.epoch, args.task_num, train_db, valid_db, 32, 64)

625it [09:57,  1.05it/s, epoch=1/10, loss=0.840079, acc=0.714]
Mean Train Loss : 1.122235
Mean Train acc  : 0.6012
Mean Valid Loss : 0.956055
Mean Valid acc  : 0.6523

625it [09:23,  1.11it/s, epoch=2/10, loss=0.982303, acc=0.661]
Mean Train Loss : 0.843689
Mean Train acc  : 0.7092
Mean Valid Loss : 0.771973
Mean Valid acc  : 0.7329

625it [09:28,  1.10it/s, epoch=3/10, loss=0.892296, acc=0.717]
Mean Train Loss : 0.764827
Mean Train acc  : 0.7379
Mean Valid Loss : 0.684082
Mean Valid acc  : 0.7651

625it [10:08,  1.03it/s, epoch=4/10, loss=0.661145, acc=0.805]
Mean Train Loss : 0.738936
Mean Train acc  : 0.7591
Mean Valid Loss : 0.618652
Mean Valid acc  : 0.7993

625it [08:46,  1.19it/s, epoch=5/10, loss=0.633064, acc=0.765]
Mean Train Loss : 0.657963
Mean Train acc  : 0.7838
Mean Valid Loss : 0.539551
Mean Valid acc  : 0.8208

625it [09:05,  1.15it/s, epoch=6/10, loss=0.675688, acc=0.789]
Mean Train Loss : 0.668048
Mean Train acc  : 0.7900
Mean Valid Loss : 0.524414
Mean Valid acc  : 

In [7]:
test_result = evaluate(maml, device, test_db, mode='test')

print('loss mean : {}'.format([round(val, 5) for val in test_result[0]]))
print('acc mean  : {}'.format([round(val, 3) for val in test_result[1]]))
print('Total loss : {:.6f}'.format(test_result[2]))
print('Total acc  : {:.4f}'.format(test_result[3]))

100it [00:13,  7.28it/s, loss=0.861074, acc=0.761]
loss mean : [3.10352, 0.27612, 0.2428, 0.24048, 0.23865, 0.23718, 0.23608, 0.23499, 0.23401, 0.23315, 0.23242]
acc mean  : [0.202, 0.91, 0.92, 0.92, 0.92, 0.921, 0.921, 0.921, 0.921, 0.922, 0.922]
Total loss : 0.500977
Total acc  : 0.8545


In [8]:
with open('history/omniglot_5-way_1-shot_train_history.pickle', 'wb') as f:
    pickle.dump(history, f, pickle.HIGHEST_PROTOCOL)
    
with open('history/omniglot_5-way_1-shot_test_result.pickle', 'wb') as f:
    pickle.dump(test_result, f, pickle.HIGHEST_PROTOCOL)

## 5-way 5-shot
+ query : 5-way 15-shot

In [3]:
args = MyArgs(epoch=10,
              n_way=5,
              k_spt=5,
              k_qry=15,
              imgsz=84,
              imgc=1,
              task_num=8,
              meta_lr=1e-3,
              update_lr=0.01,
              update_step=5,
              update_step_test=10)

In [4]:
config = [
    ("conv2d", [32, 1, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [32, 32, 3, 3, 1, 0]),
    ("relu", [True]),
    ("bn", [32]),
    ("max_pool2d", [2, 1, 0]),
    ("flatten", []),
    ("linear", [args.n_way, 32 * 5 * 5]),
]


device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')
maml = Meta_Learner(args, config, task='classification').to(device)

print(maml)

Meta_Learner(
  (net): Learner(
    conv2d:(ch_in:1, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:2, padding:0)
    conv2d:(ch_in:32, ch_out:32, k:3x3, stride:1, padding:0)
    relu:(True,)
    bn:(32,)
    max_pool2d:(k:2, stride:1, padding:0)
    flatten:()
    linear:(in:800, out:5)
    
    (vars): ParameterList(
        (0): Parameter containing: [torch.cuda.FloatTensor of size 32x1x3x3 (GPU 0)]
        (1): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (2): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (3): Parameter containing: [torch.cuda.FloatTensor of size 32 (GPU 0)]
        (4): Parameter containing: [torch

In [5]:
# batchsz here means total episode number
train_db = Omniglot(
    "omniglot/",
    mode="train",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=5000,
    resize=args.imgsz,
)
valid_db = Omniglot(
    "omniglot/",
    mode="valid",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=50,
    resize=args.imgsz,
)
test_db = Omniglot(
    "omniglot/",
    mode="test",
    n_way=args.n_way,
    k_shot_spt=args.k_spt,
    k_shot_qry=args.k_qry,
    batch_size=100,
    resize=args.imgsz,
)

StanFord Dogs Dataset(train) :
	Batch_size : 5000
	Support sets : 5-way 5-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84
StanFord Dogs Dataset(valid) :
	Batch_size : 50
	Support sets : 5-way 5-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84
StanFord Dogs Dataset(test) :
	Batch_size : 100
	Support sets : 5-way 5-shot
	Query sets : 5-way 15-shot
	Resizing Image : 84x84


In [6]:
history = train(maml, device, args.epoch, args.task_num, train_db, valid_db, 32, 64)

625it [09:26,  1.10it/s, epoch=1/10, loss=0.987737, acc=0.786]
Mean Train Loss : 0.970982
Mean Train acc  : 0.7196
Mean Valid Loss : 0.703613
Mean Valid acc  : 0.8032

625it [10:30,  1.01s/it, epoch=2/10, loss=1.068335, acc=0.822]
Mean Train Loss : 1.142785
Mean Train acc  : 0.8012
Mean Valid Loss : 0.750977
Mean Valid acc  : 0.8608

625it [10:35,  1.02s/it, epoch=3/10, loss=1.259367, acc=0.826]
Mean Train Loss : 1.299601
Mean Train acc  : 0.8248
Mean Valid Loss : 0.827148
Mean Valid acc  : 0.8818

625it [10:23,  1.00it/s, epoch=4/10, loss=1.321163, acc=0.845]
Mean Train Loss : 1.306313
Mean Train acc  : 0.8358
Mean Valid Loss : 0.813477
Mean Valid acc  : 0.8887

625it [10:17,  1.01it/s, epoch=5/10, loss=1.205620, acc=0.852]
Mean Train Loss : 1.322368
Mean Train acc  : 0.8422
Mean Valid Loss : 0.808105
Mean Valid acc  : 0.8955

625it [10:47,  1.04s/it, epoch=6/10, loss=1.179096, acc=0.826]
Mean Train Loss : 1.296361
Mean Train acc  : 0.8459
Mean Valid Loss : 0.786621
Mean Valid acc  : 

In [7]:
test_result = evaluate(maml, device, test_db, mode='test')

print('loss mean : {}'.format([round(val, 5) for val in test_result[0]]))
print('acc mean  : {}'.format([round(val, 3) for val in test_result[1]]))
print('Total loss : {:.6f}'.format(test_result[2]))
print('Total acc  : {:.4f}'.format(test_result[3]))

100it [00:15,  6.60it/s, loss=0.783721, acc=0.913]
loss mean : [6.85938, 0.08521, 0.06696, 0.06067, 0.05936, 0.05832, 0.05746, 0.0567, 0.05606, 0.05551, 0.05499]
acc mean  : [0.202, 0.971, 0.978, 0.98, 0.98, 0.98, 0.98, 0.981, 0.981, 0.981, 0.981]
Total loss : 0.679199
Total acc  : 0.9087


In [8]:
with open('history/omniglot_5-way_5-shot_train_history.pickle', 'wb') as f:
    pickle.dump(history, f, pickle.HIGHEST_PROTOCOL)
    
with open('history/omniglot_5-way_5-shot_test_result.pickle', 'wb') as f:
    pickle.dump(test_result, f, pickle.HIGHEST_PROTOCOL)