In [1]:
import torch, math
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from nets import Qnetn, MLP_NN
from taskmeta import SINTASK, SINTASK2, SINTASK3, SIN_xlow, SIN_xhigh


device, dtype = 'cpu', torch.float32
def tensor(data, rgrad=False):
    return torch.tensor(data, device=device, dtype=dtype, requires_grad=rgrad)

GRS = 12
RNG = np.random.default_rng(GRS)
randseed = lambda : RNG.integers(1, 10_000)
xl, xh = SIN_xlow, SIN_xhigh

HH = [1, 25, 25, 25, 1]
lossF = lambda pp, yy: 0.5 * (pp-yy)**2
lossM = lambda pp, yy: torch.sum(0.5 * (pp - yy) ** 2) 

NN = lambda : MLP_NN(HH, device, dtype, actF=nn.ReLU, seed=randseed())
NP = lambda ext_params: MLP_NN(ext_params, device, dtype, actF=nn.ReLU, seed=randseed(), from_param=True)

# Meta Learning

In [2]:
taskerL = [
    SINTASK2((7, 3), seed=randseed()),
    SINTASK2((6, 2), seed=randseed()),
    SINTASK2((5, 1), seed=randseed()),
    SINTASK2((4, 0), seed=randseed()),
    SINTASK2((3, -1), seed=randseed()),
] 

# algorithm

In [3]:
outer_epochs =  1000
outer_lr =      0.001
inner_lr =      0.001
task_batch_size =    5
train_K =       8
test_K =        8
P = print

# training

In [4]:
P('# 1: randomly initialize theta ')
theta = NN()
theta.info( show_vals=True, P=P)

P('# 2: while not done ', outer_epochs )
for outer_epoch in range(outer_epochs):

    # 3: sample a batch of taskss
    task_batch = [ taskerL[i%len(taskerL)] for i in range(task_batch_size) ]
    P('# 3: outer_epoch:[{}] - sampled task batch: -({})-'.format(outer_epoch,task_batch))
    
    #P('# 4: for each task do')
    thetaL, dbL = [], []
    for tasker in task_batch:

        #P('# 5: sample train_K data points', train_K)
        train_db =tasker.sample(train_K)
        batch_x = torch.tensor(np.expand_dims(train_db[ :, 0 ],axis=-1), dtype=torch.float32)
        batch_y = torch.tensor(np.expand_dims(train_db[ :, 1 ],axis=-1), dtype=torch.float32)

        #P('# 6: evaluate grad_theta(loss)')
        #theta.zero_grad()
        pred = theta.forward(batch_x)
        loss =  lossM(pred, batch_y) #torch.sum((pred - batch_y) ** 2) 
        P('\t Loss:', loss.item())
        #loss.backward(create_graph=True)
        grads = torch.autograd.grad(loss, theta.parameters, create_graph=True)

        #P('# 7: compute adapted paramters with grad descent')
        theta_i = [] 
        for t_param, grad in zip(theta.parameters, grads):
            i_param = t_param - inner_lr * grad
            theta_i.append(i_param)
        thetaL.append(theta_i)
        theta.zero_grad()

        #P('# 8: sample test_K data points', test_K)
        test_db =tasker.sample(test_K)
        dbL.append(test_db)

    #P('# 9: end for (tasker)')

    #P('# 10: meta update')
    # grad_theta ( sum of all looses )
    oL = []
    for Ti, Di in zip(thetaL, dbL):
        dx = torch.tensor(np.expand_dims(Di[ :, 0 ],axis=-1), dtype=torch.float32)
        dy = torch.tensor(np.expand_dims(Di[ :, 1 ],axis=-1), dtype=torch.float32)

        temp = NP(Ti)
        xpred = temp.forward(dx)
        xloss =  lossM(xpred, dy) #torch.sum((xpred - dy) ** 2) 
        oL.append(xloss)
        
    oloss = torch.sum(torch.stack(oL))
    ograds = torch.autograd.grad(oloss, theta.parameters, create_graph=False)
    print('Outer loss:', oloss.item())
    #print('Outer grads:', len(ograds))
    
    #P('... update outer params')
    with torch.no_grad():
        #grad_sum = torch.sum(ograds)
        #P('Outer-Grads:', grad_sum.item())
        for t_param, grad in zip(theta.parameters, ograds):
            #print( 'grad-shapes', t_param , grad )
            t_param -= outer_lr * grad
            #print( 'After', t_param )

