# Meta Learning in Jax (Practice)

In [1]:
# Imports
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from typing import Tuple, Dict, List
from IPython import display

In [2]:
# Typings
Task = Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]

## Preparing data

In [3]:
def generate_sin_task(key: random.PRNGKey, k, test_size=10) -> Tuple[random.PRNGKey, Task]:
    key, A_key, w_key, b_key = random.split(key, 4)
    key, input_key = random.split(key)

    # Generate random parameters for the sin curve
    A = random.uniform(A_key, minval=0.1, maxval=5)
    w = random.uniform(w_key, minval=0.8, maxval=1.2)
    b = random.uniform(b_key, minval=0, maxval=jnp.pi)

    # Generate random training input values
    x_train = random.uniform(input_key, shape=(k,1), minval=-5, maxval=5)

    # Generate random testing input values
    x_test = random.uniform(input_key, shape=(test_size,1), minval=-5, maxval=5)

    # Calculate traning output
    y_train = A * jnp.sin(w * x_train + b)

    # Calculate testing output
    y_test = A * jnp.sin(w * x_test + b)

    d_train = (x_train, y_train)
    d_test = (x_test, y_test)
    task = (d_train, d_test)

    return key, task

In [4]:
class TaskLoader:
    def __init__(self, dataset: List[Task], batch_size: int):
        self.dataset = dataset
        self.batch_size = batch_size
    
    def __len__(self) -> int:
        return len(self.dataset)
    
    def __iter__(self):
        length = len(self.dataset)
        for i in range(0, length, self.batch_size):
            yield self.convert_to_batch(self.dataset[i:i+self.batch_size])
    
    def convert_to_batch(self, tasks):
        x_train = []
        x_test = []
        y_train = []
        y_test = []
        for task in tasks:
            d_train, d_test = task
            x_train.append(d_train[0])
            x_test.append(d_test[0])
            y_train.append(d_train[1])
            y_test.append(d_test[1])
        return (jnp.stack(x_train), jnp.stack(y_train)), (jnp.stack(x_test), jnp.stack(y_test))    

In [5]:
def get_meta_train_dataset(key: random.PRNGKey, num_tasks: int, shots=5) -> Tuple[random.PRNGKey, List[Task]]:
    tasks = []
    for _ in range(num_tasks):
        key, task = generate_sin_task(key, shots)
        tasks.append(task)
    return key, tasks

In [6]:
def get_meta_test_dataset(key: random.PRNGKey, num_tasks: int = 100, shots=5) -> Tuple[random.PRNGKey, List[Task]]:
    tasks = []
    for _ in range(num_tasks):
        key, task = generate_sin_task(key, shots, test_size=100)
        tasks.append(task)
    return key, tasks

## Neural Network functions

In [7]:
# Activation
@jax.jit
def relu(x):
    return jnp.maximum(x, 0)

@jax.jit
def tanh(x):
    return jnp.tanh(x)

In [8]:
# Forward pass
def forward(params, x: jnp.ndarray) -> jnp.ndarray:
    for param in params[:-1]:
        w = param['w']
        b = param['b']
        x = w.T @ x + b
        x = relu(x)

    w = params[-1]['w']
    b = params[-1]['b']
    x = w.T @ x + b
    # x = 5 * tanh(x)
    return x

In [9]:
def mse(prediction: jnp.ndarray, target: jnp.ndarray) -> jnp.ndarray:
    return jnp.mean((prediction - target) ** 2)

In [10]:
class Optimizer:
    
    @jax.jit
    def SGD(params, grads, lr):
        return jax.tree_multimap(
      lambda p, g: p - lr * g, params, grads
    )

    @jax.jit
    def MetaSGD(theta, alpha, grads):
      return jax.tree_multimap(
        lambda t, a, g: t - a * g, theta, alpha, grads
      )

In [11]:
def init_params(param_key, layers):
    theta = []
    keys = random.split(param_key, len(layers)-1)
    for (l1, l2, key) in zip(layers[:-1], layers[1:], keys):
        w = random.normal(key, (l1, l2))*0.01
        b = jnp.zeros(shape=(l2,), dtype=jnp.float32)
        theta.append({'w': w, 'b': b})
    return tuple(theta)

## MAML

In [None]:
key = random.PRNGKey(0)
key, param_key = random.split(key)
meta_train_key, meta_test_key = random.split(key)

In [None]:
alpha = 0.01
shots = 5
task_size = 400
_, meta_train_dataset = get_meta_train_dataset(meta_train_key, num_tasks=task_size, shots=shots)
_, meta_test_dataset = get_meta_test_dataset(meta_test_key, num_tasks=100, shots=shots)

In [None]:
# Initialize parameters
dims = [1, 40, 40, 1]
theta = init_params(param_key, layers=dims)

### Meta Training

In [None]:
epochs = 200
fas = 1
beta = 0.01
task_loader = TaskLoader(meta_train_dataset, batch_size=4)

In [None]:
# This deals with input batches on a single task
def loss_fn(theta, data):
    x, y = data
    logits = jax.vmap(forward, in_axes=(None, 0))(theta, x)
    return mse(logits, y)

In [None]:
@jax.jit
def adapt(theta, data, lr):
    grads = jax.grad(loss_fn)(theta, data)
    return Optimizer.SGD(theta, grads, lr)

In [None]:
def meta_loss_fn(theta, tasks, lr):
    def single_meta_loss_fn(theta, task, lr):
        d_train, d_test = task
        theta_prime = adapt(theta, d_train, lr)
        return loss_fn(theta_prime, d_test)
    return jnp.mean(jax.vmap(single_meta_loss_fn, in_axes=(None, ((0, 0), (0, 0)), None))(theta, tasks, lr))

