# Original Meta

In [1]:
import torch
from collections import OrderedDict
from torch.optim import Optimizer
from torch.nn import Module
from typing import Dict, List, Callable, Union

from few_shot.core import create_nshot_task_label


def replace_grad(parameter_gradients, parameter_name):
    def replace_grad_(module):
        return parameter_gradients[parameter_name]

    return replace_grad_


def meta_gradient_step_ori(model: Module,
                       optimiser: Optimizer,
                       loss_fn: Callable,
                       x: torch.Tensor,
                       y: torch.Tensor,
                       n_shot: int,
                       k_way: int,
                       q_queries: int,
                       order: int,
                       inner_train_steps: int,
                       inner_lr: float,
                       train: bool,
                       device: Union[str, torch.device]):
    """
    Perform a gradient step on a meta-learner.

    # Arguments
        model: Base model of the meta-learner being trained
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        x: Input samples for all few shot tasks
        y: Input labels of all few shot tasks
        n_shot: Number of examples per class in the support set of each task
        k_way: Number of classes in the few shot classification task of each task
        q_queries: Number of examples per class in the query set of each task. The query set is used to calculate
            meta-gradients after applying the update to
        order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the
            query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated
            weights on the query with respect to the original weights).
        inner_train_steps: Number of gradient steps to fit the fast weights during each inner update
        inner_lr: Learning rate used to update the fast weights on the inner update
        train: Whether to update the meta-learner weights at the end of the episode.
        device: Device on which to run computation
    """
    data_shape = x.shape[2:]
    create_graph = (True if order == 2 else False) and train

    task_gradients = []
    task_losses = []
    task_predictions = []
    for meta_batch in x:
        # By construction x is a 5D tensor of shape: (meta_batch_size, n*k + q*k, channels, width, height)
        # Hence when we iterate over the first  dimension we are iterating through the meta batches
        x_task_train = meta_batch[:n_shot * k_way]
        x_task_val = meta_batch[n_shot * k_way:]

        # Create a fast model using the current meta model weights
        fast_weights = OrderedDict(model.named_parameters())

        # Train the model for `inner_train_steps` iterations
        for inner_batch in range(inner_train_steps):
            # Perform update of model weights
            y = create_nshot_task_label(k_way, n_shot).to(device)
            logits = model.functional_forward(x_task_train, fast_weights)
            loss = loss_fn(logits, y)
            gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)

            # Update weights manually
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), gradients)
            )

        # Do a pass of the model on the validation data from the current task
        y = create_nshot_task_label(k_way, q_queries).to(device)
        logits = model.functional_forward(x_task_val, fast_weights)
        loss = loss_fn(logits, y)
        loss.backward(retain_graph=True)

        # Get post-update accuracies
        y_pred = logits.softmax(dim=1)
        task_predictions.append(y_pred)

        # Accumulate losses and gradients
        task_losses.append(loss)
        gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)
        named_grads = {name: g for ((name, _), g) in zip(fast_weights.items(), gradients)}
        task_gradients.append(named_grads)

    if order == 1:
        if train:
            sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
                                  for k in task_gradients[0].keys()}
            hooks = []
            for name, param in model.named_parameters():
                hooks.append(
                    param.register_hook(replace_grad(sum_task_gradients, name))
                )

            model.train()
            optimiser.zero_grad()
            # Dummy pass in order to create `loss` variable
            # Replace dummy gradients with mean task gradients using hooks
            logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
            loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
            loss.backward()
            optimiser.step()

            for h in hooks:
                h.remove()

        return torch.stack(task_losses).mean(), torch.cat(task_predictions)

    elif order == 2:
        model.train()
        optimiser.zero_grad()
        meta_batch_loss = torch.stack(task_losses).mean()

        if train:
            meta_batch_loss.backward()
            optimiser.step()

        return meta_batch_loss, torch.cat(task_predictions)
    else:
        raise ValueError('Order must be either 1 or 2.')

In [2]:
from torch import nn
import numpy as np
import torch.nn.functional as F
import torch
from typing import Dict

def conv_block(in_channels: int, out_channels: int) -> nn.Module:
    """Returns a Module that performs 3x3 convolution, ReLu activation, 2x2 max pooling.
    # Arguments
        in_channels:
        out_channels:
    """
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
    )


def functional_conv_block(x: torch.Tensor, weights: torch.Tensor, biases: torch.Tensor,
                          bn_weights, bn_biases) -> torch.Tensor:
    """Performs 3x3 convolution, ReLu activation, 2x2 max pooling in a functional fashion.
    # Arguments:
        x: Input Tensor for the conv block
        weights: Weights for the convolutional block
        biases: Biases for the convolutional block
        bn_weights:
        bn_biases:
    """
    x = F.conv2d(x, weights, biases, padding=1)
    x = F.batch_norm(x, running_mean=None, running_var=None, weight=bn_weights, bias=bn_biases, training=True)
    x = F.relu(x)
    x = F.max_pool2d(x, kernel_size=2, stride=2)
    return x