P('# 11: end for')


# 1: randomly initialize theta 
--------------------------
~ N_LAYERS:[4]
~ D_TYPE:[torch.float32]
~ DEV:[cpu]
--------------------------
--> Weights[0]:: Params[25] of Shape[torch.Size([25, 1])]
 ~--> [PARAMETER TENSOR]: tensor([[ 4.0752e-02],
        [-8.0139e-02],
        [-2.0934e-03],
        [ 9.1835e-02],
        [-6.5284e-05],
        [ 5.8189e-02],
        [-8.6809e-02],
        [ 2.3397e-02],
        [ 6.8264e-02],
        [-1.3586e-02],
        [ 2.6145e-02],
        [-1.7595e-02],
        [ 3.8888e-02],
        [-8.4653e-02],
        [ 2.8353e-02],
        [-5.0440e-02],
        [-5.0463e-02],
        [-2.3697e-02],
        [-6.8030e-02],
        [-3.7875e-02],
        [-2.6412e-02],
        [ 5.6413e-04],
        [-8.0990e-02],
        [-4.2174e-03],
        [-5.4436e-02]], requires_grad=True)
--> Bias[0]:: Params[25] of Shape[torch.Size([25])]
 ~--> [PARAMETER TENSOR]: tensor([-0.0140, -0.0385,  0.0818, -0.0299, -0.0257,  0.0169, -0.0935,  0.0139,
         0.0154,  0.0660

In [5]:

theta.save_external('meta_pie')
theta.info()

--------------------------
~ N_LAYERS:[4]
~ D_TYPE:[torch.float32]
~ DEV:[cpu]
--------------------------
--> Weights[0]:: Params[25] of Shape[torch.Size([25, 1])]
--> Bias[0]:: Params[25] of Shape[torch.Size([25])]
--> Weights[1]:: Params[625] of Shape[torch.Size([25, 25])]
--> Bias[1]:: Params[25] of Shape[torch.Size([25])]
--> Weights[2]:: Params[625] of Shape[torch.Size([25, 25])]
--> Bias[2]:: Params[25] of Shape[torch.Size([25])]
--> Weights[3]:: Params[25] of Shape[torch.Size([1, 25])]
--> Bias[3]:: Params[1] of Shape[torch.Size([1])]
--------------------------
PARAMS:	 1,376
--------------------------


1376

In [6]:
theta.info(show_vals=True)

--------------------------
~ N_LAYERS:[4]
~ D_TYPE:[torch.float32]
~ DEV:[cpu]
--------------------------
--> Weights[0]:: Params[25] of Shape[torch.Size([25, 1])]
 ~--> [PARAMETER TENSOR]: tensor([[ 2.4490e-03],
        [-8.0139e-02],
        [-2.5064e-01],
        [ 9.5909e-01],
        [-6.5284e-05],
        [ 4.8062e-01],
        [-8.6809e-02],
        [ 6.5355e-03],
        [ 4.4192e-01],
        [-7.2053e-01],
        [ 8.9951e-01],
        [-1.7595e-02],
        [-2.0167e-02],
        [-1.3581e-01],
        [-2.8878e-01],
        [-5.0440e-02],
        [-5.0478e-02],
        [-2.3697e-02],
        [-6.8030e-02],
        [-4.1850e-01],
        [-2.6412e-02],
        [ 5.6413e-04],
        [-8.6678e-02],
        [-4.2174e-03],
        [-2.6082e-01]], requires_grad=True)
--> Bias[0]:: Params[25] of Shape[torch.Size([25])]
 ~--> [PARAMETER TENSOR]: tensor([-1.6234e-02, -3.8518e-02,  9.0392e-01, -2.0014e+00, -2.5732e-02,
        -1.0545e+00, -9.3511e-02, -6.5731e-02, -9.8871e-01,  6.

1376