In [None]:
import torch, math
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from nets import Qnetn, MLP_NN
from taskmeta import SINTASK, SINTASK2, SINTASK3, SIN_xlow, SIN_xhigh


device, dtype = 'cpu', torch.float32
def tensor(data, rgrad=False):
    return torch.tensor(data, device=device, dtype=dtype, requires_grad=rgrad)

GRS = 12
RNG = np.random.default_rng(GRS)
randseed = lambda : RNG.integers(1, 10_000)
xl, xh = SIN_xlow, SIN_xhigh

HH = [1, 25, 25, 25, 1]
lossF = lambda pp, yy: 0.5 * (pp-yy)**2
lossM = lambda pp, yy: torch.sum(0.5 * (pp - yy) ** 2) 

NN = lambda : MLP_NN(HH, device, dtype, actF=nn.ReLU, seed=randseed())
NP = lambda ext_params: MLP_NN(ext_params, device, dtype, actF=nn.ReLU, seed=randseed(), from_param=True)

# Meta Learning

In [None]:
taskerL = [
    SINTASK2((7, 3), seed=randseed()),
    SINTASK2((6, 2), seed=randseed()),
    SINTASK2((5, 1), seed=randseed()),
    SINTASK2((4, 0), seed=randseed()),
    SINTASK2((3, -1), seed=randseed()),
] 
tasker = taskerL[0]

# algorithm

In [None]:
inner_epochs =  1000
inner_lr =      0.001
train_K =       8
P = print

In [None]:
# load theta
model = NN()
#model.load_external('reptile_pie') #<---- comment this to not use meta pie
model.info(show_vals=True)

In [None]:
db =tasker.space(train_K)
bx = torch.tensor(np.expand_dims(db[ :, 0 ],axis=-1), dtype=torch.float32)
by = torch.tensor(np.expand_dims(db[ :, 1 ],axis=-1), dtype=torch.float32)
bp = []
with torch.no_grad():
    pred=model.forward(bx)
    plt.plot(db[:,0], db[:,1], color='green')
    plt.plot(db[:,0], pred, color='blue')


# training

In [None]:


loss_hist=[]
for inner_epoch in range(inner_epochs):

    #P('# 5: sample train_K data points', train_K)
    train_db =tasker.sample(train_K)
    batch_x = torch.tensor(np.expand_dims(train_db[ :, 0 ],axis=-1), dtype=torch.float32)
    batch_y = torch.tensor(np.expand_dims(train_db[ :, 1 ],axis=-1), dtype=torch.float32)

    #P('# 6: evaluate grad_theta(loss)')
    model.zero_grad()
    pred = model.forward(batch_x)
    loss =  lossM(pred, batch_y) #torch.sum((pred - batch_y) ** 2) 
    
    loss_hist.append(loss.item())
    P('\t Loss:', loss_hist[-1])
    #loss.backward(create_graph=True)
    grads = torch.autograd.grad(loss, model.parameters, create_graph=False)

    #P('# 7: compute adapted paramters with grad descent')
    with torch.no_grad():
        for t_param, grad in zip(model.parameters, grads):
            t_param -= inner_lr * grad

npc = lambda n: np.ones(n) * (1/n)
plt.figure(figsize=(12,5))
plt.ylim(0,80)
plt.plot(np.convolve(loss_hist, npc(5)))
plt.show()

In [None]:
db =tasker.space(train_K)
bx = torch.tensor(np.expand_dims(db[ :, 0 ],axis=-1), dtype=torch.float32)
by = torch.tensor(np.expand_dims(db[ :, 1 ],axis=-1), dtype=torch.float32)
bp = []
with torch.no_grad():
    pred=model.forward(bx)
    plt.plot(db[:,0], db[:,1], color='green')
    plt.plot(db[:,0], pred, color='blue')
