In [1]:
import torch
import math
import os
import time
import json
import logging
import warnings

from collections import OrderedDict
from torchmeta.utils.data import BatchMetaDataLoader

from maml.datasets import get_benchmark_by_name
from maml.metalearners import ModelAgnosticMetaLearning
from maml.metalearners import FlatModelAgnosticMetaLearning
from maml.metalearners import SamModelAgnosticMetaLearning

from sam import SAM
from sam_folder.model.smooth_cross_entropy import smooth_crossentropy
from sam_folder.utility.bypass_bn import enable_running_stats, disable_running_stats
from sam_folder.model.wide_res_net import WideResNet
from sam_folder.utility.step_lr import StepLR

device = torch.device('cuda' if  torch.cuda.is_available() else 'cpu')
warnings.filterwarnings('ignore')

In [2]:
ways=5
shots=1
meta_lr=0.003
fast_lr=0.5
meta_batch_size=32

benchmark = get_benchmark_by_name('omniglot',
                                      './data',
                                      ways,
                                      shots,
                                      shots,
                                      hidden_size=64)

meta_train_dataloader = BatchMetaDataLoader(benchmark.meta_train_dataset,
                                            batch_size=32,
                                            shuffle=True,
                                            num_workers=2,
                                            pin_memory=True)
meta_val_dataloader = BatchMetaDataLoader(benchmark.meta_val_dataset,
                                            batch_size=16,
                                            shuffle=True,
                                            num_workers=2,
                                            pin_memory=True)

meta_optimizer = torch.optim.Adam(benchmark.model.parameters(), lr=meta_lr)
#base_optimizer = torch.optim.Adam
#meta_optimizer = SAM(benchmark.model.parameters(), base_optimizer, rho=0.05,
#                        adaptive=False, lr=meta_lr)
metalearner = FlatModelAgnosticMetaLearning(benchmark.model,
                                        meta_optimizer,
                                        first_order=False,
                                        num_adaptation_steps=1,
                                        step_size=fast_lr,
                                        loss_function=benchmark.loss_function,
                                        device=device)

In [3]:
best_value = None
num_epochs = 200
num_batches = 500
# Training loop
epoch_desc = 'Epoch {{0: <{0}d}}'.format(1 + int(math.log10(num_epochs)))
for epoch in range(num_epochs):
    #if epoch%10==0:
    #    metalearner.calculate_flatness(meta_val_dataloader,
    #                                    max_batches=1)
                                        
    metalearner.train(meta_train_dataloader,
                        max_batches=num_batches,
                        verbose=True,
                        desc='Training',
                        leave=False)
    results = metalearner.evaluate(meta_val_dataloader,
                                    max_batches=num_batches,
                                    verbose=True,
                                    desc=epoch_desc.format(epoch + 1))

    # Save best model
    if 'accuracies_after' in results:
        if (best_value is None) or (best_value < results['accuracies_after']):
            best_value = results['accuracies_after']
            save_model = True
    elif (best_value is None) or (best_value > results['mean_outer_loss']):
        best_value = results['mean_outer_loss']
        save_model = True
    else:
        save_model = False

    if save_model:
        current_directory = os.getcwd()
        final_directory = os.path.join(current_directory, r'models')
        if not os.path.isdir(final_directory):
                    os.mkdir(final_directory)
        torch.save(benchmark.model.state_dict(), './models/flat_maml.pth')

if hasattr(benchmark.meta_train_dataset, 'close'):
    benchmark.meta_train_dataset.close()
    benchmark.meta_val_dataset.close()

Epoch 1  : 100%|██████████| 500/500 [00:32<00:00, 15.34it/s, accuracy=0.9303, loss=0.2116]
Epoch 2  : 100%|██████████| 500/500 [00:33<00:00, 15.14it/s, accuracy=0.9535, loss=0.1427]
Epoch 3  : 100%|██████████| 500/500 [00:31<00:00, 15.63it/s, accuracy=0.9599, loss=0.1231]
Epoch 4  : 100%|██████████| 500/500 [00:33<00:00, 15.00it/s, accuracy=0.9630, loss=0.1142]
Epoch 5  : 100%|██████████| 500/500 [00:33<00:00, 14.91it/s, accuracy=0.9678, loss=0.0981]
Epoch 6  : 100%|██████████| 500/500 [00:34<00:00, 14.58it/s, accuracy=0.9717, loss=0.0859]
Epoch 7  : 100%|██████████| 500/500 [00:33<00:00, 14.81it/s, accuracy=0.9691, loss=0.0995]
Epoch 8  : 100%|██████████| 500/500 [00:32<00:00, 15.23it/s, accuracy=0.9653, loss=0.1077]
Epoch 9  : 100%|██████████| 500/500 [00:33<00:00, 14.95it/s, accuracy=0.9710, loss=0.0909]
Epoch 10 : 100%|██████████| 500/500 [00:32<00:00, 15.31it/s, accuracy=0.9745, loss=0.0777]
Epoch 11 : 100%|██████████| 500/500 [00:31<00:00, 15.83it/s, accuracy=0.9750, loss=0.0781]

KeyboardInterrupt: 