In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch import nn, optim
from torch.utils.data import DataLoader
from utils.dataloader import MiniImagenet
from common.cnn_models import FewshotClassifier
from maml.models import MAML
from maml.utils import fast_adapt
from tqdm import tqdm

In [2]:
torch.manual_seed(777)
torch.cuda.manual_seed_all(777)
np.random.seed(777)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
N = 5
K = 1
Q = 15
batch_size = 8
meta_lr=0.003
fast_lr=0.5
adaptation_steps=1

In [4]:
root_path = './datasets/miniimagenet/pkl_file/' 
train_dataset = MiniImagenet(path=root_path, N=N, K=K, Q=Q, mode='train')
train_loader = DataLoader(train_dataset, batch_size=batch_size,\
                          shuffle=True, num_workers=1)
val_dataset = MiniImagenet(path=root_path, N=N, K=K, Q=Q, mode='validation', total_iter=1000)
val_loader = DataLoader(val_dataset, batch_size=batch_size,\
                        shuffle=True, num_workers=1)
test_dataset = MiniImagenet(path=root_path, N=N, K=K, Q=Q, mode='test', total_iter=5000)
test_loader = DataLoader(test_dataset, batch_size=batch_size,\
                         shuffle=True, num_workers=1)

In [5]:
# Create model
model = FewshotClassifier(output_size=N)
model.to(device)
maml = MAML(model, lr=fast_lr, first_order=False)
maml.to(device)
opt = optim.Adam(maml.parameters(), meta_lr)
loss = nn.CrossEntropyLoss(reduction='mean')

In [6]:
iteration = 0
for task_batch in train_loader:
    torch.cuda.empty_cache()
    opt.zero_grad()
    meta_train_error = []
    meta_train_accuracy = []
    for batch in zip(*task_batch):
        learner = maml.clone()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           device)
        evaluation_error.backward()
        meta_train_error.append(evaluation_error.item())
        meta_train_accuracy.append(evaluation_accuracy.item())
        iteration += 1 
    # Average the accumulated gradients and optimize
    for p in maml.parameters():
        p.grad.data.mul_(1.0 / len(task_batch))
    opt.step()

    if iteration % 1000 == 0:
        print('Iteration', iteration)
        print('Meta Train Error : ', round(np.mean(meta_train_error),4),\
              " std : ", round(np.std(meta_train_error), 4))
        print('Meta Train Accuracy : ', round(np.mean(meta_train_accuracy), 4),\
              " std : ", round(np.std(meta_train_accuracy), 4))
    
    if iteration % 5000 == 0:
        # Compute meta-validation loss
        meta_valid_error = []
        meta_valid_accuracy = []
        for task_batch in val_loader:
            for batch in zip(*task_batch):
                learner = maml.clone()
                evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                                   learner,
                                                                   loss,
                                                                   adaptation_steps,
                                                                   device)
                meta_valid_error.append(evaluation_error.item())
                meta_valid_accuracy.append(evaluation_accuracy.item())
        print('Valid Result')
        print('Meta Valid Error : ', round(np.mean(meta_valid_error),4),\
              " std : ", round(np.std(meta_valid_error), 4))
        print('Meta Valid Accuracy : ', round(np.mean(meta_valid_accuracy), 4),\
              " std : ", round(np.std(meta_valid_accuracy), 4))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  2%|▏         | 125/7500 [00:32<31:02,  3.96it/s]

Iteration 1000
Meta Train Error 3.548592448234558
Meta Train Error 0.4633333422243595


  3%|▎         | 250/7500 [01:03<31:33,  3.83it/s]

Iteration 2000
Meta Train Error 3.1803328096866608
Meta Train Error 0.5100000128149986


  5%|▌         | 375/7500 [01:35<29:59,  3.96it/s]

Iteration 3000
Meta Train Error 3.2033116221427917
Meta Train Error 0.5266666822135448


  7%|▋         | 500/7500 [02:07<29:39,  3.93it/s]

Iteration 4000
Meta Train Error 3.0240584313869476
Meta Train Error 0.666666679084301


  8%|▊         | 624/7500 [02:38<29:05,  3.94it/s]

Iteration 5000
Meta Train Error 3.1707969307899475
Meta Train Error 0.5433333441615105


  8%|▊         | 625/7500 [02:59<12:31:18,  6.56s/it]

Valid Result
Meta Valid Error :  1.567244961977005
Meta Valid Accuracy :  0.2865200060456991  std :  0.08335438652692144


 10%|█         | 750/7500 [03:31<28:42,  3.92it/s]   

