In [5]:
import numpy as np
import torch
from torch import nn, autograd as ag
import matplotlib.pyplot as plt
from copy import deepcopy

seed = 0
plot = True
innerstepsize = 0.02 # stepsize in inner SGD
innerepochs = 1 # number of epochs of each inner SGD
outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization
niterations = 30000 # number of outer updates; each iteration we sample one task and update on it

rng = np.random.RandomState(seed)
torch.manual_seed(seed)

# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points
ntrain = 10 # Size of training minibatches
def gen_task():
    "Generate classification problem"
    phase = rng.uniform(low=0, high=2*np.pi)
    ampl = rng.uniform(0.1, 5)
    f_randomsine = lambda x : np.sin(x + phase) * ampl
    return f_randomsine

def gen_task2(phase, ampl):
    f_randomsine = lambda x : np.sin(x + phase) * ampl
    return f_randomsine

# Define model. Reptile paper uses ReLU, but Tanh gives slightly better results
model = nn.Sequential(
    nn.Linear(1, 64),
    nn.Tanh(),
    nn.Linear(64, 64),
    nn.Tanh(),
    nn.Linear(64, 1),
)

def totorch(x):
    return ag.Variable(torch.Tensor(x))

def train_on_batch(x, y):
    x = totorch(x)
    y = totorch(y)
    model.zero_grad()
    ypred = model(x)
    loss = (ypred - y).pow(2).mean()
    loss.backward()
    for param in model.parameters():
        param.data -= innerstepsize * param.grad.data

def predict(x):
    x = totorch(x)
    return model(x).data.numpy()

# Choose a fixed task and minibatch for visualization

f_plot = gen_task()
with open("eval_task.csv", "r") as f:
    for line in f:
        phase, ampl = [float(x) for x in line.split(",")]
        f_plot = gen_task2(phase, ampl)

xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]
with open("eval_ids.csv", "r") as f:
    for line in f:
        xtrain_plot = x_all[[int(x) for x in line.split(",")]]

tasks = []
with open("tasks.csv", "r") as f:
    for line in f:
        tasks.append([float(x) for x in line.split(",")])

ext_inds = []
with open("ext_inds.csv", "r") as f:
    for line in f:
        ext_inds.append([int(x) for x in line.split(",")])
        
# Reptile training loop
for iteration in range(niterations):
    weights_before = deepcopy(model.state_dict())
    # Generate task
    f = gen_task2(tasks[iteration][0], tasks[iteration][1])
    y_all = f(x_all)
    # Do SGD on this task
    # inds = rng.permutation(len(x_all))
    inds = ext_inds[iteration]
    for _ in range(innerepochs):
        for start in range(0, len(x_all), ntrain):
            mbinds = inds[start:start+ntrain]
            train_on_batch(x_all[mbinds], y_all[mbinds])
    # Interpolate between current weights and trained weights from this task
    # I.e. (weights_before - weights_after) is the meta-gradient
    weights_after = model.state_dict()
    outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule
    model.load_state_dict({name : 
        weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize 
        for name in weights_before})

    # Periodically plot the results on a particular task and minibatch
    if plot and iteration==0 or (iteration+1) % 1000 == 0:
        plt.cla()
        f = f_plot
        weights_before = deepcopy(model.state_dict()) # save snapshot before evaluation
        plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1))
        for inneriter in range(32):
            train_on_batch(xtrain_plot, f(xtrain_plot))
            if (inneriter+1) % 8 == 0:
                frac = (inneriter+1) / 32
                plt.plot(x_all, predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac))
        plt.plot(x_all, f(x_all), label="true", color=(0,1,0))
        lossval = np.square(predict(x_all) - f(x_all)).mean()
        plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
        plt.ylim(-4,4)
        plt.legend(loc="lower right")
        plt.pause(0.01)
        model.load_state_dict(weights_before) # restore from snapshot
        print("-----------------------------")
        print("iteration               {}".format(iteration+1))
        print("loss on plotted curve   {:.3f}".format(lossval)) # would be better to average loss over a set of examples, but this is optimized for brevity

