In [1]:
import torch
import math
import os
import time
import json
import logging
import warnings

from collections import OrderedDict
from torchmeta.utils.data import BatchMetaDataLoader

from maml.datasets import get_benchmark_by_name
from maml.metalearners import ModelAgnosticMetaLearning
from maml.metalearners import FlatModelAgnosticMetaLearning
from maml.metalearners import SamModelAgnosticMetaLearning

from sam import SAM
from sam_folder.model.smooth_cross_entropy import smooth_crossentropy
from sam_folder.utility.bypass_bn import enable_running_stats, disable_running_stats
from sam_folder.model.wide_res_net import WideResNet
from sam_folder.utility.step_lr import StepLR

device = torch.device('cuda' if  torch.cuda.is_available() else 'cpu')
warnings.filterwarnings('ignore')

In [6]:
ways=5
shots=1
meta_lr=0.003
fast_lr=0.5
meta_batch_size=32

benchmark = get_benchmark_by_name('omniglot',
                                      './data',
                                      ways,
                                      shots,
                                      shots,
                                      hidden_size=64)

meta_train_dataloader = BatchMetaDataLoader(benchmark.meta_train_dataset,
                                            batch_size=32,
                                            shuffle=True,
                                            num_workers=2,
                                            pin_memory=True)
meta_val_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
                                            batch_size=16,
                                            shuffle=True,
                                            num_workers=2,
                                            pin_memory=True)

meta_optimizer = torch.optim.Adam(benchmark.model.parameters(), lr=meta_lr)
#base_optimizer = torch.optim.Adam
#meta_optimizer = SAM(benchmark.model.parameters(), base_optimizer, rho=0.05,
#                        adaptive=False, lr=meta_lr)
metalearner = ModelAgnosticMetaLearning(benchmark.model,
                                        meta_optimizer,
                                        first_order=False,
                                        num_adaptation_steps=1,
                                        step_size=fast_lr,
                                        loss_function=benchmark.loss_function,
                                        device=device)

In [7]:
best_value = None
num_epochs = 200
num_batches = 500
# Training loop
epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 + int(math.log10(num_epochs)))
for epoch in range(num_epochs):
    #if epoch%10==0:
    #    metalearner.calculate_flatness(meta_val_dataloader,
    #                                    max_batches=1)
                                        
    metalearner.train(meta_train_dataloader,
                        max_batches=num_batches,
                        verbose=True,
                        desc='Training',
                        leave=False)
    results = metalearner.evaluate(meta_val_dataloader,
                                    max_batches=num_batches,
                                    verbose=True,
                                    desc=epoch_desc.format(epoch + 1))

    # Save best model
    if 'accuracies_after' in results:
        if (best_value is None) or (best_value < results['accuracies_after']):
            best_value = results['accuracies_after']
            save_model = True
    elif (best_value is None) or (best_value > results['mean_outer_loss']):
        best_value = results['mean_outer_loss']
        save_model = True
    else:
        save_model = False

    if save_model:
        current_directory = os.getcwd()
        final_directory = os.path.join(current_directory, r'models')
        if not os.path.isdir(final_directory):
                    os.mkdir(final_directory)
        torch.save(benchmark.model.state_dict(), './models/sam_omni.pth')

if hasattr(benchmark.meta_train_dataset, 'close'):
    benchmark.meta_train_dataset.close()
    benchmark.meta_val_dataset.close()