Iteration 6000
Meta Train Error 3.0748651027679443
Meta Train Error 0.6133333407342434


 12%|█▏        | 875/7500 [04:03<28:02,  3.94it/s]

Iteration 7000
Meta Train Error 3.0864515602588654
Meta Train Error 0.676666684448719


 13%|█▎        | 1000/7500 [04:35<27:34,  3.93it/s]

Iteration 8000
Meta Train Error 2.8263639509677887
Meta Train Error 0.7433333396911621


 15%|█▌        | 1125/7500 [05:07<27:06,  3.92it/s]

Iteration 9000
Meta Train Error 2.853819817304611
Meta Train Error 0.7633333541452885


 17%|█▋        | 1249/7500 [05:38<26:22,  3.95it/s]

Iteration 10000
Meta Train Error 2.9128872752189636
Meta Train Error 0.783333346247673


 17%|█▋        | 1250/7500 [06:00<11:16:29,  6.49s/it]

Valid Result
Meta Valid Error :  1.5280929988622665
Meta Valid Accuracy :  0.3259066734910011  std :  0.10500354519598899


 18%|█▊        | 1375/7500 [06:31<25:51,  3.95it/s]   

Iteration 11000
Meta Train Error 2.799357682466507
Meta Train Error 0.7433333545923233


 20%|██        | 1500/7500 [07:03<25:18,  3.95it/s]

Iteration 12000
Meta Train Error 2.7568684220314026
Meta Train Error 0.856666699051857


 22%|██▏       | 1625/7500 [07:35<24:56,  3.93it/s]

Iteration 13000
Meta Train Error 2.811398059129715
Meta Train Error 0.8333333507180214


 23%|██▎       | 1750/7500 [08:07<24:33,  3.90it/s]

Iteration 14000
Meta Train Error 2.6892464458942413
Meta Train Error 0.8833333551883698


 25%|██▍       | 1874/7500 [08:39<23:47,  3.94it/s]

Iteration 15000
Meta Train Error 2.8890113532543182
Meta Train Error 0.7600000128149986


 25%|██▌       | 1875/7500 [09:00<10:16:28,  6.58s/it]

Valid Result
Meta Valid Error :  1.4711107640266419
Meta Valid Accuracy :  0.3681066748648882  std :  0.12669357795716024


 27%|██▋       | 2000/7500 [09:32<23:10,  3.96it/s]   

Iteration 16000
Meta Train Error 2.695343554019928
Meta Train Error 0.9333333522081375


 28%|██▊       | 2125/7500 [10:03<22:42,  3.95it/s]

Iteration 17000
Meta Train Error 2.7572497129440308
Meta Train Error 0.8333333693444729


 30%|███       | 2250/7500 [10:35<22:06,  3.96it/s]

Iteration 18000
Meta Train Error 2.773266524076462
Meta Train Error 0.8300000093877316


 32%|███▏      | 2375/7500 [11:07<21:59,  3.88it/s]

Iteration 19000
Meta Train Error 2.75663298368454
Meta Train Error 0.8600000217556953


 33%|███▎      | 2499/7500 [11:39<20:55,  3.98it/s]

Iteration 20000
Meta Train Error 2.608903229236603
Meta Train Error 0.9600000232458115


 33%|███▎      | 2500/7500 [12:00<9:00:26,  6.49s/it]

Valid Result
Meta Valid Error :  1.447327131986618
Meta Valid Accuracy :  0.3875333417057991  std :  0.1401058516728079


 35%|███▌      | 2625/7500 [12:32<20:42,  3.92it/s]  

Iteration 21000
Meta Train Error 2.607917070388794
Meta Train Error 0.9066666960716248


 37%|███▋      | 2750/7500 [13:03<20:02,  3.95it/s]

Iteration 22000
Meta Train Error 2.326435685157776
Meta Train Error 1.1066666841506958


 38%|███▊      | 2875/7500 [13:35<19:25,  3.97it/s]

Iteration 23000
Meta Train Error 2.5056224167346954
Meta Train Error 0.976666696369648


 40%|████      | 3000/7500 [14:07<18:59,  3.95it/s]

Iteration 24000
Meta Train Error 2.4532414972782135
Meta Train Error 1.003333367407322


 42%|████▏     | 3124/7500 [14:38<18:37,  3.92it/s]

Iteration 25000
Meta Train Error 2.3961819857358932
Meta Train Error 1.0233333557844162


 42%|████▏     | 3125/7500 [14:59<7:57:24,  6.55s/it]

Valid Result
Meta Valid Error :  1.4057217273116112
Meta Valid Accuracy :  0.40802667586505414  std :  0.1561622334187434


 43%|████▎     | 3250/7500 [15:32<18:10,  3.90it/s]  

