<a href="https://colab.research.google.com/github/Tianananana/MAML-Naive-Walkthrough/blob/main/v0_MAML_P%26P_(Manual).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Implementation of MAML with the Single Neuron Neural Network Model. v0 uses manual calculation of the gradient.


In [None]:
import torch
import numpy as np

## Create model

In [None]:
class SingleNet():
  def __init__(self):
    psi = torch.tensor([1.], requires_grad=True) # task specific initializati
    self.weight = psi # set initial weight to 1
  
  def __call__(self, x):
    return self.weight * x

## Create datasets

In [None]:
# We define 2 datasets for our case. 
D1 = {'query': torch.tensor([(1, 2)], dtype=torch.int64), 'support': torch.tensor([(2, 4), (3, 1)])} # (x, y) pairs for query (Q1) and support (S1) set.
D2 = {'query': torch.tensor([(4, 1)], dtype=torch.int64), 'support': torch.tensor([(5, 3), (6, 0)])}
D_all = [D1, D2]

## Loss function

In [None]:
""" LOSS FUNCTIONS """
def loss(weight, dataset, mode='train'):
  '''
  Regression loss over dataset
  '''
  # Eqn 5
  if mode == 'train':
    data = dataset['support']
  if mode == 'test':
    data = dataset['query']
  print(weight)
  return torch.sum((data[:, 1] - weight * data[:, 0])**2)

In [None]:
""" TASK SPECIFIC FUNCTIONS """
def inner_gradient(net, dataset, mode='train', weight=0):
  # Eqn 6
  if mode == 'train':
    data = dataset['support']
    weight = net.weight
  if mode == 'test':
    data = dataset['query']
    weight = weight # manually pass in task-specific weights

  # Manual formula for task-specific gradient
  task_specific_gradients = -2 * torch.sum(data[:, 0] * (data[:, 1] - weight.item() * data[:, 0]))
  return task_specific_gradients

def inner_weight(net, dataset, alpha=0.1):
  ''' Compute task-specific (inner) weights on the support set
  param net: meta weight
  param dataset:
  param alpha:
  '''
  # Eqn 7
  task_specific_weight = net.weight - alpha * inner_gradient(net, dataset, mode='train')
  return task_specific_weight

In [None]:
""" META FUNCTIONS """

def meta_gradient(net, dataset, weight, mode='test', alpha=0.1):
  # Eqn 12
  support = dataset['support']

  # Manual gradient formula of L wrt phi_j
  ddag = inner_gradient(net, dataset, mode='test', weight=weight)
  # Manual gradient formula of phi_j wrt theta_0
  dag = torch.tensor([1]) - 2 * alpha * torch.sum(support[:, 0]**2)
  # Compute using chain rule
  new_meta_gradient = ddag * dag
  return new_meta_gradient


def meta_weight(net, datasets, weight, beta=0.5):
  '''
  Compute meta (outer) weights on the query set
  param weight: weight vector of task specific weights
  '''
  # Eqn 11
  all_meta_gradient = torch.tensor(list((map(lambda i: meta_gradient(net, datasets[i], weight=weight[i], mode='test'), list(range(len(datasets)))))))
  net.weight = net.weight - (beta * torch.sum(all_meta_gradient))
  return net.weight

In [None]:
## UNIT TEST ##
net = SingleNet()
# Output task-specific gradient
inner_grad1 = inner_gradient(net, D1, mode='train')
inner_grad2 = inner_gradient(net, D2, mode='train')
print(f"{inner_grad1=}")
print(f"{inner_grad2=}")
print("")

# Output task-specific weights
w1 = inner_weight(net, D1)
w2 = inner_weight(net, D2)
print(f"{w1=}")
print(f"{w2=}")
print("")

# Output meta gradient
meta_grad1 = meta_gradient(net, D1, w1, mode='test', alpha=0.1)
meta_grad2 = meta_gradient(net, D2, w2, mode='test', alpha=0.1)
print(f"{meta_grad1=}")
print(f"{meta_grad2=}")
print("")

# Output new meta weight
theta_1 = meta_weight(net, D_all, [w1, w2])
print(f"{theta_1=}")
print("")

inner_grad1=tensor(4.)
inner_grad2=tensor(92.)

w1=tensor([0.6000], grad_fn=<SubBackward0>)
w2=tensor([-8.2000], grad_fn=<SubBackward0>)

meta_grad1=tensor([4.4800])
meta_grad2=tensor([3028.4800])

theta_1=tensor([-1515.4800], grad_fn=<SubBackward0>)

