# MAML MPLP


In [None]:
# @title Connect to internal TF kernel and run this.
import os
import io
import numpy as np
import glob

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

REMOTE_GPU = True
# This tutorial doesn't use videos, and therefore doesn't need custom kernels.
if REMOTE_GPU:
  print("here")

  from google3.learning.brain.contrib.eager.python import remote_eager
  from google3.third_party.tensorflow.python.eager import context

  # Replace with your BNS address (omit the task index)
  WORKER_NMB = '13129177'
  BNS_ADDRESS = '/bns/li/borg/li/bns/etr/etr_headless_gpu_{}.1.gpu_worker'.format(WORKER_NMB)
  NUM_WORKERS = 1

  # Connect
  print("Connecting")
  remote_eager.enable('{}/0'.format(BNS_ADDRESS), num_workers=NUM_WORKERS)
  print("Connected")

  # Print devices
  print(tf.config.experimental.list_logical_devices())


import matplotlib.pyplot as plt # visualization

from collections import defaultdict
import random

import itertools
import tensorflow_datasets as tfds
import tensorflow.compat.v2 as tf
import matplotlib.pyplot as plt

import IPython.display as display
from IPython.display import clear_output

from PIL import Image
import numpy as np
import os


In [None]:
from colabtools import adhoc_import
import getpass

client_name = 'twp'
user_name = getpass.getuser()
with adhoc_import.Google3CitcClient(client_name, user_name) as outputs:
  from google3.experimental.selforg.metalearn import core
  core = adhoc_import.Reload(core)

  from google3.experimental.selforg.metalearn import tf_layers
  tf_layers = adhoc_import.Reload(tf_layers)
  from google3.experimental.selforg.metalearn import util
  util = adhoc_import.Reload(util)
  from google3.experimental.selforg.metalearn.preprocessing import sinusoidals
  sinusoidals = adhoc_import.Reload(sinusoidals)
  from google3.experimental.selforg.metalearn import training
  training = adhoc_import.Reload(training)

MPDense = tf_layers.MPDense
MPActivation = tf_layers.MPActivation
MPSoftmax = tf_layers.MPSoftmax
MPL1Loss = tf_layers.MPL1Loss
MPL2Loss = tf_layers.MPL2Loss
MPNetwork = tf_layers.MPNetwork
MPMetrics = tf_layers.MPMetrics
SinusoidalsDS = sinusoidals.SinusoidalsDS
SamplePool = util.SamplePool
TrainingRegime = training.TrainingRegime

The task is to fit sinusoidals from randomly initialized networks.

Therefore, there are:
* Outer batch size = 4 number of tasks at every step. Each has a different network, different amplitude and different phase.
* Inner batch size = 10 number of examples for each forward/backward steps.
* num steps = 5, number of inner steps the network has to get better.
* train/eval split: the network only sees train instances during forward/backward. The meta-learning regime *may* choose to use eval splits as well, MAML-style.

In [None]:
# @title create dataset and plot it
OUTER_BATCH_SIZE = 4
INNER_BATCH_SIZE = 10
NUM_STEPS = 2

ds_factory = SinusoidalsDS()

ds = ds_factory.create_ds(OUTER_BATCH_SIZE, INNER_BATCH_SIZE, NUM_STEPS)
ds_iter = iter(ds)

# Utility range
xrange_inputs = np.linspace(-5,5,100).reshape((100, 1)).astype(np.float32)

xtb, ytb, xeb, yeb = next(ds_iter)
plt.figure(figsize=(14, 10))
colors = itertools.cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])
for xts, yts, xes, yes in zip(xtb, ytb, xeb, yeb):
  c_t = next(colors)
  c_e = next(colors)
  markers = itertools.cycle((',', '+', '.', 'o', '*')) 
  for xtsib, ytsib, xesib, yesib in zip(xts, yts, xes, yes):
    marker = next(markers)
    plt.scatter(xtsib, ytsib, c=c_t, marker=marker)
    plt.scatter(xesib, yesib, c=c_e, marker=marker)

plt.show()




Create a MP network:

In [None]:

# This is the size of the message passed.
message_size = 8
stateful = False
stateful_hidden_n = 15

# If we used metrics, we'd want to reset them whenever we create a new network.
# MPMetrics.reset()

# This network is keras-style initialized.
# If you want to create a single layer, you need to pass it also the in_dim
# and message size.
network = MPNetwork(
    [
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n), 
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(1, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     ],
     MPL2Loss(stateful=stateful, stateful_hidden_n=stateful_hidden_n))
network.setup(in_dim=1, message_size=message_size, inner_batch_size=INNER_BATCH_SIZE)

# see trainable weights:
tr_w = network.get_trainable_weights()
print("trainable weights:")
tot_w = 0
for w in tr_w:
  w = w.numpy()
  w_size = w.size
  tot_w += w_size

  print(w.shape, w_size)
print("tot n:", tot_w)

In [None]:
# for MAML training, we need one and only one set of variables.

trained_pfw = [tf.Variable(t) for t in network.init()]


In [None]:
num_steps = tf.constant(NUM_STEPS)

learning_schedule = 1e-4

# Prepare a training regime.
# The heldout_weight tells you how to split the loss between train and eval sets
# that are passed to the network.
# Empirically, a heldout_weight=0.0 (or None), results in a much lower overall
# performance, both for train and test losses.
training_regime = TrainingRegime(
    network, heldout_weight=1.0, hint_loss_ratio=None, remember_loss_ratio=None)

last_step = 0

# Initialize important parameters by passing a minibatch.
#x_b, y_b, _, _ = next(ds_iter)
#network.minibatch_init(x_b[0][0],  y_b[0][0], INNER_BATCH_SIZE)

# enhanced version of minibatch init, allowing to initialize by looking at more
# than just one step.
# Likewise, this can be run multiple times to improve the initialization.
for j in range(1):
  print("on", j)
  stats = []
  pfw = trained_pfw

  x_b, y_b, _, _ = next(ds_iter)
  x_b, y_b = x_b[0], y_b[0]
  for i in range(NUM_STEPS):
    # Initialize important parameters by passing a minibatch.
    #x_b, y_b = x_eval, target_eval # a minibatch, basically. Could be prettified.

    pfw, stats_i = network.minibatch_init(x_b[i],  y_b[i], x_b[i].shape[0], pfw=pfw)
    stats.append(stats_i)
  # update
  network.update_statistics(stats, update_perc=1.)

  print("final mean:")
  for p in tf.nest.flatten(pfw):
    print(p.shape, tf.reduce_mean(p), tf.math.reduce_std(p))



# The outer loop here uses Adam. SGD/Momentum are more stable but way slower.
#trainer = tf.keras.optimizers.Adam(learning_schedule)
trainer = tf.keras.optimizers.Adam(learning_schedule)

loss_log = []
def smoothen(l, lookback=20):
  # first of all, if it's a nan, change it to a high value
  kernel = [1./lookback] * lookback
  return np.convolve(l[0:1] * (lookback - 1) + l, kernel, "valid")


In [None]:
print([p.shape for p in trained_pfw])

In [None]:
training_steps = 200000
print("Stop this block whenever after 1-2k steps. It's good even very early.")

@tf.function
def step(pfw, xts, yts, xes, yes, num_steps):
  print("compiling")
  with tf.GradientTape() as g:
    pfw_serialized = network.serialize_pfw(pfw)
    # As you can see, we pass train and eval/heldout instances.
    # We only show the network the xts and yts for the inner update,
    # but we compute a cross-validation loss using also xes and yes.
    l, _, _ = training_regime.batch_mp_loss(
        pfw_serialized, xts, yts, xes, yes, num_steps, same_pfw=True)
  all_weights = network.get_trainable_weights()
  all_weights += pfw
  grads = g.gradient(l, all_weights)
  # Try grad clipping to avoid explosions.
  grads = [g/(tf.norm(g)+1e-8) for g in grads]
  trainer.apply_gradients(zip(grads, all_weights))
  return l


import time
start_time = time.time()