Iteration 26000
Meta Train Error 2.728569447994232
Meta Train Error 0.8933333456516266


 45%|████▌     | 3375/7500 [16:03<17:31,  3.92it/s]

Iteration 27000
Meta Train Error 2.6625219583511353
Meta Train Error 0.9133333414793015


 47%|████▋     | 3500/7500 [16:35<16:50,  3.96it/s]

Iteration 28000
Meta Train Error 2.811162233352661
Meta Train Error 0.8200000040233135


 48%|████▊     | 3625/7500 [17:07<16:45,  3.86it/s]

Iteration 29000
Meta Train Error 2.512761741876602
Meta Train Error 0.9933333545923233


 50%|████▉     | 3749/7500 [17:38<15:49,  3.95it/s]

Iteration 30000
Meta Train Error 2.5350708067417145
Meta Train Error 0.976666696369648


 50%|█████     | 3750/7500 [17:59<6:44:20,  6.47s/it]

Valid Result
Meta Valid Error :  1.3934746989011764
Meta Valid Accuracy :  0.4167733429968357  std :  0.16350994787600182


 52%|█████▏    | 3875/7500 [18:31<15:24,  3.92it/s]  

Iteration 31000
Meta Train Error 2.647527754306793
Meta Train Error 0.9400000274181366


 53%|█████▎    | 4000/7500 [19:03<14:49,  3.94it/s]

Iteration 32000
Meta Train Error 2.722761854529381
Meta Train Error 0.8766666911542416


 55%|█████▌    | 4125/7500 [19:35<14:13,  3.95it/s]

Iteration 33000
Meta Train Error 2.4940733462572098
Meta Train Error 0.9133333414793015


 57%|█████▋    | 4250/7500 [20:07<13:57,  3.88it/s]

Iteration 34000
Meta Train Error 2.432102531194687
Meta Train Error 1.0366666838526726


 58%|█████▊    | 4374/7500 [20:39<13:12,  3.94it/s]

Iteration 35000
Meta Train Error 2.285119727253914
Meta Train Error 1.0766666904091835


 58%|█████▊    | 4375/7500 [20:59<5:33:05,  6.40s/it]

Valid Result
Meta Valid Error :  1.3730271518826485
Meta Valid Accuracy :  0.42472000969946383  std :  0.1593074577542633


 60%|██████    | 4500/7500 [21:31<12:49,  3.90it/s]  

Iteration 36000
Meta Train Error 2.619266599416733
Meta Train Error 0.940000019967556


 62%|██████▏   | 4625/7500 [22:03<12:04,  3.97it/s]

Iteration 37000
Meta Train Error 2.4627963602542877
Meta Train Error 0.9866666868329048


 63%|██████▎   | 4750/7500 [22:35<11:40,  3.92it/s]

Iteration 38000
Meta Train Error 2.237085923552513
Meta Train Error 1.1166667118668556


 65%|██████▌   | 4875/7500 [23:07<11:09,  3.92it/s]

Iteration 39000
Meta Train Error 2.425740748643875
Meta Train Error 1.0166666805744171


 67%|██████▋   | 4999/7500 [23:38<10:29,  3.98it/s]

Iteration 40000
Meta Train Error 2.13297376036644
Meta Train Error 1.2466666996479034


 67%|██████▋   | 5000/7500 [23:59<4:21:55,  6.29s/it]

Valid Result
Meta Valid Error :  1.3613393127918243
Meta Valid Accuracy :  0.4357866762429476  std :  0.17729712531935393


 68%|██████▊   | 5125/7500 [24:30<10:08,  3.91it/s]  

Iteration 41000
Meta Train Error 2.5209033340215683
Meta Train Error 0.9133333340287209


 70%|███████   | 5250/7500 [25:02<09:16,  4.04it/s]

Iteration 42000
Meta Train Error 2.299491301178932
Meta Train Error 1.0800000131130219


 72%|███████▏  | 5375/7500 [25:33<08:46,  4.04it/s]

Iteration 43000
Meta Train Error 2.225634679198265
Meta Train Error 1.1166666895151138


 73%|███████▎  | 5500/7500 [26:04<08:23,  3.97it/s]

Iteration 44000
Meta Train Error 2.1956450641155243
Meta Train Error 1.1433333531022072


 75%|███████▍  | 5624/7500 [26:35<07:52,  3.97it/s]

Iteration 45000
Meta Train Error 2.16292105615139
Meta Train Error 1.1533333659172058


 75%|███████▌  | 5625/7500 [26:56<3:21:21,  6.44s/it]

