<a href="https://colab.research.google.com/github/arnav-pati/Brain-Age-Prediction/blob/main/Reptile_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn, optim, autograd
from torch.autograd import Variable
from torch.nn import functional as F
import numpy as np

In [5]:
class Learner(nn.Module):
    '''
    It stores a specific nn.Module class
    '''

    def __init__(self, net_class, *args) -> None:
        '''
        net_class is a class, not an instance
        args: the parameters for net_class
        '''
        super(Learner, self).__init__()
        assert net_class.__class__ == type

        self.net = net_class(*args)
        self.net_pi = net_class(*args)
        self.learner_lr = 0.1
        self.optimizer = optim.SGD(self.net_pi.parameters(), self.learner_lr)
    
    def parameters(self):
        '''
        ignore self.net_pi.parameters()
        '''
        return self.net.parameters()
    
    def update_pi(self):
        for m_from, m_to in zip(self.net.modules(), self.net_pi.modules()):
            # Check again for the model
            if isinstance(m_to, nn.Linear) or isinstance(m_to, nn.Conv2d) or isinstance(m_to, nn.BatchNorm2d):
                m_to.weight.data = m_from.weight.data.clone()
                if m_to.bias is not None:
                    m_to.bias.data = m_from.bias.data.clone()
    
    def forward(self, support_x, support_y, query_x, query_y, num_updates):
        self.update_pi()
        for i in range(num_updates):
            loss, pred = self.net_pi(support_x, support_y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        loss, pred = self.net_pi(query_x, query_y)
        indices = torch.argmax(pred, dim=1)
        correct = torch.eq(indices, query_y).sum().item()
        acc = correct / query_y.size(0)
        
        grads_pi = autograd.grad(loss, self.net_pi.parameters(), create_graph=True)
        return loss, grads_pi, acc
    
    def net_forward(self, support_x, support_y):
        loss, pred = self.net(support_x, support_y)
        return loss, pred

In [None]:
class MetaLearner(nn.Module):
    def __init__(self, net_class, net_class_args, n_way, k_shot, meta_batchesz, beta, num_updates) -> None:
        super(MetaLearner, self).__init__()
        self.n_way = n_way
        self.k_shot = k_shot
        self.meta_batchesz = meta_batchesz
        self.beta = beta
        self.num_updates = num_updates

        self.learner = Learner(net_class, *net_class_args)
        self.optimizer = optim.Adam(self.learner.parameters(), lr=beta)
    
    def write_grads(self, dummy_loss, sum_grads_pi):
        hooks = []
        for i, v in enumerate(self.learner.parameters()):
            h = v.register_hook(lambda grad : sum_grads_pi[i])
            hooks.append(h)
        
        self.optimizer.zero_grad()
        dummy_loss.backward()
        self.optimizer.step()

        for h in hooks:
            h.remove()
    
    def forward(self, support_x, support_y, query_x, query_y):
        