<a href="https://colab.research.google.com/github/Tianananana/MAML-Naive-Walkthrough/blob/main/v2_MAML_P%26P_(Autograd).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [78]:
import torch
import numpy as np
import pdb
import copy
import json
import pandas as pd

## Classes

In [2]:
class SingleNet():
  def __init__(self):
    theta = torch.tensor([1.], requires_grad=True) # we will need the gradients of meta model weights.
    # theta.retain_grad()
    self.weight = theta # set initial weight to 1
    self.retain_gradient()
  
  def __call__(self, x, phij=None):
    if phij is None:
      return self.weight * x
    else:
      # manally pass in task-specific weights
      return phij * x
  
  def zero_grad(self):
    self.weight.grad = None # setting gradients to None deletes all previous computational tree.

  def retain_gradient(self):
    self.weight.retain_grad()

In [3]:
class Loss():
  def __init__(self):
    pass

  def L1loss(self, pred, targ):
    loss = torch.sum((pred - targ)**2)
    return loss

  def __call__(self, pred, targ):
    return self.L1loss(pred, targ)

In [93]:
class OptimSGD():
  def __init__(self, net, lr):
    self.net = net
    self.lr = lr

  def _SGD(self):
    # SGD formula
    self.net.weight = self.net.weight - self.lr * self.net.weight.grad
    self.net.retain_gradient()
    print(self.net.weight)


  def step(self):
    self._SGD()

## v1 New pipeline with SingleNet, Loss, and Optim class

### Define Inner Trainer: compute task-specific weights (phi_j)

In [87]:
class InnerTrainer:
  def __init__(self, net, alpha, criterion):
    self.net = net
    self.alpha = alpha
    self.loss = criterion
    self.tmp_weight = None # for k step inner update

  def _one_inner_epoch(self, x, y):
    if self.tmp_weight is None:
      # use theta (meta-model weights)
      pred = self.net(x) 
      curr_weight = self.net.weight
    else:
      # use intermmediate steps
      pred = self.net(x, self.tmp_weight) 
      curr_weight = self.tmp_weight

    loss = self.loss(pred, y)
    
    # computing phi_j
    gradient = torch.autograd.grad(loss, curr_weight, create_graph=True)[0]
    print(f"INNER GRADIENT: {gradient.item()}")
    self.tmp_weight = curr_weight - self.alpha * gradient
    print(f"PHI_J: {self.tmp_weight.item()}")

    return self.tmp_weight, gradient


  def k_inner_epoch(self, x, y, k):
    """ Returns phi_j """
    # Outputs both phi_j and phi_j gradients for logging.
    final_phi_j = None
    final_grad = None
    for i in range(k):
      final_phi_j, final_grad = self._one_inner_epoch(x, y)

    return final_phi_j, final_grad

### Define Meta Trainer: update meta weight (theta)

In [88]:
class MetaTrainer:
  def __init__(self, net, alpha, beta, criterion, k=1):
    self.net = net
    self.loss = criterion
    self.k = k
    self.alpha = alpha
    self.opt = OptimSGD(self.net, beta)

  def one_epoch(self, datasets):
    # loss_all = torch.tensor([0.])
    self.net.zero_grad()

    # Initialize dicts per dataset
    info['InnerGrad'], info['PhiJ'] = {}, {}
    info['L_q'], info['OuterGrad(Accum)'] = {}, {}

    for i in range(len(datasets)):
      # iterate over each dataset (D1, D2)
      print(f"RUNNING INNER LOOP OF D{i+1}")
      sx, sy = datasets[i]['support'][:, 0], datasets[i]['support'][:, 1]
      qx, qy = datasets[i]['query'][:, 0], datasets[i]['query'][:, 1]
      
      innerTrainer = InnerTrainer(net=self.net, alpha=self.alpha, criterion=self.loss)

      # Initialize list for each dataset
      phi_j, phi_j_grad = innerTrainer.k_inner_epoch(sx, sy, self.k)
      info['PhiJ'][f'D{i+1}'], info['InnerGrad'][f'D{i+1}'] = phi_j.item(), phi_j_grad.item()
        ## Check whether 2nd outer epoch calls tmp_weights on 1st inner training or use net.weight.

      pred = self.net(qx, phi_j)
      loss = self.loss(pred, qy)
      print(f"L_Q OF D{i+1}: {loss}")
      info['L_q'][f'D{i+1}'] = loss.item()

      loss.backward() # take gradients for each dataset separately
      print(f"ACCUM GRAD OF L_Q OF D{i+1}: {self.net.weight.grad}") #v2 TODO: think of a way to remove comp tree aft accumulating grads
      info['OuterGrad(Accum)'][f'D{i+1}'] = self.net.weight.grad.item()

    # update meta weights (theta)
    self.opt.step()

    print(f"NEW META WEIGHT {self.net.weight.item()}")
    info['NewWeight'] = self.net.weight.item()

In [89]:
alpha = 0.1
beta = 0.5
net = SingleNet()
L1 = Loss()

In [90]:
# We define 2 datasets for our case. 
D1 = {'query': torch.tensor([(1, 2)], dtype=torch.int64), 'support': torch.tensor([(2, 4), (3, 1)])} # (x, y) pairs for query (Q1) and support (S1) set.
D2 = {'query': torch.tensor([(4, 1)], dtype=torch.int64), 'support': torch.tensor([(5, 3), (6, 0)])}
D_all = [D1, D2]

In [None]:
trainer = MetaTrainer(net=net, alpha=alpha, beta=beta, criterion=L1, k=1)
# initialize dict to store in json file
log = {}
log_dir = './log'