for i in range(last_step + 1, last_step +1 + training_steps):
  last_step = i

  tmp_t = time.time()
  xts, yts, xes, yes = next(ds_iter)

  l = step(trained_pfw, xts, yts, xes, yes, num_steps)
  loss_log.append(l)

  if i % 50 == 0:
    print(i)
    print("--- %s seconds ---" % (time.time() - start_time))
  if i % 500 == 0:
    plt.plot(smoothen(loss_log, 100), label='mp')
    plt.yscale('log')
    #plt.ylim(0.0, 1e-1)
    plt.legend()
    plt.show()
print("--- %s seconds ---" % (time.time() - start_time))



In [None]:

print(loss_log[-1])
plt.plot(smoothen(loss_log, 100), label='mp')
plt.yscale('log')
plt.ylim(0.0, 1e-1)
plt.gca().yaxis.grid(True)
plt.legend()
plt.show()

#Proper evaluation: run 100 different few-shot instances with totally new network params.

The train loss is computed only on points that the network has already observed.

The eval loss is computed on the entire range [-5, 5]


In [None]:
!mkdir tmp

!ls tmp -R

In [None]:

!mkdir tmp
file_path = "tmp/maml_sin_net_weights"

network.save_weights(file_path, last_step)

with open("tmp/maml_sin_prior_weights_{:08d}.npy".format(
    last_step), "wb") as fout:
  prior_to_save = tf.concat([tf.reshape(e, [-1]) for e in trained_pfw], 0)
  np.save(fout, prior_to_save.numpy())

!ls -lh tmp
#files.download('tmp/weights_{:08d}.npy'.format(last_step))

In [None]:

%download_file tmp/maml_sin_net_weights_00258202.npy
%download_file tmp/maml_sin_prior_weights_00258202.npy