Valid Result
Meta Valid Error :  1.3449099182486535
Meta Valid Accuracy :  0.44149334298074244  std :  0.16912230682008192


 77%|███████▋  | 5750/7500 [27:27<07:32,  3.86it/s]  

Iteration 46000
Meta Train Error 2.1196274161338806
Meta Train Error 1.21666669100523


 78%|███████▊  | 5875/7500 [27:58<06:39,  4.06it/s]

Iteration 47000
Meta Train Error 2.2675233483314514
Meta Train Error 1.0466666966676712


 80%|████████  | 6000/7500 [28:30<06:15,  3.99it/s]

Iteration 48000
Meta Train Error 2.2813019901514053
Meta Train Error 1.0866666957736015


 82%|████████▏ | 6125/7500 [29:00<05:18,  4.32it/s]

Iteration 49000
Meta Train Error 2.3413216173648834
Meta Train Error 1.0566666945815086


 83%|████████▎ | 6249/7500 [29:30<05:05,  4.09it/s]

Iteration 50000
Meta Train Error 2.3164773136377335
Meta Train Error 1.0633333548903465


 83%|████████▎ | 6250/7500 [29:51<2:11:46,  6.33s/it]

Valid Result
Meta Valid Error :  1.3468253734111786
Meta Valid Accuracy :  0.44105334328114987  std :  0.18070862718249833


 85%|████████▌ | 6375/7500 [30:22<04:37,  4.06it/s]  

Iteration 51000
Meta Train Error 2.346715360879898
Meta Train Error 1.0666666775941849


 87%|████████▋ | 6500/7500 [30:53<04:06,  4.05it/s]

Iteration 52000
Meta Train Error 2.349168449640274
Meta Train Error 1.0966666787862778


 88%|████████▊ | 6625/7500 [31:24<03:40,  3.97it/s]

Iteration 53000
Meta Train Error 2.4034904688596725
Meta Train Error 1.0100000128149986


 90%|█████████ | 6750/7500 [31:55<03:04,  4.06it/s]

Iteration 54000
Meta Train Error 2.3302767872810364
Meta Train Error 1.0033333599567413


 92%|█████████▏| 6874/7500 [32:26<02:33,  4.08it/s]

Iteration 55000
Meta Train Error 2.3005583733320236
Meta Train Error 1.0933333411812782


 92%|█████████▏| 6875/7500 [32:46<1:04:23,  6.18s/it]

Valid Result
Meta Valid Error :  1.3355601305365563
Meta Valid Accuracy :  0.4485733433663845  std :  0.18662137999675227


 93%|█████████▎| 7000/7500 [33:17<02:03,  4.05it/s]  

Iteration 56000
Meta Train Error 2.3986108601093292
Meta Train Error 1.0333333685994148


 95%|█████████▌| 7125/7500 [33:47<01:32,  4.04it/s]

Iteration 57000
Meta Train Error 2.1632602512836456
Meta Train Error 1.1800000369548798


 97%|█████████▋| 7250/7500 [34:18<01:01,  4.05it/s]

Iteration 58000
Meta Train Error 1.9559813514351845
Meta Train Error 1.2266666889190674


 98%|█████████▊| 7375/7500 [34:49<00:30,  4.04it/s]

Iteration 59000
Meta Train Error 2.196537896990776
Meta Train Error 1.1033333465456963


100%|█████████▉| 7499/7500 [35:20<00:00,  4.04it/s]

Iteration 60000
Meta Train Error 2.162126451730728
Meta Train Error 1.1533333659172058


100%|██████████| 7500/7500 [35:41<00:00,  3.50it/s]

Valid Result
Meta Valid Error :  1.3371518250107766
Meta Valid Accuracy :  0.4474800105243921  std :  0.18822809740505483





In [9]:
meta_train_accuracy

4.613333463668823

In [10]:
meta_test_error = []
meta_test_accuracy = []
for task_batch in test_loader:
    for batch in zip(*task_batch):
        learner = maml.clone()
        evaluation_error, evaluation_accuracy = fast_adapt(batch,
                                                           learner,
                                                           loss,
                                                           adaptation_steps,
                                                           device)
        meta_test_error.append(evaluation_error.item())
        meta_test_accuracy.append(evaluation_accuracy.item())
        torch.cuda.empty_cache()
print('Valid Result')
print('Meta Valid Error : ', round(np.mean(meta_test_error),4),\
        " std : ", round(np.std(meta_test_error), 4))
print('Meta Valid Accuracy : ', round(np.mean(meta_test_accuracy), 4),\
        " std : ", round(np.std(meta_test_accuracy), 4))

Valid Result
Meta Valid Error :  1.3157  std :  0.1907
Meta Valid Accuracy :  0.4559  std :  0.0972