In [None]:
@jax.jit
def meta_optimize(theta, tasks, alpha, beta):
    loss, meta_grads = jax.value_and_grad(meta_loss_fn)(theta, tasks, lr=alpha)
    theta = Optimizer.SGD(theta, meta_grads, lr=beta)
    return theta, loss

In [None]:
# Meta training

for epoch in range(epochs):
    running_loss = []
    for tasks in task_loader:
        theta, loss = meta_optimize(theta, tasks, alpha, beta)
        running_loss.append(loss)
    loss = jnp.mean(jnp.array(running_loss))
    print(f"Epoch: {epoch+1}, Loss: {loss}")
    display.clear_output(wait=True)

### Meta Testing

In [None]:
# Meta testing
task_loader = TaskLoader(meta_test_dataset, batch_size=1)
total_loss = []

for task in task_loader:
    meta_loss = meta_loss_fn(theta, task, alpha)
    total_loss.append(meta_loss)
loss = jnp.mean(jnp.stack(total_loss))

print(f"Meta test loss: {loss}")

## META-SGD

In [12]:
key = random.PRNGKey(0)
key, param_key = random.split(key)
meta_train_key, meta_test_key = random.split(key)



In [13]:
def init_alpha(alpha_key, layers):
    a = random.uniform(alpha_key, (1,), minval=0.005, maxval=0.1).item()
    alpha = []
    for (l1, l2) in zip(layers[:-1], layers[1:]):
        w = jnp.full((l1, l2), a)
        b = jnp.full((l2,), a)
        alpha.append({'w': w, 'b': b})
    return tuple(alpha)

In [14]:
shots = 5
task_size = 400
_, meta_train_dataset = get_meta_train_dataset(meta_train_key, num_tasks=task_size, shots=shots)
_, meta_test_dataset = get_meta_test_dataset(meta_test_key, num_tasks=100, shots=shots)

In [15]:
# Initialize parameters
alpha_key, _ = random.split(param_key)
dims = [1, 40, 40, 1]
theta = init_params(param_key, layers=dims)
alpha = init_alpha(alpha_key, layers=dims)
params = (theta, alpha)

### Meta Training

In [16]:
epochs = 116
fas = 1
beta = 0.01
task_loader = TaskLoader(meta_train_dataset, batch_size=4)

In [17]:
# This deals with input batches on a single task
def loss_fn(theta, data):
    x, y = data
    logits = jax.vmap(forward, in_axes=(None, 0))(theta, x)
    return mse(logits, y)

In [18]:
@jax.jit
def adapt(params, data):
    theta, alpha = params
    grads = jax.grad(loss_fn)(theta, data)
    return Optimizer.MetaSGD(theta, alpha, grads)

In [19]:
def meta_loss_fn(params, tasks):
    def single_meta_loss_fn(params, task):
        d_train, d_test = task
        theta_prime = adapt(params, d_train)
        return loss_fn(theta_prime, d_test)
    return jnp.mean(jax.vmap(single_meta_loss_fn, in_axes=(None, ((0, 0), (0, 0))))(params, tasks))

In [20]:
@jax.jit
def meta_optimize(params, tasks, beta):
    loss, meta_grads = jax.value_and_grad(meta_loss_fn)(params, tasks)
    params = Optimizer.SGD(params, meta_grads, lr=beta)
    return params, loss

In [21]:
# Meta training

for epoch in range(epochs):
    running_loss = []
    for tasks in task_loader:
        params, loss = meta_optimize(params, tasks, beta)
        running_loss.append(loss)
    loss = jnp.mean(jnp.array(running_loss))
    print(f"Epoch: {epoch+1}, Loss: {loss}")
    display.clear_output(wait=True)

Epoch: 116, Loss: 0.871147871017456


In [25]:
alpha

({'b': DeviceArray([ 0.07510047, -0.03043799,  0.09361773,  0.09383196,
               -0.04353495,  0.35591167,  0.16577566,  0.00849801,
                0.05657737,  0.33873984,  0.10075056, -0.04562387,
               -0.10016309, -0.0271682 ,  0.09440634,  0.30484062,
                0.09440784, -0.32592985, -0.02220334,  0.0944064 ,
                0.0944078 , -0.01428996,  0.04325764, -0.22249158,
                0.04611416,  0.09839371,  0.22376578, -0.00235719,
                0.17428641,  0.4524485 ,  0.04886561,  0.09440812,
                0.00126575,  0.0431421 ,  0.07989103,  0.09440958,
                0.5607324 ,  0.09440977,  0.21418506,  0.2607757 ],            dtype=float32, weak_type=True),
  'w': DeviceArray([[ 0.1806522 , -0.07453348,  0.09189415,  0.09439375,
                -0.04526986, -0.07164492,  0.18109062, -0.00610501,
                 0.04987805, -0.03284879,  0.03893312,  0.075978  ,
                 0.1580771 , -0.01073916,  0.09440987,  0.09301425,
    

### Meta Testing

In [22]:
# Meta testing
task_loader = TaskLoader(meta_test_dataset, batch_size=1)
total_loss = []

for task in task_loader:
    meta_loss = meta_loss_fn(params, task)
    total_loss.append(meta_loss)
loss = jnp.mean(jnp.stack(total_loss))

print(f"Meta test loss: {loss}")

Meta test loss: 1.063297152519226