In [None]:
# try to save and load
raise Exception("do not run")
network = MPNetwork(
    [
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(20, stateful=stateful, stateful_hidden_n=stateful_hidden_n), 
     MPActivation(tf.nn.relu, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     MPDense(1, stateful=stateful, stateful_hidden_n=stateful_hidden_n),
     ],
     MPL2Loss(stateful=stateful, stateful_hidden_n=stateful_hidden_n))
network.setup(in_dim=1, message_size=message_size, inner_batch_size=INNER_BATCH_SIZE)
network.load_weights(file_path)


In [None]:

eval_tot_steps = 100

tr_losses = np.zeros([eval_tot_steps, NUM_STEPS])
ev_losses = np.zeros([eval_tot_steps, NUM_STEPS + 1]) # also 0-step.

@tf.function
def get_loss(pfw, x, y):
  predictions, _, _= network.forward(pfw, x)
  loss, _ = network.compute_loss(predictions, y)
  return loss

start_time = time.time()

for r in range(eval_tot_steps):
  p_fw = trained_pfw # network.init(INNER_BATCH_SIZE)

  A, ph = ds_factory._create_task()

  targets = A * np.sin(xrange_inputs + ph)

  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
  #xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)

  # initial loss.
  loss = get_loss(p_fw, xrange_inputs, targets)
  ev_losses[r, 0] = tf.reduce_mean(loss)

  for i in range(NUM_STEPS):
    p_fw, _, _ = network.inner_update(p_fw, xt[i], yt[i])

    # loss specific to only what we observe.
    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
    loss = get_loss(p_fw, x_observed_so_far, y_observed_so_far)
    tr_losses[r, i] = tf.reduce_mean(loss)

    # Plotting for the continuous input range
    loss = get_loss(p_fw, xrange_inputs, targets)
    ev_losses[r, i + 1] = tf.reduce_mean(loss)
print("--- %s seconds ---" % (time.time() - start_time))

tr_losses_m = np.mean(tr_losses, axis=0)
ev_losses_m = np.mean(ev_losses, axis=0)

tr_losses_sd = np.std(tr_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)

print("tr_l, m:", tr_losses_m, " sd:", tr_losses_sd)
print("ev_l, m:", ev_losses_m, " sd:", ev_losses_sd)

ub = [m + sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
lb = [m - sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
plt.fill_between(range(1, len(tr_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(tr_losses_m) + 1), tr_losses_m, label='train loss')

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(0, len(ev_losses_m)), ub, lb, alpha=.5)
plt.plot(range(0, len(ev_losses_m)), ev_losses_m, label='eval loss')
plt.ylim(0.0, 0.025)
plt.xlabel("num steps")
plt.ylabel("L2 loss")
plt.legend()

"""
with open("tmp/mplp_losses.png", "wb") as fout:
  plt.savefig(fout)
"""


In [None]:
print(tr_losses_m, ev_losses_m)

In [None]:
%download_file tmp/mplp_losses.png

In [None]:
# title Show an example run:

fig, axs = plt.subplots(5, 2, figsize=(10,15))

for fig_n in range(5):
  p_fw = trained_pfw

  n_plot = 5
  plot_every = 1# NUM_STEPS // n_plot if NUM_STEPS >= n_plot else NUM_STEPS

  predictions, _, _ = network.forward(p_fw, xrange_inputs)
  #plt.plot(xrange_inputs, predictions, label='pre-update predictions')

  A, ph = ds_factory._create_task()

  targets = A * np.sin(xrange_inputs + ph)
  axs[fig_n][0].plot(xrange_inputs, targets, label='target')

  predictions, _, _= network.forward(p_fw, xrange_inputs)
  axs[fig_n][0].plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
  xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
  tr_losses = []
  ev_losses = []

  for i in range(NUM_STEPS):
    p_fw, _, _ = network.inner_update(p_fw, xt[i], yt[i])
    #print([tf.reduce_mean(p) for p in p_fw])

    # loss specific to only what we observe.
    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
    predictions, _, _= network.forward(p_fw, x_observed_so_far)
    loss, _ = network.compute_loss(predictions, y_observed_so_far)
    tr_losses.append(tf.reduce_mean(loss))

    # Plotting for the continuous input range
    predictions, _, _= network.forward(p_fw, xrange_inputs)
    if (i+1) % plot_every == 0:
      axs[fig_n][0].plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
    loss, _ = network.compute_loss(predictions, targets)
    ev_losses.append(tf.reduce_mean(loss))

  #axs[fig_n][0].legend()
  #plt.show()

  axs[fig_n][1].plot(np.arange(len(tr_losses)), tr_losses, label='tr_losses')
  axs[fig_n][1].plot(np.arange(len(ev_losses)), ev_losses, label='ev_losses')
  #axs[fig_n][1].legend()
  #plt.show()

axs[0][0].legend()
axs[0][1].legend()

In [None]:
# title Single run for drawing.

p_fw = trained_pfw

plot_every = 1# NUM_STEPS // n_plot if NUM_STEPS >= n_plot else NUM_STEPS

predictions, _, _ = network.forward(p_fw, xrange_inputs)
#plt.plot(xrange_inputs, predictions, label='pre-update predictions')

A, ph = ds_factory._create_task()

targets = A * np.sin(xrange_inputs + ph)
plt.plot(xrange_inputs, targets, label='target')

predictions, _, _= network.forward(p_fw, xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
tr_losses = []
ev_losses = []

for i in range(NUM_STEPS):
  p_fw, _, _ = network.inner_update(p_fw, xt[i], yt[i])
  #print([tf.reduce_mean(p) for p in p_fw])

  # loss specific to only what we observe.
  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
  predictions, _, _= network.forward(p_fw, x_observed_so_far)
  loss, _ = network.compute_loss(predictions, y_observed_so_far)
  tr_losses.append(tf.reduce_mean(loss))

  # Plotting for the continuous input range
  predictions, _, _= network.forward(p_fw, xrange_inputs)
  if (i+1) % plot_every == 0:
    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
  loss, _ = network.compute_loss(predictions, targets)
  ev_losses.append(tf.reduce_mean(loss))


plt.legend()


with open("tmp/mplp_example_run.png", "wb") as fout:
  plt.savefig(fout)
%download_file tmp/mplp_example_run.png


# Compare it with MAML run

In [None]:
maml_pfw = [tf.Variable(t) for t in network.init()]
maml_last_step = 0
maml_loss_log = []

In [None]:
training_steps = 200000
print("Stop this block whenever after 1-2k steps. It's good even very early.")

def update_pfw(pfw, xt, yt, num_steps):
  for i in tf.range(num_steps):
    with tf.GradientTape() as g:
      g.watch(pfw)
      prediction, _, _ = network.forward(pfw, xt[i])
      loss, _ = network.compute_loss(prediction, yt[i])
      loss = tf.reduce_mean(loss)
    grads = g.gradient(loss, pfw)
    
    pfw = [p - 0.05 * pg for p, pg in zip(pfw, grads)]
  return pfw

def single_loss(pfw, xt, yt, xe, ye, num_steps):
  new_pfw = update_pfw(pfw, xt, yt, num_steps)

  prediction, _, _ = network.forward(new_pfw, xe)
  cv_loss, _ = network.compute_loss(prediction, ye)
  cv_loss = tf.reduce_mean(cv_loss)
  return cv_loss

def batch_maml_loss(pfw, xts, yts, xes, yes, num_steps):
  task_losses = []
  for i in range(len(xts)):
    task_losses.append(
        single_loss(pfw, xts[i], yts[i], xes[i], yes[i], num_steps))
  return tf.reduce_mean(tf.stack(task_losses))

@tf.function
def maml_step(pfw, xts, yts, xes, yes, num_steps):
  print("compiling")
  with tf.GradientTape() as g:
    l = batch_maml_loss(pfw, xts, yts, xes, yes, num_steps)
  grads = g.gradient(l, pfw)
  # Try grad clipping to avoid explosions.
  grads = [g/(tf.norm(g)+1e-8) for g in grads]
  trainer.apply_gradients(zip(grads, pfw))
  return l


import time
start_time = time.time()

for i in range(maml_last_step + 1, maml_last_step +1 + training_steps):
  maml_last_step = i

  tmp_t = time.time()
  xts, yts, xes, yes = next(ds_iter)

  l = maml_step(maml_pfw, xts, yts, xes, yes, num_steps)
  maml_loss_log.append(l)

  if i % 50 == 0:
    print(i)
    print("--- %s seconds ---" % (time.time() - start_time))
  if i % 500 == 0:
    plt.plot(smoothen(maml_loss_log, 100), label='mp')
    plt.yscale('log')
    #plt.ylim(0.0, 1e-1)
    plt.legend()
    plt.show()
print("--- %s seconds ---" % (time.time() - start_time))



In [None]:
eval_tot_steps = 100

tr_losses = np.zeros([eval_tot_steps, NUM_STEPS])
ev_losses = np.zeros([eval_tot_steps, NUM_STEPS + 1]) # also 0 step.

for r in range(eval_tot_steps):
  # We need to transform these into variables.
  p_fw = maml_pfw

  A, ph = ds_factory._create_task()

  targets = A * np.sin(xrange_inputs + ph)

  xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
  xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)

  # initial loss.
  predictions, _, _= network.forward(p_fw, xrange_inputs)
  loss, _ = network.compute_loss(predictions, targets)
  ev_losses[r, 0] = tf.reduce_mean(loss)

  for i in range(NUM_STEPS):
    p_fw = update_pfw(p_fw, xt[i:i+1], yt[i:i+1], num_steps=1)

    # loss specific to only what we observe.
    x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
    y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
    predictions, _, _= network.forward(p_fw, x_observed_so_far)
    loss, _ = network.compute_loss(predictions, y_observed_so_far)
    tr_losses[r, i] = tf.reduce_mean(loss)

    # Plotting for the continuous input range
    predictions, _, _= network.forward(p_fw, xrange_inputs)
    loss, _ = network.compute_loss(predictions, targets)
    ev_losses[r, i + 1] = tf.reduce_mean(loss)

tr_losses_m = np.mean(tr_losses, axis=0)
ev_losses_m = np.mean(ev_losses, axis=0)
tr_losses_sd = np.std(tr_losses, axis=0)
ev_losses_sd = np.std(ev_losses, axis=0)
print("tr_l, m:", tr_losses_m, " sd:", tr_losses_sd)
print("ev_l, m:", ev_losses_m, " sd:", ev_losses_sd)

ub = [m + sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
lb = [m - sd for m, sd in zip(tr_losses_m, tr_losses_sd)]
plt.fill_between(range(1, len(tr_losses_m) + 1), ub, lb, alpha=.5)
plt.plot(range(1, len(tr_losses_m) + 1), tr_losses_m, label='train loss')

ub = [m + sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
lb = [m - sd for m, sd in zip(ev_losses_m, ev_losses_sd)]
plt.fill_between(range(0, len(ev_losses_m)), ub, lb, alpha=.5)
plt.plot(range(0, len(ev_losses_m)), ev_losses_m, label='eval loss')
plt.ylim(0.0, 0.04)
plt.xlabel("num steps")
plt.ylabel("L2 loss")
plt.legend()

"""
with open("tmp/adam_losses.png", "wb") as fout:
  plt.savefig(fout)
%download_file tmp/adam_losses.png
"""

In [None]:
# Same task as MPLP fo same drawing

p_fw = maml_pfw

#n_plot = 5
plot_every = 1# NUM_STEPS // n_plot

predictions, _, _ = network.forward(p_fw, xrange_inputs)
#plt.plot(xrange_inputs, predictions, label='pre-update predictions')

#A, ph = ds_factory._create_task()

targets = A * np.sin(xrange_inputs + ph)
plt.plot(xrange_inputs, targets, label='target')

predictions, _, _= network.forward(p_fw, xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
tr_losses = []
ev_losses = []

for i in range(NUM_STEPS):
  p_fw = update_pfw(p_fw, xt[i:i+1], yt[i:i+1], num_steps=1)

  # loss specific to only what we observe.
  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
  predictions, _, _= network.forward(p_fw, x_observed_so_far)
  loss, _ = network.compute_loss(predictions, y_observed_so_far)
  tr_losses.append(tf.reduce_mean(loss))

  # Plotting for the continuous input range
  predictions, _, _= network.forward(p_fw, xrange_inputs)
  if (i+1) % plot_every == 0:
    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
  loss, _ = network.compute_loss(predictions, targets)
  ev_losses.append(tf.reduce_mean(loss))

plt.legend()

with open("tmp/maml_example_run.png", "wb") as fout:
  plt.savefig(fout)

%download_file tmp/maml_example_run.png

In [None]:
##@title Show an example run:
n_plot = 5
plot_every = max(1, NUM_STEPS // n_plot)

p_fw = maml_pfw
predictions, _, _ = network.forward(p_fw, xrange_inputs)
#plt.plot(xrange_inputs, predictions, label='pre-update predictions')

A, ph = ds_factory._create_task()

targets = A * np.sin(xrange_inputs + ph)
plt.plot(xrange_inputs, targets, label='target')

predictions, _, _= network.forward(p_fw, xrange_inputs)
plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(0))

xt, yt = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
xe, ye = ds_factory._create_instance(A, ph, INNER_BATCH_SIZE, NUM_STEPS)
tr_losses = []
ev_losses = []

for i in range(NUM_STEPS):
  p_fw = update_pfw(p_fw, xt[i:i+1], yt[i:i+1], num_steps=1)

  # loss specific to only what we observe.
  x_observed_so_far = tf.reshape(xt[:i+1], (-1, 1))
  y_observed_so_far = tf.reshape(yt[:i+1], (-1, 1))
  predictions, _, _= network.forward(p_fw, x_observed_so_far)
  loss, _ = network.compute_loss(predictions, y_observed_so_far)
  tr_losses.append(tf.reduce_mean(loss))

  # Plotting for the continuous input range
  predictions, _, _= network.forward(p_fw, xrange_inputs)
  if (i+1) % plot_every == 0:
    plt.plot(xrange_inputs, predictions, label='{}-step predictions'.format(i+1))
  loss, _ = network.compute_loss(predictions, targets)
  ev_losses.append(tf.reduce_mean(loss))

plt.legend()
plt.show()

plt.plot(np.arange(len(tr_losses)), tr_losses, label='tr_losses')
plt.plot(np.arange(len(ev_losses)), ev_losses, label='ev_losses')
plt.legend()
plt.show()