class FewShotClassifierOri(nn.Module):
    def __init__(self, num_input_channels: int, k_way: int, final_layer_size: int = 64):
        """Creates a few shot classifier as used in MAML.
        This network should be identical to the one created by `get_few_shot_encoder` but with a
        classification layer on top.
        # Arguments:
            num_input_channels: Number of color channels the model expects input data to contain. Omniglot = 1,
                miniImageNet = 3
            k_way: Number of classes the model will discriminate between
            final_layer_size: 64 for Omniglot, 1600 for miniImageNet
        """
        super(FewShotClassifierOri, self).__init__()
        self.conv1 = conv_block(num_input_channels, 64)
        self.conv2 = conv_block(64, 64)
        self.conv3 = conv_block(64, 64)
        self.conv4 = conv_block(64, 64)

        self.logits = nn.Linear(final_layer_size, k_way)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        x = x.view(x.size(0), -1)

        return self.logits(x)

    def functional_forward(self, x, weights):
        """Applies the same forward pass using PyTorch functional operators using a specified set of weights."""

        for block in [1, 2, 3, 4]:
            x = functional_conv_block(x, weights[f'conv{block}.0.weight'], weights[f'conv{block}.0.bias'],
                                      weights.get(f'conv{block}.1.weight'), weights.get(f'conv{block}.1.bias'))

        x = x.view(x.size(0), -1)

        x = F.linear(x, weights['logits.weight'], weights['logits.bias'])

        return x

# Our Evaluation

In [3]:
from torch.utils.data import DataLoader
from torch import nn

from few_shot.datasets import OmniglotDataset, MiniImageNet
from few_shot.core import NShotTaskSampler, create_nshot_task_label, EvaluateFewShot
from few_shot.maml import meta_gradient_step
from few_shot.models import FewShotClassifier
from few_shot.train import fit
from few_shot.callbacks import *
from few_shot.utils import setup_dirs
from config import PATH

In [4]:
class Args:
    n = 1
    k = 5
    q = 5
    inner_train_steps = 2
    inner_val_steps = 1
    inner_lr = 0.01
    meta_lr = 0.01
    meta_batch_size = 1
    order = 2
    epochs = 1
    epoch_len = 2
    eval_batches = 1
    
    dataset = 'miniImageNet'
#     dataset = 'omniglot'
    gpu = 1
    
args = Args()
    
assert torch.cuda.is_available()
torch.cuda.set_device(args.gpu)
device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

In [5]:
if args.dataset == 'omniglot':
    dataset_class = OmniglotDataset
    fc_layer_size = 64
    num_input_channels = 1
elif args.dataset == 'miniImageNet':
    dataset_class = MiniImageNet
    fc_layer_size = 1600
    num_input_channels = 3
else:
    raise(ValueError('Unsupported dataset'))

In [6]:
###################
# Create datasets #
###################

background = dataset_class('background')
background_taskloader = DataLoader(
    background,
    batch_sampler=NShotTaskSampler(background, args.epoch_len, n=args.n, k=args.k, q=args.q,
                                   num_tasks=args.meta_batch_size),
    num_workers=8
)
evaluation = dataset_class('evaluation')
evaluation_taskloader = DataLoader(
    evaluation,
    batch_sampler=NShotTaskSampler(evaluation, args.eval_batches, n=args.n, k=args.k, q=args.q,
                                   num_tasks=args.meta_batch_size),
    num_workers=8
)

Indexing background...


48000it [00:00, 550665.45it/s]


Indexing evaluation...


12000it [00:00, 579423.80it/s]


In [7]:
meta_model_ori = FewShotClassifierOri(num_input_channels, args.k, fc_layer_size).to(device, dtype=torch.double)
meta_model_new = FewShotClassifierOri(num_input_channels, args.k, fc_layer_size).to(device, dtype=torch.double)
# meta_model_new = FewShotClassifier(num_input_channels, args.k, fc_layer_size).to(device, dtype=torch.double)

# optimiser_ori = torch.optim.Adam(meta_model_ori.parameters(), lr=args.meta_lr)
# optimiser_new = torch.optim.Adam(meta_model_new.parameters(), lr=args.meta_lr)
# conv_optimiser_new = torch.optim.Adam(meta_model_new.conv_param, lr=args.meta_lr)
# other_optimiser_new = torch.optim.Adam(meta_model_new.other_param, lr=args.meta_lr)

optimiser_ori = torch.optim.SGD(meta_model_ori.parameters(), lr=args.meta_lr)
optimiser_new = torch.optim.SGD(meta_model_new.parameters(), lr=args.meta_lr)
# conv_optimiser_new = torch.optim.SGD(meta_model_new.conv_param, lr=args.meta_lr)
# other_optimiser_new = torch.optim.SGD(meta_model_new.other_param, lr=args.meta_lr)