Epoch 1  : 100%|██████████| 500/500 [00:33<00:00, 14.83it/s, accuracy=0.9334, loss=0.2076]
Epoch 2  : 100%|██████████| 500/500 [00:33<00:00, 14.98it/s, accuracy=0.9479, loss=0.1586]
Epoch 3  : 100%|██████████| 500/500 [00:33<00:00, 15.13it/s, accuracy=0.9636, loss=0.1125]
Epoch 4  : 100%|██████████| 500/500 [00:32<00:00, 15.21it/s, accuracy=0.9667, loss=0.0999]
Epoch 5  : 100%|██████████| 500/500 [00:32<00:00, 15.34it/s, accuracy=0.9714, loss=0.0865]
Epoch 6  : 100%|██████████| 500/500 [00:33<00:00, 15.07it/s, accuracy=0.9769, loss=0.0722]
Epoch 7  : 100%|██████████| 500/500 [00:33<00:00, 15.11it/s, accuracy=0.9720, loss=0.0853]
Epoch 8  : 100%|██████████| 500/500 [00:33<00:00, 15.10it/s, accuracy=0.9764, loss=0.0701]
Epoch 9  : 100%|██████████| 500/500 [00:31<00:00, 15.83it/s, accuracy=0.9761, loss=0.0730]
Epoch 10 : 100%|██████████| 500/500 [00:34<00:00, 14.68it/s, accuracy=0.9811, loss=0.0568]
Epoch 11 : 100%|██████████| 500/500 [00:34<00:00, 14.69it/s, accuracy=0.9781, loss=0.0652]

KeyboardInterrupt: 

In [None]:
meta_test_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
                                            batch_size=32,
                                            shuffle=True,
                                            num_workers=2,
                                            pin_memory=True)

results2 = metalearner.evaluate(meta_test_dataloader,
                                   max_batches=500,
                                   verbose=True,
                                   desc='Test')

Test:   2%|▏         | 12/500 [00:02<01:42,  4.76it/s, accuracy=0.2547, loss=1.6896]


KeyboardInterrupt: 