for i in range(3):
  print(f"\nEPOCH {i}")
  info = {}
  info['Epoch'] = i
  trainer.one_epoch(D_all)

  # append to current json file. Save each epoch logs once.
  with open('./log.json', 'a') as log:
    json.dump(info, log)
    log.write('\n')


In [92]:
"""Read json output file into pandas DataFrame """
df = pd.read_json('/content/log.json', lines=True)
df

Unnamed: 0,Epoch,InnerGrad,PhiJ,L_q,OuterGrad(Accum),NewWeight
0,0,"{'D1': 4.0, 'D2': 92.0}","{'D1': 0.600000023841857, 'D2': -8.19999980926...","{'D1': 1.959999918937683, 'D2': 1142.43994140625}","{'D1': 4.479999542236328, 'D2': 3032.9599609375}",-1515.48
1,1,"{'D1': -39424.48046875, 'D2': -184918.5625}","{'D1': 2426.968017578125, 'D2': 16976.376953125}","{'D1': 5880470.0, 'D2': 4611022336.0}","{'D1': -7759.89794921875, 'D2': -6092004.0}",3044486.0
2,2,"{'D1': 79156624.0, 'D2': 371427328.0}","{'D1': -4871176.0, 'D2': -34098244.0}","{'D1': 23728375201792.0, 'D2': 1.8603043104751...","{'D1': 15587772.0, 'D2': 12236398592.0}",-6115155000.0


## v0 Old code (Loss functions, task-specific functions, meta functions)

In [None]:
""" LOSS FUNCTIONS """
def loss(weight, dataset, mode='train'):
  # Eqn 5: l1 loss
  if mode == 'train':
    data = dataset['support']
  if mode == 'test':
    data = dataset['query']
    # HK: loss function need to pass in only pred and target.
  return torch.sum((data[:, 1] - weight * data[:, 0])**2)

In [None]:
""" TASK SPECIFIC FUNCTIONS """
def inner_gradient(weight, dataset, mode='train'):
  # Eqn 6
  if mode == 'train':
    data = dataset['support']

  if mode == 'test':
    data = dataset['query']

  # HK: Only pass in dataset
  # add in loss variable
  

  # Use torch.autograd.grad
  return grad(loss(weight, dataset, mode=mode), weight, create_graph=True)[0]
  # return grad(loss(weight, dataset, mode=mode), weight)[0]
  
def inner_weight(weight, dataset, alpha=0.1, mode='train'):
  # Eqn 7
  task_specific_grads = inner_gradient(weight, dataset, mode=mode)
  return weight - alpha * task_specific_grads

In [None]:
""" META FUNCTIONS """
def meta_gradient_1(theta_0, dataset):
  """ dag term: derivative of phi_j wrt theta_0 """
  phi_j_test = inner_weight(theta_0, dataset, mode='test')
  # Use torch.autograd.grad
  grad_wrt_theta_0 = grad(phi_j_test, theta_0)[0]
  return grad_wrt_theta_0

def meta_gradient_2(phi_j, dataset):
  """ ddag term: derivative wrt phi_j """
  meta_loss = loss(phi_j, dataset, mode='test')
  # Use torch.autograd.grad
  grad_wrt_phi_j = grad(meta_loss, phi_j)[0]
  return grad_wrt_phi_j


def meta_gradient(theta_0, dataset):
  """ Getting meta_gradient by applying chain rule """
  dag_term = meta_gradient_1(theta_0, dataset)
  phi_j = inner_weight(theta_0, dataset, mode='train')
  ddag_term = meta_gradient_2(phi_j, dataset)
  # use the chain rule
  return dag_term * ddag_term

def meta_loss(theta_0, dataset, alpha=0.1, beta=0.5):
  """ Computing the meta loss with meta gradients """
  total_grads = torch.tensor([0.])
  for d in dataset:
    total_grads += meta_gradient(theta_0, d)
  return theta_0 - beta * total_grads


In [None]:
dataset = D1
net = SingleNet()
phi_j_test = inner_weight(net.weight, dataset, mode='test')
phi_j_train = inner_weight(net.weight, dataset, mode='train')
meta_loss = loss(phi_j_train, dataset, mode='test')
print(meta_loss)
meta_loss.backward()
phi_j_train.backward()
print(net.weight.grad)
print(phi_j_train)
print(meta_loss.grad_fn(phi_j_train))

tensor(1.9600, grad_fn=<SumBackward0>)
tensor([-1.8000])
tensor([0.6000], grad_fn=<SubBackward0>)
tensor([0.6000], grad_fn=<ExpandBackward0>)


In [None]:
## UNIT TEST ##
net = SingleNet()
# Output task-specific gradient
inner_grad1 = inner_gradient(net.weight, D1, mode='train')
inner_grad2 = inner_gradient(net.weight, D2, mode='train')
print(f"{inner_grad1=}")
print(f"{inner_grad2=}")
print("")

# Output task-specific weights
w1 = inner_weight(net.weight, D1)
w2 = inner_weight(net.weight, D2)
print(f"{w1=}")
print(f"{w2=}")
print("")

# Output meta gradient
meta_grad1 = meta_gradient(net.weight, D1)
meta_grad2 = meta_gradient(net.weight, D2)
print(f"{meta_grad1}")
print(f"{meta_grad2}")
print("")

# Output new meta weight
theta_1 = meta_loss(net.weight, D_all)
print(f"{theta_1=}")
print("")

inner_grad1=tensor([4.], grad_fn=<SumBackward1>)
inner_grad2=tensor([92.], grad_fn=<SumBackward1>)

w1=tensor([0.6000], grad_fn=<SubBackward0>)
w2=tensor([-8.2000], grad_fn=<SubBackward0>)

tensor([-2.2400])
tensor([594.8800])

theta_1=tensor([-295.3200], grad_fn=<SubBackward0>)