loss_fn_ori = nn.CrossEntropyLoss().to(device)
loss_fn_new = nn.CrossEntropyLoss().to(device)

In [8]:
# make network parameters consistent
for (n, p) in meta_model_ori.named_parameters():
    print('n: ', n)
    for i in range(len(n)):
        if n[i] == '.' and n[i + 1].isdigit():
            module1 = getattr(meta_model_ori, n[:i])[int(n[i+1])]
            module2 = getattr(meta_model_new, n[:i])[int(n[i+1])]
            getattr(module2, n[i+3:]).data.copy_(getattr(module1, n[i+3:]))
            break
    if 'logits' in n:
        module1 = getattr(meta_model_ori, 'logits')
        module2 = getattr(meta_model_new, 'logits')
        getattr(module2, n[7:]).data.copy_(getattr(module1, n[7:]))

n:  conv1.0.weight
n:  conv1.0.bias
n:  conv1.1.weight
n:  conv1.1.bias
n:  conv2.0.weight
n:  conv2.0.bias
n:  conv2.1.weight
n:  conv2.1.bias
n:  conv3.0.weight
n:  conv3.0.bias
n:  conv3.1.weight
n:  conv3.1.bias
n:  conv4.0.weight
n:  conv4.0.bias
n:  conv4.1.weight
n:  conv4.1.bias
n:  logits.weight
n:  logits.bias


In [9]:
# print('origin: ', OrderedDict(meta_model_ori.named_parameters())['logits.weight'][0,0])
# print('new: ', meta_model_new.other_param[8][0,0])
# print('origin: ', OrderedDict(meta_model_ori.named_parameters())['logits.bias'][0])
# print('new: ', meta_model_new.other_param[9][0])

In [10]:
def prepare_meta_batch(n, k, q, meta_batch_size):
    def prepare_meta_batch_(batch):
        x, y = batch

        # Reshape to `meta_batch_size` number of tasks. Each task contains
        # n*k support samples to train the fast model on and q*k query samples to
        # evaluate the fast model on and generate meta-gradients
        x = x.reshape(meta_batch_size, n*k + q*k, num_input_channels, x.shape[-2], x.shape[-1])
        # Move to device
        x = x.double().to(device)
        # Create label
        y = create_nshot_task_label(k, q).cuda().repeat(meta_batch_size)
        return x, y

    return prepare_meta_batch_

In [11]:
dataloader=background_taskloader,
prepare_batch=prepare_meta_batch(args.n, args.k, args.q, args.meta_batch_size)
fit_function=meta_gradient_step,
fit_function_kwargs={'n_shot': args.n, 'k_way': args.k, 'q_queries': args.q,
                     'train': True,
                     'order': args.order, 'device': device, 'inner_train_steps': args.inner_train_steps,
                     'inner_lr': args.inner_lr}
    
for epoch in range(1, args.epochs+1):
    for batch_index, batch in enumerate(dataloader[0]):
        x, y = prepare_batch(batch)
        x.requires_grad = False
        y.requires_grad = False

        loss_ori, y_pred_ori = meta_gradient_step_ori(meta_model_ori, optimiser_ori, loss_fn_ori, 
                                                     x, y, **fit_function_kwargs)
        
        print('origin grad: ', meta_model_ori.conv1[0].bias.grad[0])
        print('origin w: ', meta_model_ori.conv1[0].bias[0])
        
        
        loss_new, y_pred_new = meta_gradient_step_ori(meta_model_new, optimiser_new, loss_fn_new, 
                                                     x, y, **fit_function_kwargs)
#         loss_new, y_pred_new = fit_function[0](meta_model_new, optimiser_new, loss_fn_new, x, y, origin=False, 
#                                        other_optim=[conv_optimiser_new, other_optimiser_new], 
#                                        p_task=[1,1,1,1], p_meta=[1,1,1,1], **fit_function_kwargs)
    
        print('new grad: ', meta_model_new.conv1[0].bias.grad[0])
        print('new w: ', meta_model_new.conv1[0].bias[0])

#         loss_ori, y_pred_ori = fit_function[0](meta_model_new, optimiser_new, loss_fn_new, x, y, origin=False, 
#                                                other_optim=[conv_optimiser_new, other_optimiser_new], 
#                                                p_task=[1,1,1,1], p_meta=[1,1,1,1], **fit_function_kwargs)

#         loss_new, y_pred_new = fit_function[0](meta_model_new, optimiser_new, loss_fn_new, x, y, origin=True, 
#                                                other_optim=[conv_optimiser_new, other_optimiser_new], 
#                                                p_task=[1,1,1,1], p_meta=[1,1,1,1], **fit_function_kwargs)

