In [1]:
from torch.utils.data import DataLoader
from torch import nn

from few_shot.datasets import OmniglotDataset, MiniImageNet
from few_shot.core import NShotTaskSampler, create_nshot_task_label, EvaluateFewShot
from few_shot.maml import meta_gradient_step
from few_shot.models import FewShotClassifier
from few_shot.train import fit
from few_shot.callbacks import *
from few_shot.utils import setup_dirs
from config import PATH

In [2]:
class Args:
    n = 1
    k = 5
    q = 5
    inner_train_steps = 2
    inner_val_steps = 1
    inner_lr = 0.01
    meta_lr = 0.001
    meta_batch_size = 1
    order = 2
    epochs = 1
    epoch_len = 5
    eval_batches = 1
    
    dataset = 'miniImageNet'
#     dataset = 'omniglot'
    gpu = 1
    
args = Args()
    
assert torch.cuda.is_available()
torch.cuda.set_device(args.gpu)
device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

In [3]:
if args.dataset == 'omniglot':
    dataset_class = OmniglotDataset
    fc_layer_size = 64
    num_input_channels = 1
elif args.dataset == 'miniImageNet':
    dataset_class = MiniImageNet
    fc_layer_size = 1600
    num_input_channels = 3
else:
    raise(ValueError('Unsupported dataset'))

In [4]:
###################
# Create datasets #
###################

background = dataset_class('background')
background_taskloader = DataLoader(
    background,
    batch_sampler=NShotTaskSampler(background, args.epoch_len, n=args.n, k=args.k, q=args.q,
                                   num_tasks=args.meta_batch_size),
    num_workers=8
)
evaluation = dataset_class('evaluation')
evaluation_taskloader = DataLoader(
    evaluation,
    batch_sampler=NShotTaskSampler(evaluation, args.eval_batches, n=args.n, k=args.k, q=args.q,
                                   num_tasks=args.meta_batch_size),
    num_workers=8
)

Indexing background...


48000it [00:00, 564358.96it/s]


Indexing evaluation...


12000it [00:00, 594803.15it/s]


In [5]:
meta_model_ori = FewShotClassifier(num_input_channels, args.k, fc_layer_size).to(device, dtype=torch.double)
meta_model_new = FewShotClassifier(num_input_channels, args.k, fc_layer_size).to(device, dtype=torch.double)

optimiser_ori = torch.optim.Adam(meta_model_origin.parameters(), lr=args.meta_lr)
optimiser_new = torch.optim.Adam(meta_model_new.parameters(), lr=args.meta_lr)
conv_optimiser_new = torch.optim.Adam(meta_model_new.conv_param, lr=args.meta_lr)
other_optimiser_new = torch.optim.Adam(meta_model_new.other_param, lr=args.meta_lr)

loss_fn_ori = nn.CrossEntropyLoss().to(device)
loss_fn_new = nn.CrossEntropyLoss().to(device)

In [6]:
def prepare_meta_batch(n, k, q, meta_batch_size):
    def prepare_meta_batch_(batch):
        x, y = batch

        # Reshape to `meta_batch_size` number of tasks. Each task contains
        # n*k support samples to train the fast model on and q*k query samples to
        # evaluate the fast model on and generate meta-gradients
        x = x.reshape(meta_batch_size, n*k + q*k, num_input_channels, x.shape[-2], x.shape[-1])
        # Move to device
        x = x.double().to(device)
        # Create label
        y = create_nshot_task_label(k, q).cuda().repeat(meta_batch_size)
        return x, y

    return prepare_meta_batch_

In [7]:
dataloader=background_taskloader,
prepare_batch=prepare_meta_batch(args.n, args.k, args.q, args.meta_batch_size)
fit_function=meta_gradient_step,
fit_function_kwargs={'n_shot': args.n, 'k_way': args.k, 'q_queries': args.q,
                     'train': True,
                     'order': args.order, 'device': device, 'inner_train_steps': args.inner_train_steps,
                     'inner_lr': args.inner_lr}
    
for epoch in range(1, args.epochs+1):
    for batch_index, batch in enumerate(dataloader[0]):
        x, y = prepare_batch(batch)

#         loss_origin, y_pred_origin = fit_function[0](meta_model_origin, optimiser_origin, loss_fn_origin, 
#                                                      x, y, origin=True, **fit_function_kwargs)
        
        loss_new, y_pred_new = fit_function[0](meta_model_new, optimiser_new, loss_fn_new, x, y, origin=True, 
                                               other_optim=[conv_optimiser_new, other_optimiser_new], 
                                               p_task=[1,1,1,1], p_meta=[1,1,1,1], **fit_function_kwargs)
        print(meta_model_new.conv1[0].bias.grad[0])
        
        loss_new, y_pred_new = fit_function[0](meta_model_new, optimiser_new, loss_fn_new, x, y, origin=False, 
                                               other_optim=[conv_optimiser_new, other_optimiser_new], 
                                               p_task=[1,1,1,1], p_meta=[1,1,1,1], **fit_function_kwargs)
        print(meta_model_new.conv1[0].bias.grad[0])
        
#         print('loss_origin: ', loss_origin)
#         print('y_pred_origin', y_pred_origin)
#         print('loss_new: ', loss_new)
#         print('y_pred_new', y_pred_new)

torch.Size([1, 30, 3, 84, 84])
torch.Size([25])
original:  tensor([ 0.1233,  0.1583, -0.1633], device='cuda:1', dtype=torch.float64,
       grad_fn=<SliceBackward0>)
new:  tensor([ 0.1233,  0.1583, -0.1633], device='cuda:1', dtype=torch.float64,
       grad_fn=<SliceBackward0>)
torch.Size([1, 30, 3, 84, 84])
torch.Size([25])
original:  tensor([ 0.1233,  0.1583, -0.1633], device='cuda:1', dtype=torch.float64,
       grad_fn=<SliceBackward0>)
new:  tensor([ 0.1233,  0.1583, -0.1633], device='cuda:1', dtype=torch.float64,
       grad_fn=<SliceBackward0>)
torch.Size([1, 30, 3, 84, 84])
torch.Size([25])
original:  tensor([ 0.1233,  0.1583, -0.1633], device='cuda:1', dtype=torch.float64,
       grad_fn=<SliceBackward0>)
new:  tensor([ 0.1233,  0.1583, -0.1633], device='cuda:1', dtype=torch.float64,
       grad_fn=<SliceBackward0>)
torch.Size([1, 30, 3, 84, 84])
torch.Size([25])
original:  tensor([ 0.1233,  0.1583, -0.1633], device='cuda:1', dtype=torch.float64,
       grad_fn=<SliceBackward0