In [1]:
import torch
import math
import os
import time
import json
import logging
import warnings

from collections import OrderedDict
from torchmeta.utils.data import BatchMetaDataLoader

from maml.datasets import get_benchmark_by_name
from maml.metalearners import ModelAgnosticMetaLearning
from maml.metalearners import FlatModelAgnosticMetaLearning
from maml.metalearners import SamModelAgnosticMetaLearning

from sam import SAM
from sam_folder.model.smooth_cross_entropy import smooth_crossentropy
from sam_folder.utility.bypass_bn import enable_running_stats, disable_running_stats
from sam_folder.model.wide_res_net import WideResNet
from sam_folder.utility.step_lr import StepLR

device = torch.device('cuda' if  torch.cuda.is_available() else 'cpu')
warnings.filterwarnings('ignore')

In [24]:
ways=5
shots=1
meta_lr=0.003
fast_lr=0.5
meta_batch_size=32

benchmark = get_benchmark_by_name('omniglot',
                                      './data',
                                      ways,
                                      shots,
                                      shots,
                                      hidden_size=64)

meta_test_dataloader = BatchMetaDataLoader(benchmark.meta_test_dataset,
                                            batch_size=32,
                                            shuffle=True,
                                            num_workers=2,
                                            pin_memory=True)
with open('/home/kristi/Desktop/maml_res/models/flat_omni.pth', 'rb') as f:
        benchmark.model.load_state_dict(torch.load(f, map_location=device))

metalearner = ModelAgnosticMetaLearning(benchmark.model,
                                        first_order=False,
                                        num_adaptation_steps=1,
                                        step_size=fast_lr,
                                        loss_function=benchmark.loss_function,
                                        device=device)

In [21]:
results = metalearner.evaluate(meta_test_dataloader,
                                   max_batches=500,
                                   verbose=True,
                                   desc='Test')

Test: 100%|██████████| 500/500 [01:16<00:00,  6.50it/s, accuracy=0.9843, loss=0.0451]


In [23]:
results = metalearner.evaluate(meta_test_dataloader,
                                   max_batches=500,
                                   verbose=True,
                                   desc='Test')

Test: 100%|██████████| 500/500 [01:14<00:00,  6.74it/s, accuracy=0.9878, loss=0.0356]


In [25]:
results = metalearner.evaluate(meta_test_dataloader,
                                   max_batches=500,
                                   verbose=True,
                                   desc='Test')

Test: 100%|██████████| 500/500 [01:22<00:00,  6.07it/s, accuracy=0.9762, loss=0.0701]