#         loss_ori, y_pred_ori = meta_gradient_step_ori(meta_model_ori, optimiser_ori, loss_fn_ori,
#                                                   x, y, **fit_function_kwargs)

#         print('old_grad')
#         print(meta_model_ori.conv1[0].bias[0], meta_model_ori.conv1[0].bias.grad[0])

#         loss_new, y_pred_new = meta_gradient_step_ori(meta_model_new, optimiser_new, loss_fn_new,
#                                                   x, y, **fit_function_kwargs)

#         print('new_grad')
#         print(meta_model_ori.conv1[0].bias[0], meta_model_new.conv1[0].bias.grad[0])
        
# #         loss_new, y_pred_new = meta_gradient_step_ori(meta_model_ori, optimiser_new, loss_fn_new,
# #                                                   x, y, **fit_function_kwargs)

#         print('loss_ori: ', loss_ori)
# #         print('y_pred_ori', y_pred_ori)
#         print('loss_new: ', loss_new)
# #         print('y_pred_new', y_pred_new)

origin grad:  tensor(3.1442e-17, device='cuda:1', dtype=torch.float64)
origin w:  tensor(0.1569, device='cuda:1', dtype=torch.float64, grad_fn=<SelectBackward0>)
new grad:  tensor(3.1008e-17, device='cuda:1', dtype=torch.float64)
new w:  tensor(0.1569, device='cuda:1', dtype=torch.float64, grad_fn=<SelectBackward0>)
origin grad:  tensor(-4.4235e-17, device='cuda:1', dtype=torch.float64)
origin w:  tensor(0.1569, device='cuda:1', dtype=torch.float64, grad_fn=<SelectBackward0>)
new grad:  tensor(-5.0741e-17, device='cuda:1', dtype=torch.float64)
new w:  tensor(0.1569, device='cuda:1', dtype=torch.float64, grad_fn=<SelectBackward0>)


In [12]:
# different output
y_pred_new[0], y_pred_ori[0]

(tensor([0.1133, 0.1830, 0.1443, 0.4588, 0.1007], device='cuda:1',
        dtype=torch.float64, grad_fn=<SelectBackward0>),
 tensor([0.1133, 0.1830, 0.1443, 0.4588, 0.1007], device='cuda:1',
        dtype=torch.float64, grad_fn=<SelectBackward0>))

In [13]:
# parameters still consistent
meta_model_ori.conv1[0].weight[0, 0], meta_model_new.conv1[0].weight[0, 0]

(tensor([[ 0.1820,  0.1400,  0.0682],
         [ 0.0558, -0.1389,  0.1704],
         [ 0.1452, -0.1606,  0.1271]], device='cuda:1', dtype=torch.float64,
        grad_fn=<SelectBackward0>),
 tensor([[ 0.1820,  0.1400,  0.0682],
         [ 0.0558, -0.1389,  0.1704],
         [ 0.1452, -0.1606,  0.1271]], device='cuda:1', dtype=torch.float64,
        grad_fn=<SelectBackward0>))

In [14]:
meta_model_ori.conv1[0].bias.grad[0], meta_model_new.conv1[0].bias.grad[0]

(tensor(-4.4235e-17, device='cuda:1', dtype=torch.float64),
 tensor(-5.0741e-17, device='cuda:1', dtype=torch.float64))

In [15]:
for n, p in meta_model_new.named_parameters():
    print(n)

conv1.0.weight
conv1.0.bias
conv1.1.weight
conv1.1.bias
conv2.0.weight
conv2.0.bias
conv2.1.weight
conv2.1.bias
conv3.0.weight
conv3.0.bias
conv3.1.weight
conv3.1.bias
conv4.0.weight
conv4.0.bias
conv4.1.weight
conv4.1.bias
logits.weight
logits.bias


In [16]:
for n, p meta_model_ori.named_parameters():
    print(n)

SyntaxError: invalid syntax (2929393399.py, line 1)

In [None]:
meta_model_ori.conv1[0].bias.grad, meta_model_new.conv1[0].bias.grad

In [None]:
# make network parameters consistent
for (n, p) in meta_model_ori.named_parameters():
    for i in range(len(n)):
        if n[i] == '.' and n[i + 1].isdigit():
            module1 = getattr(meta_model_ori, n[:i])[int(n[i+1])]
            module2 = getattr(meta_model_new, n[:i])[int(n[i+1])]
            print(torch.all((getattr(module2, n[i+3:]).grad == (getattr(module1, n[i+3:])).grad)))
            break
    if 'logits' in n:
            module1 = getattr(meta_model_ori, 'logits')
            module2 = getattr(meta_model_new, 'logits')
            print(torch.all((getattr(module2, n[7:]).grad == (getattr(module1, n[7:])).grad)))