# Experiments on Omniglot: Meta Neural Kernel vs. MAML & iMAML

In [1]:
from meta_cntk_utils import *
from maml_utils import *
from tqdm.notebook import trange
import torch.nn.functional as F
from time import time
from copy import deepcopy
import pickle
import numpy as np

%load_ext autoreload
%autoreload 2
%matplotlib inline

## Define Hyperparameters

In [2]:
accs = {}
n_characters = 20 # number of characters (i.e., classes) in the training dataset
n_way=  5 
n_shot = 1
gpu_id = 3
seed = 3
# random_cnn determines if we use a random or 
# trained CNN as the feature extractor
# if False, then we train a CNN over the training dataset
# in the supervised learning way. After that, we take its hidden
# layers as the feature extractor.
random_cnn = False
label_encoding_dim = 250 # last hidden size of the CNN

# If pre-train a CNN on the training data,
# pretrain_epochs sets the number of epochs
# for the pre-training
pretrain_epochs = 50
# Certainly, we can resample from the
# n_way classes to obtain more training tasks
train_data_enlarge_ratio =15
maml_epochs = 200
kernel = 'CNTK' # Take Convlutional NTK as the base kernel
# For simplicity, we only consider 
# cases s.t. n_characters % n_way == 0
assert n_characters % n_way == 0 
# Each task consists of n_way classes 
# So the least number of tasks is
# n_characters//n_way.
n_task = n_characters//n_way

# Determine the dimention of label encodings

#Random CNN: emb-dim = 1000 is better than 500 or 2000

batch_norm = True # batch_norm for the feature extractor
dropout = 0 # Dropout for the feature extractor
pretrain_batch_size = 16
weight_decay = 0


# number of channels for MAML CNN
n_channel_maml = 64

## Load Dataset and Preprocess

In [5]:
torch.cuda.set_device(gpu_id)
device = torch.device('cuda')
t0 = time()
# Load dataset
dataset = load_dataset(n_task,True,seed)

# train_set = get_train_data(dataset)

train_set,test_set = get_train_data(dataset,n_test_per_class=1)

# Construct a randomly initialized CNN as the feature extractor.
net = build_CNN(train_set['n_class'], device, n_channel=label_encoding_dim,batch_norm=batch_norm,dropout=dropout)

if not random_cnn and pretrain_epochs > 0:
    # Train a CNN on training data by supervised learning, in order
    # to obtain a better feature extractor than a random CNN when
    # training data is relatively large. As the training data are of 
    # small-size, the supervised training leads to overfitted CNN,
    # which is a worse feature extractor than a random CNN.
    print('Pre-training the feature extractor')
    net,test_accs,test_losses = pretrain(net,train_set,test_set,device,seed=seed,epochs=pretrain_epochs,
                                         weight_decay=weight_decay,batch_size=pretrain_batch_size)

encode_labels(dataset,net,device)

# Given n_way*n_task classes of samples, we can us resampling to obtain many tasks that consists of n_way distinct classes
# of samples. The following is the resampling procedure.
orig_dataset = deepcopy(dataset)
if train_data_enlarge_ratio > 1:
    augment_train_data(dataset,enlarge_ratio=train_data_enlarge_ratio,n_way=n_way,n_shot=n_shot,seed=seed)

# get_embeddings_from_PCA(dataset,PCA_method=PCA_method,n_components=n_components)
preprocess_label_embeddings(dataset)
load_precomputed_base_kernels(dataset,kernel=kernel)
print(f'Preprocessing takes {round(time()-t0,1)}s')

## Experiment on Meta Neural Kernels
Note: MetaCNTK = Meta Neural Kernel with CNTK as base kernels

In [4]:
metaCNTK = build_MetaCNTK(dataset,normalize_NTK=True,normalize_metaNTK=True)
metaCNTK_test_acc, metaCNTK_pred, metaCNTK_loss = test_MetaCNTK(dataset,metaCNTK)
metaCNTK_test_acc *= 100 # Numerical -> percentage
accs['mnk_accs'] = metaCNTK_test_acc
print("MetaCNTK Accuracy:",metaCNTK_test_acc)

## Experiment on MAML
The core code and hyperparameters is adopted from `higher`, a pytorch package: https://github.com/facebookresearch/higher/blob/master/examples/maml-omniglot.py
### Prepare dataset and build CNN

In [5]:

use_label_encodings = False # If True, use encoded labels and l2 loss for MAML
l2_loss = use_label_encodings
preprocess_label_embeddings(orig_dataset)
tasks = vars(orig_dataset)