-----------------------------
iteration               1
loss on plotted curve   0.432
-----------------------------
iteration               1000
loss on plotted curve   0.317


KeyboardInterrupt: 

In [19]:
with open("eval_task.csv", "r") as f:
    for line in f:
        print([float(x) for x in line.split(",")])

[3.8851604956705836, 1.634926426631627]


In [16]:
with open("eval_ids.csv", "r") as f:
    for line in f:
        xtrain_plot = x_all[[int(x) for x in line.split(",")]]

xtrain_plot

array([[ 0.91836735],
       [ 4.79591837],
       [-2.14285714],
       [-1.12244898],
       [-5.        ],
       [ 4.18367347],
       [-3.36734694],
       [ 1.53061224],
       [ 0.51020408],
       [ 3.16326531]])

In [24]:
tasks = []
with open("tasks.csv", "r") as f:
    for line in f:
        tasks.append([float(x) for x in line.split(",")])

In [16]:
model.state_dict()['0.weight'].numpy()

array([[-0.0221422 ],
       [ 0.55535251],
       [-0.85915786],
       [-0.84218323],
       [-0.37831536],
       [ 0.32682943],
       [ 0.05167898],
       [ 0.80155498],
       [-0.12587641],
       [ 0.26184642],
       [-0.30406606],
       [-0.12208395],
       [-1.06078362],
       [-0.74192226],
       [-0.48071307],
       [ 0.06067892],
       [ 0.59191287],
       [ 0.69817364],
       [-0.72751081],
       [-0.58355033],
       [ 0.25810301],
       [ 0.92219269],
       [-0.21943124],
       [ 0.85931396],
       [-0.10027938],
       [ 0.08487212],
       [ 0.97331411],
       [-0.95175725],
       [-0.71544433],
       [-0.23541121],
       [-0.42948088],
       [ 0.89139748],
       [-0.76156646],
       [-0.47726426],
       [-0.7944538 ],
       [-0.97052705],
       [-0.71252966],
       [ 0.95814681],
       [ 0.50040412],
       [ 0.41395468],
       [ 0.01961884],
       [-0.36207476],
       [ 0.0747766 ],
       [-0.97882289],
       [-0.79214132],
       [-0

In [11]:
for p in model.parameters():
    print(p.data.numpy())

[[-0.0221422 ]
 [ 0.55535251]
 [-0.85915786]
 [-0.84218323]
 [-0.37831536]
 [ 0.32682943]
 [ 0.05167898]
 [ 0.80155498]
 [-0.12587641]
 [ 0.26184642]
 [-0.30406606]
 [-0.12208395]
 [-1.06078362]
 [-0.74192226]
 [-0.48071307]
 [ 0.06067892]
 [ 0.59191287]
 [ 0.69817364]
 [-0.72751081]
 [-0.58355033]
 [ 0.25810301]
 [ 0.92219269]
 [-0.21943124]
 [ 0.85931396]
 [-0.10027938]
 [ 0.08487212]
 [ 0.97331411]
 [-0.95175725]
 [-0.71544433]
 [-0.23541121]
 [-0.42948088]
 [ 0.89139748]
 [-0.76156646]
 [-0.47726426]
 [-0.7944538 ]
 [-0.97052705]
 [-0.71252966]
 [ 0.95814681]
 [ 0.50040412]
 [ 0.41395468]
 [ 0.01961884]
 [-0.36207476]
 [ 0.0747766 ]
 [-0.97882289]
 [-0.79214132]
 [-0.59230739]
 [ 0.73408157]
 [ 0.75951695]
 [-0.54085475]
 [-0.03548805]
 [ 0.70868784]
 [ 1.01554716]
 [ 0.3598974 ]
 [ 0.18038218]
 [ 0.70964849]
 [-0.60463977]
 [ 0.16329235]
 [-0.83828813]
 [-0.71010029]
 [-0.52556252]
 [ 0.6258198 ]
 [ 0.40616217]
 [-0.6864723 ]
 [ 0.35487333]]
[ 0.6648221  -0.13442098  0.12021134  0