In [None]:
# Ref:
# https://github.com/sudharsan13296/Hands-On-Meta-Learning-With-Python/tree/master/06.%20MAML%20and%20it's%20Variants

![title](Algorithm.PNG)

In [1]:
import numpy as np

In [2]:
def sample_points(k):
    x = np.random.rand(k,50)
    y = np.random.choice([0, 1], size=k, p=[.5, .5]).reshape([-1,1])
    return x,y

In [3]:
sample_points(2)

(array([[0.28990196, 0.90793764, 0.23421841, 0.67913086, 0.76418995,
         0.18841729, 0.12222369, 0.55810223, 0.24859761, 0.32549487,
         0.39976121, 0.80769984, 0.38958264, 0.57690576, 0.49416128,
         0.90208894, 0.62708224, 0.05897736, 0.4249155 , 0.25914113,
         0.33541824, 0.21043635, 0.34308769, 0.13740121, 0.98087293,
         0.460416  , 0.67733664, 0.50000008, 0.57856388, 0.82864749,
         0.89001534, 0.0418834 , 0.3681971 , 0.22715244, 0.03816675,
         0.90126498, 0.35985741, 0.48516955, 0.32467412, 0.7556343 ,
         0.0679308 , 0.44552869, 0.31944692, 0.8317235 , 0.26212497,
         0.68972078, 0.91044049, 0.05556849, 0.94329861, 0.97070162],
        [0.58815487, 0.58490588, 0.49850777, 0.67846093, 0.12363671,
         0.02581901, 0.40651341, 0.93896464, 0.01332282, 0.26109609,
         0.3754193 , 0.15806783, 0.76559732, 0.68129808, 0.12001497,
         0.7535202 , 0.47517047, 0.61089555, 0.94064037, 0.52245568,
         0.80133887, 0.32472378, 

In [4]:
np.matmul(np.array([[1, 2]]).reshape(1, 2), np.array([[2, 2]]).reshape(2, 1))

array([[6]])

In [5]:
class MAML(object):
    def __init__(self):
        
        #initialize number of tasks i.e number of tasks we need in each batch of tasks
        self.num_tasks = 10
        
        #number of samples i.e number of shots  -number of data points (k) we need to have in each task
        self.num_samples = 10

        #number of epochs i.e training iterations
        self.epochs = 1000
        
        #hyperparameter for the inner loop (inner gradient update)
        self.alpha = 0.001
        
        #hyperparameter for the outer loop (outer gradient update) i.e meta optimization
        self.beta = 0.001
        
        #randomly initialize our model parameter theta
        self.theta = np.random.normal(size=50).reshape(50, 1)

    #define our sigmoid activation function  
    def sigmoid(self,a):
        return 1.0 / (1 + np.exp(-a))
    
    #now let us get to the interesting part i.e training :P
    def train(self):
        
        #for the number of epochs,
        for e in range(self.epochs):        
            
            self.theta_ = []
            
            #for task i in batch of tasks
            for i in range(self.num_tasks):
                
                #sample k data points and prepare our train set
                XTrain, YTrain = sample_points(self.num_samples)
                
                a = np.matmul(XTrain, self.theta)

                YHat = self.sigmoid(a)

                #since we are performing classification, we use cross entropy loss as our loss function
                loss = ((np.matmul(-YTrain.T, np.log(YHat)) - np.matmul((1 -YTrain.T), np.log(1 - YHat)))/self.num_samples)[0][0]
                
                #minimize the loss by calculating gradients
                gradient = np.matmul(XTrain.T, (YHat - YTrain)) / self.num_samples

                #update the gradients and find the optimal parameter theta' for each of tasks
                self.theta_.append(self.theta - self.alpha*gradient)
                
                
            #initialize meta gradients
            meta_gradient = np.zeros(self.theta.shape)
                        
            for i in range(self.num_tasks):
            
                #sample k data points and prepare our test set for meta training
                XTest, YTest = sample_points(10)

                #predict the value of y
                a = np.matmul(XTest, self.theta_[i])
                
                YPred = self.sigmoid(a)
                           
                #compute meta gradients
                meta_gradient += np.matmul(XTest.T, (YPred - YTest)) / self.num_samples

                
            #update our randomly initialized model parameter theta with the meta gradients
            self.theta = self.theta-self.beta*meta_gradient/self.num_tasks
                                       
            if e%200==0:
                print("Epoch {}: Loss {}\n".format(e,loss))             
                print('Updated Model Parameter Theta\n') 
                print('Sampling Next Batch of Tasks \n')
                print('---------------------------------\n')
        return self.theta

In [6]:
model = MAML()

In [7]:
model.train()

Epoch 0: Loss 1.471125890239991

Updated Model Parameter Theta

Sampling Next Batch of Tasks 

---------------------------------

Epoch 200: Loss 0.971124679744533

Updated Model Parameter Theta

Sampling Next Batch of Tasks 

---------------------------------

Epoch 400: Loss 1.0434002569741918

Updated Model Parameter Theta

Sampling Next Batch of Tasks 

---------------------------------

Epoch 600: Loss 0.5292412384581618

Updated Model Parameter Theta

Sampling Next Batch of Tasks 

---------------------------------

Epoch 800: Loss 0.5666395851080482

Updated Model Parameter Theta

Sampling Next Batch of Tasks 

---------------------------------



array([[-1.03845523],
       [ 0.02096362],
       [ 0.45813487],
       [ 0.38818897],
       [ 1.06218167],
       [ 0.99728921],
       [ 0.22654533],
       [ 0.03183934],
       [-0.35730649],
       [-0.58902643],
       [-0.72061311],
       [ 1.06188063],
       [ 1.72587482],
       [ 0.11144021],
       [-0.43881127],
       [-1.44020855],
       [-1.12379434],
       [ 0.35574965],
       [-0.8169855 ],
       [-0.54429624],
       [ 0.23368033],
       [-0.71888756],
       [ 0.52729837],
       [-1.78481406],
       [-1.02497378],
       [ 2.07476839],
       [ 0.28599288],
       [-1.092588  ],
       [-0.42452784],
       [ 1.14156754],
       [-0.81809066],
       [-1.28336159],
       [ 1.97591674],
       [-1.10784741],
       [ 0.24784227],
       [-0.34259046],
       [ 0.98150067],
       [ 0.45907924],
       [ 1.02645133],
       [-0.49365937],
       [-2.47417796],
       [ 1.64296297],
       [-0.61929726],
       [ 1.51158755],
       [-0.93139109],
       [ 0