X = np.concatenate([tasks['X_qry'],tasks['X_spt']],axis=1)
Y = np.concatenate([tasks['Y_qry'],tasks['Y_spt']],axis=1)
Y_emb = np.concatenate([tasks['Y_qry_emb'],tasks['Y_spt_emb']],axis=1)
new_X =[]
new_Y_emb = []
for x,y,y_emb in zip(X,Y,Y_emb):
    idxes = np.argsort(y).reshape(n_way,-1)
    for i in range(idxes.shape[0]):
        new_X.append([])
        new_Y_emb.append([])
        for j in range(idxes.shape[1]):
            idx = idxes[i,j]
            new_X[-1].append(x[idx])
            new_Y_emb[-1].append(y_emb[idx])
x_train = remove_padding(np.array(new_X))
y_train = np.array(new_Y_emb) if use_label_encodings else None
if l2_loss:
    test_tasks = remove_padding(tasks['test_X_spt']), tasks['test_Y_spt_emb'],remove_padding(tasks['test_X_qry']), tasks['test_Y_qry_emb']
else:
    test_tasks = remove_padding(tasks['test_X_spt']), tasks['test_Y_spt'],remove_padding(tasks['test_X_qry']), tasks['test_Y_qry']

from support.omniglot_loaders_original import OmniglotNShot
n_channel = n_channel_maml
batchsz = 32 if n_channel <= 1024 else 8

db = OmniglotNShot(root=None,
    batchsz=batchsz,
    n_way=5,
    k_shot=1,
    k_query=19,
    imgsz=28,
    device=device,
    n_train_tasks=None,
    given_x=True,
    x_train=x_train,
    x_test=None,
    y_train = y_train,
)
n_out = label_encoding_dim if use_label_encodings else n_way
net, meta_opt = build_MAML_model(n_out,device,lr=1e-3,n_channel=n_channel)

DB: train (20, 20, 1, 28, 28)


### Train and Test MAML

In [3]:
log = []
t = trange(maml_epochs,desc='MAML Training')
maml_test_accs = []
for epoch in t:
    train_acc = train_MAML(db, net, device, meta_opt, epoch, log,verbose=False,l2_loss = l2_loss)
#     t.set_postfix()
    if epoch % 5 == 0:
        test_acc=test_MAML(db, net, device, epoch, log,test_tasks,verbose=False,l2_loss=l2_loss,dataset=dataset)
        maml_test_accs.append(test_acc)
        if use_label_encodings:
            # We did not implement the function to calculate training accuracy in the l2 loss case.
            t.set_postfix(test_acc=test_acc,max_test_acc=np.max(maml_test_accs))
        else:
            t.set_postfix(train_acc=train_acc,test_acc=test_acc,max_test_acc=np.max(maml_test_accs))
accs['maml_accs'] =np.max(maml_test_accs)

## Experiment on implicit MAML (iMAML)
The code is adopted from https://github.com/prolearner/hypertorch/tree/master/hypergrad, along with default hyperparameters (copied below).
### Define hyperparameters

In [7]:
from imaml_utils import *
hg_mode = 'CG'
inner_log_interval = None
inner_log_interval_test = None
ways = n_way
shots = n_shot
test_shots=20-n_shot
batch_size = 16
n_channels = 64
reg_param = 2  # reg_param = 2
T, K = 16, 5  # T, K = 16, 5
T_test = T
inner_lr = .1

cuda = torch.cuda.is_available()
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
# the following are for reproducibility on GPU, see https://pytorch.org/docs/master/notes/randomness.html
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

torch.random.manual_seed(seed)
np.random.seed(seed)
meta_model = get_cnn_omniglot(n_channels, ways).to(device)

outer_opt = torch.optim.Adam(params=meta_model.parameters())
inner_opt_class = hg.GradientDescent
inner_opt_kwargs = {'step_size': inner_lr}

def get_inner_opt(train_loss):
    return inner_opt_class(train_loss, **inner_opt_kwargs)

In [None]:
eval_interval = 5
log_interval = 5
test_accs = []
t = trange(maml_epochs,desc='i-MAML Epoch')
for k in t:
    train_imaml(meta_model,db,reg_param,hg_mode,K,T,outer_opt,inner_log_interval)
    if k % eval_interval == 0:
        test_losses, test_acc = test_imaml(test_tasks, meta_model, T_test, get_inner_opt, reg_param, log_interval=None)
        test_acc = np.mean(test_acc)*100
        test_accs.append(test_acc)
        t.set_postfix(test_acc=test_acc,max_test_acc=np.max(test_accs))
accs['i-maml_accs'] =np.max(test_accs)

## Compare Test Accuracy

In [None]:
# test accuracy for MNK, MAML and iMAML
accs