In [None]:
Epoch 1  : 100%|██████████| 500/500 [00:30<00:00, 16.33it/s, accuracy=0.8871, loss=0.3333]
Epoch 2  : 100%|██████████| 500/500 [00:29<00:00, 16.71it/s, accuracy=0.9121, loss=0.2780]
Epoch 3  : 100%|██████████| 500/500 [00:30<00:00, 16.53it/s, accuracy=0.9222, loss=0.2342]
Epoch 4  : 100%|██████████| 500/500 [00:29<00:00, 16.70it/s, accuracy=0.9322, loss=0.2033]
Epoch 5  : 100%|██████████| 500/500 [00:28<00:00, 17.46it/s, accuracy=0.9334, loss=0.2084]
Epoch 6  : 100%|██████████| 500/500 [00:30<00:00, 16.61it/s, accuracy=0.9327, loss=0.2093]
Epoch 7  : 100%|██████████| 500/500 [00:30<00:00, 16.63it/s, accuracy=0.9342, loss=0.2058]
Epoch 8  : 100%|██████████| 500/500 [00:28<00:00, 17.52it/s, accuracy=0.9347, loss=0.2006]
Epoch 9  : 100%|██████████| 500/500 [00:30<00:00, 16.27it/s, accuracy=0.9432, loss=0.1772]
Epoch 10 : 100%|██████████| 500/500 [00:30<00:00, 16.33it/s, accuracy=0.9344, loss=0.2031]
Epoch 11 : 100%|██████████| 500/500 [00:30<00:00, 16.46it/s, accuracy=0.9462, loss=0.1665]
Epoch 12 : 100%|██████████| 500/500 [00:30<00:00, 16.67it/s, accuracy=0.9470, loss=0.1661]
Epoch 13 : 100%|██████████| 500/500 [00:29<00:00, 17.05it/s, accuracy=0.9538, loss=0.1431]
Epoch 14 : 100%|██████████| 500/500 [00:29<00:00, 16.95it/s, accuracy=0.9480, loss=0.1575]
Epoch 15 : 100%|██████████| 500/500 [00:30<00:00, 16.57it/s, accuracy=0.9533, loss=0.1425]
Epoch 16 : 100%|██████████| 500/500 [00:30<00:00, 16.63it/s, accuracy=0.9582, loss=0.1255]
Epoch 17 : 100%|██████████| 500/500 [00:30<00:00, 16.66it/s, accuracy=0.9610, loss=0.1153]
Epoch 18 : 100%|██████████| 500/500 [00:30<00:00, 16.65it/s, accuracy=0.9668, loss=0.0998]
Epoch 19 : 100%|██████████| 500/500 [00:30<00:00, 16.36it/s, accuracy=0.9644, loss=0.1030]
Epoch 20 : 100%|██████████| 500/500 [00:28<00:00, 17.61it/s, accuracy=0.9736, loss=0.0796]
Epoch 21 : 100%|██████████| 500/500 [00:28<00:00, 17.68it/s, accuracy=0.9713, loss=0.0826]
Epoch 22 : 100%|██████████| 500/500 [00:30<00:00, 16.46it/s, accuracy=0.9745, loss=0.0738]
Epoch 23 : 100%|██████████| 500/500 [00:31<00:00, 16.10it/s, accuracy=0.9740, loss=0.0756]
Epoch 24 : 100%|██████████| 500/500 [00:30<00:00, 16.21it/s, accuracy=0.9771, loss=0.0680]
Epoch 25 : 100%|██████████| 500/500 [00:30<00:00, 16.55it/s, accuracy=0.9732, loss=0.0775]
Epoch 26 : 100%|██████████| 500/500 [00:29<00:00, 16.92it/s, accuracy=0.9730, loss=0.0796]
Epoch 27 : 100%|██████████| 500/500 [00:30<00:00, 16.63it/s, accuracy=0.9762, loss=0.0722]
Epoch 28 : 100%|██████████| 500/500 [00:30<00:00, 16.27it/s, accuracy=0.9777, loss=0.0677]
Epoch 29 : 100%|██████████| 500/500 [00:30<00:00, 16.39it/s, accuracy=0.9778, loss=0.0640]
Epoch 30 : 100%|██████████| 500/500 [00:30<00:00, 16.25it/s, accuracy=0.9769, loss=0.0686]
Epoch 31 : 100%|██████████| 500/500 [00:29<00:00, 16.98it/s, accuracy=0.9734, loss=0.0765]
Epoch 32 : 100%|██████████| 500/500 [00:29<00:00, 16.74it/s, accuracy=0.9749, loss=0.0719]
Epoch 33 : 100%|██████████| 500/500 [00:30<00:00, 16.42it/s, accuracy=0.9782, loss=0.0640]
Epoch 34 : 100%|██████████| 500/500 [00:30<00:00, 16.34it/s, accuracy=0.9777, loss=0.0636]
Epoch 35 : 100%|██████████| 500/500 [00:29<00:00, 16.73it/s, accuracy=0.9764, loss=0.0691]
Epoch 36 : 100%|██████████| 500/500 [00:30<00:00, 16.62it/s, accuracy=0.9773, loss=0.0684]
Epoch 37 : 100%|██████████| 500/500 [00:30<00:00, 16.54it/s, accuracy=0.9778, loss=0.0630]
Epoch 38 : 100%|██████████| 500/500 [00:29<00:00, 16.67it/s, accuracy=0.9791, loss=0.0603]
Epoch 39 : 100%|██████████| 500/500 [00:30<00:00, 16.24it/s, accuracy=0.9779, loss=0.0639]
Epoch 40 : 100%|██████████| 500/500 [00:30<00:00, 16.42it/s, accuracy=0.9782, loss=0.0634]
Epoch 41 : 100%|██████████| 500/500 [00:29<00:00, 16.81it/s, accuracy=0.9764, loss=0.0695]
Epoch 42 : 100%|██████████| 500/500 [00:30<00:00, 16.27it/s, accuracy=0.9799, loss=0.0594]
Epoch 43 : 100%|██████████| 500/500 [00:29<00:00, 16.74it/s, accuracy=0.9794, loss=0.0603]
Epoch 44 : 100%|██████████| 500/500 [00:30<00:00, 16.28it/s, accuracy=0.9803, loss=0.0577]
Epoch 45 : 100%|██████████| 500/500 [00:30<00:00, 16.19it/s, accuracy=0.9811, loss=0.0575]
Epoch 46 : 100%|██████████| 500/500 [00:30<00:00, 16.50it/s, accuracy=0.9783, loss=0.0644]
Epoch 47 : 100%|██████████| 500/500 [00:30<00:00, 16.19it/s, accuracy=0.9800, loss=0.0586]
Epoch 48 : 100%|██████████| 500/500 [00:30<00:00, 16.46it/s, accuracy=0.9792, loss=0.0609]
Epoch 49 : 100%|██████████| 500/500 [00:30<00:00, 16.52it/s, accuracy=0.9810, loss=0.0567]
                                                                                         