<a href="https://colab.research.google.com/github/Chavdarova/jax_intro/blob/master/metaInit_mnist_jax_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MetaInint & MNIST classification in JAX and PyTorch




#  [MetaInit](https://papers.nips.cc/paper/9427-metainit-initializing-learning-by-learning-to-initialize.pdf), Dauphin & Schoenholz, NeurIPS '19, 



### MetaInit usage:

1. Initialize the way you do
2. Re-Initialize the params with MetaInit
    - do this using random data as input samples and random targets (using the same loss function as you will later)
4. Train


# MNIST classification 

In [0]:
#@title Hyperparameters and network layers

layer_sizes = [784, 512, 512, 10]
step_size = 0.0001
num_epochs = 8
batch_size = 128
n_targets = 10

In [5]:
#@title data loading (PyTorch: torchvision.datasets)
!pip install torch torchvision

import jax.numpy as np
import numpy as onp
from torch.utils import data
from torchvision.datasets import MNIST

def numpy_collate(batch):
  if isinstance(batch[0], onp.ndarray):
    return onp.stack(batch)
  elif isinstance(batch[0], (tuple,list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return onp.array(batch)

class NumpyLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return onp.ravel(onp.array(pic, dtype=np.float32))

def one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

# Define our dataset, using torch datasets
mnist_dataset = MNIST('/tmp/mnist/', download=True, 
                      transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=128, 
                                 num_workers=0)

# Get the full train dataset (for checking accuracy while training)
train_images = onp.array(mnist_dataset.train_data).reshape(
    len(mnist_dataset.train_data), -1)
train_labels = one_hot(onp.array(mnist_dataset.train_labels), n_targets)
print('[train shapes] images: {} labels: {}'.format(
    train_images.shape, train_labels.shape))
_train_images = np.array(mnist_dataset.train_data.numpy().reshape(
    len(mnist_dataset.train_data), -1), dtype=np.float32)

# Get full test dataset
mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = np.array(mnist_dataset_test.test_data.numpy().reshape(
    len(mnist_dataset_test.test_data), -1), dtype=np.float32)
test_labels = one_hot(onp.array(mnist_dataset_test.test_labels), n_targets)
print('[test shapes] images: {} labels: {}'.format(
    test_images.shape, test_labels.shape))

[train shapes] images: (60000, 784) labels: (60000, 10)
[test shapes] images: (10000, 784) labels: (10000, 10)




### JAX
based on the code from [this jax  tutorial](https://github.com/google/jax/blob/master/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)

In [0]:
#@title imports, initialization, loss, and network def
from __future__ import print_function, division, absolute_import
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
from jax.scipy.special import logsumexp
import functools
import time


def random_layer_params(m, n, key, scale=1e-2, zero_bias=False):
  """Helper function: randomly initialzie weights and biases of
  a dense NN layer. """
  w_key, b_key = random.split(key)
  _w = scale * random.normal(w_key, (n, m))
  _b = np.zeros((n,)) if zero_bias else scale * random.normal(b_key, (n,))
  return _w, _b


def init_network_params(sizes, key, zero_bias=False):
  """Initialize MLP with sizes `sizes`. """
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k, zero_bias=zero_bias) 
          for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

@jit
def relu(x):
  """ ReLU non-linearity. """
  return np.maximum(0, x)

@jit
def predict(params, image):
  """ Per-sample predictions. """
  activations = image
  for w, b in params[:-1]:
    outputs = np.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = np.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)  # for numerical stability

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

@jit
def accuracy(params, images, targets):
  target_class = np.argmax(targets, axis=1)
  predicted_class = np.argmax(batched_predict(params, images), axis=1)
  return np.mean(predicted_class == target_class)

@jit
def loss(params, images, targets):
  # preds: Traced<ShapedArray(float32[128,10])>with<JVPTrace(level=1/1)>
  preds = batched_predict(params, images)
  _r = -np.sum(preds * targets)
  # print("loss jax {}".format(_r))
  return _r

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

def train_mnist_jax(params, training_generator, 
                    tr_images, tr_labels, te_images, te_labels,
                    num_epochs=10, n_targets=10):
  for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in training_generator:  # x: bsize x 784
      y = one_hot(y, n_targets)  # y: bsize x 10
      params = update(params, x, y)
    epoch_time = time.time() - start_time

    train_acc = accuracy(params, tr_images, tr_labels)
    test_acc = accuracy(params, te_images, te_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))
  return params


# TEST ----------------------------------------------------
# # `batched_predict` has the same call signature as `predict`
# params = init_network_params(layer_sizes, random.PRNGKey(0))
# random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
# batched_preds = batched_predict(params, random_flattened_images)
# print(batched_preds.shape)

In [7]:
#@title Training loop (no MetaInit)

params = init_network_params(layer_sizes, random.PRNGKey(0))
params = train_mnist_jax(params, training_generator, 
                         train_images, train_labels,
                         test_images, test_labels,
                         num_epochs, n_targets)

Epoch 0 in 5.99 sec
Training set accuracy 0.9594333171844482
Test set accuracy 0.9560999870300293
Epoch 1 in 2.97 sec
Training set accuracy 0.97843337059021
Test set accuracy 0.9702000021934509
Epoch 2 in 2.98 sec
Training set accuracy 0.9874666929244995
Test set accuracy 0.976699948310852
Epoch 3 in 2.93 sec
Training set accuracy 0.9911167025566101
Test set accuracy 0.9785999655723572
Epoch 4 in 2.94 sec
Training set accuracy 0.9934166669845581
Test set accuracy 0.9788999557495117
Epoch 5 in 2.97 sec
Training set accuracy 0.9956166744232178
Test set accuracy 0.9799000024795532
Epoch 6 in 3.15 sec
Training set accuracy 0.996399998664856
Test set accuracy 0.9799999594688416
Epoch 7 in 2.99 sec
Training set accuracy 0.9974666833877563
Test set accuracy 0.9810999631881714


##### MetaInit in Jax for MNIST classification

In [0]:
#@title MetaInit functions
import jax.ops as ops

@jit
def qg_prod(params, x, y):
  """ Helper function for `gradient_quotent_jax`. """
  grads = grad(loss)(params, x, y)
  return np.sum([(g[0]**2).sum() / 2 for g in grads])

@jit   ## @functools.partial(jit, static_argnums=(0,))
def gradient_quotient_jax(params, x, y, eps=np.float32(1e-5)):
  """ Equation 1.
  params: list of params
  x: input to the model
  y: targets
  """
  _grad = grad(loss)(params, x, y)  # len 3 (#layers)
  _prod = grad(qg_prod)(params, x, y)    # len 3(#layers), each tuple
  out = np.sum([np.sum(np.abs((g - p) / (g + eps * (2 * 
                np.array(g >= 0, dtype=np.float32) - 1)) 
                - 1)) for (g, _), (p, _) in zip(_grad, _prod)])
  return out / sum([p.size for p, _ in params])


def metainit_jax(params, loss, x_size, y_size, lr=0.1,
                 momentum=0.9, steps=500, 
                 eps=np.float32(1e-5), key=10,
                 print_freq=100):
  key = random.PRNGKey(key)
  n_targets = y_size[1]
  memory = [0] * len(params) 
  for i in range(steps):
    key, *subkeys = random.split(key, 3)
    inputs = random.normal(subkeys[0], shape=x_size)  # cant be jitted
    target = one_hot(random.randint(subkeys[1], shape=(y_size[0],), 
                                    minval=0, maxval=n_targets), 
                     n_targets)
    _grad = grad(gradient_quotient_jax)(params, inputs, target, eps)

    if i % print_freq  == 0 or i == (steps-1):
      print("%d/GQ = %.2f" % (i, 
            gradient_quotient_jax(params, inputs, target, eps)))

    for j, ((p, _), (g_all, _)) in enumerate(zip(params, _grad)):
      norm = np.linalg.norm(p)
      g = np.sign(np.sum(p * g_all) / norm)
      memory[j] = momentum * memory[j] - lr * g
      new_norm = norm + memory[j]
      params[j] = (p * (new_norm / norm), params[j][1]) 
  return params

# TEST:
# params = init_network_params(layer_sizes, random.PRNGKey(0), zero_bias=True) 
# params = metainit_jax(params, loss, test_images.shape, test_labels.shape,
#                       lr=step_size, momentum=.0)

In [11]:
#@title Training loop with MetaInit
params = init_network_params(layer_sizes, random.PRNGKey(0), zero_bias=False)
params =  metainit_jax(params, loss, test_images.shape, test_labels.shape,
                       lr=step_size, momentum=.0)

params = train_mnist_jax(params, training_generator, 
                         train_images, train_labels,
                         test_images, test_labels,
                         num_epochs, n_targets)

0/GQ = 919.96
100/GQ = 1109.75
200/GQ = 1184.07
300/GQ = 1451.96
400/GQ = 1217.61
499/GQ = 1310.14
Epoch 0 in 2.96 sec
Training set accuracy 0.9618000388145447
Test set accuracy 0.9571999907493591
Epoch 1 in 2.93 sec
Training set accuracy 0.9778333306312561
Test set accuracy 0.9698999524116516
Epoch 2 in 2.96 sec
Training set accuracy 0.9872833490371704
Test set accuracy 0.9764999747276306
Epoch 3 in 2.92 sec
Training set accuracy 0.9917666912078857
Test set accuracy 0.9788999557495117
Epoch 4 in 2.92 sec
Training set accuracy 0.9919166564941406
Test set accuracy 0.9768999814987183
Epoch 5 in 2.98 sec
Training set accuracy 0.9952666759490967
Test set accuracy 0.9805999994277954
Epoch 6 in 2.99 sec
Training set accuracy 0.996916651725769
Test set accuracy 0.9800999760627747
Epoch 7 in 3.18 sec
Training set accuracy 0.9980000257492065
Test set accuracy 0.98089998960495


In [12]:
#@title Training loop with MetaInit: biases are `init`ed to 0
params = init_network_params(layer_sizes, random.PRNGKey(0), zero_bias=True)
params =  metainit_jax(params, loss, test_images.shape, test_labels.shape,
                       lr=step_size, momentum=.0)
params = train_mnist_jax(params, training_generator, 
                         train_images, train_labels,
                         test_images, test_labels,
                         num_epochs, n_targets)

0/GQ = 1041.03
100/GQ = 1114.56
200/GQ = 1229.71
300/GQ = 1307.12
400/GQ = 1136.16
499/GQ = 1194.94
Epoch 0 in 3.00 sec
Training set accuracy 0.9610166549682617
Test set accuracy 0.9560999870300293
Epoch 1 in 3.01 sec
Training set accuracy 0.9790000319480896
Test set accuracy 0.9710999727249146
Epoch 2 in 3.20 sec
Training set accuracy 0.9879666566848755
Test set accuracy 0.9774999618530273
Epoch 3 in 2.92 sec
Training set accuracy 0.9922167062759399
Test set accuracy 0.9793999791145325
Epoch 4 in 2.98 sec
Training set accuracy 0.9900833368301392
Test set accuracy 0.975600004196167
Epoch 5 in 2.95 sec
Training set accuracy 0.995983362197876
Test set accuracy 0.9802999496459961
Epoch 6 in 2.96 sec
Training set accuracy 0.9972500205039978
Test set accuracy 0.9807999730110168
Epoch 7 in 2.98 sec
Training set accuracy 0.9977499842643738
Test set accuracy 0.9805999994277954


### PyTorch

In [13]:
#@title network class and functions
import numpy as onp
import torch
import torch.nn as nn
# jax -> pytorch
torch.from_jax = lambda x: torch.from_numpy(onp.asarray(x))   
# pytorch -> jax
np.astensor =  lambda x: np.asarray(x.numpy())

class MLP(nn.Module):
  def __init__(self, layer_sizes, non_lin=nn.ReLU):
    super().__init__()
    layers = []
    for i in range(1, len(layer_sizes)-1):
      layers.append(nn.Linear(layer_sizes[i-1], layer_sizes[i]))
      layers.append(non_lin())
    layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1]))
    self.main = nn.Sequential(*layers)
    self.n_param_layers = len(list(self.main.named_parameters())) // 2

  def forward(self, x):
    x = self.main(x)  # bsize x n_labels 
    y = torch.logsumexp(x, dim=1)
    return x.sub(y.view(-1, 1))
  
  def cp_params(self, cp_params):
    if type(cp_params) is not list:
      raise TypeError("Expected list. Got {}".format(type(cp_params)))
    if self.n_param_layers != len(cp_params):
      raise ValueError("Expected equal len. Got {} and {}.".format(
          len(self.main.named_parameters()), len(cp_params)))
    
    _modules = list(self.main.modules())[0]
    with torch.no_grad():
      for i, (w, b) in enumerate(cp_params):
        _modules[i*2].weight.copy_(torch.from_numpy(onp.asarray(w)).float())
        _modules[i*2].bias.copy_(torch.from_numpy(onp.asarray(b)).float())

def accuracy_pytorch(net, images, targets):
  target_class = np.argmax(targets, axis=1)
  net.cpu()
  with torch.no_grad():
    output = np.astensor(net(torch.from_jax(images)))
    predicted_class = np.argmax(output, axis=1)
  net.cuda()
  return np.mean(predicted_class == target_class)


def criterion(x, y):
  return - torch.sum(x * y)

def train_mnist_pytorch(net, training_generator, 
                        tr_images, tr_labels, te_images, te_labels,
                        num_epochs=10, n_targets=10):
  optimizer = optim.SGD(net.parameters(), lr=step_size, momentum=0)
  device = torch.device("cuda") # todo: if torch.cuda.is_available():
  for epoch in range(num_epochs):
    start_time = time.time()
    for x, y in training_generator:  # x: bsize x 784
      labels = one_hot(y, n_targets)  # y: bsize x 10
      labels = torch.from_jax(labels).long().to(device)
      inputs = torch.from_numpy(x).float().to(device)
      optimizer.zero_grad()
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
    epoch_time = time.time() - start_time

    train_acc = accuracy_pytorch(net, tr_images, tr_labels)
    test_acc = accuracy_pytorch(net, te_images, te_labels)
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))

# --- Sanity check: obtain same test acc for the trained params as above
net = MLP(layer_sizes)
net.cp_params(params)  # assumes that params are traiend in above cells
print('pytorch+copied params test accuracy: {}'.format(
    accuracy_pytorch(net, test_images, test_labels)))

pytorch+copied params test accuracy: 0.9806000590324402


In [14]:
#@title training in pytorch
import torch.optim as optim

params = init_network_params(layer_sizes, random.PRNGKey(0)) # JAX
net = MLP(layer_sizes).cuda()
net.cp_params(params)

train_mnist_pytorch(net, training_generator, 
                    tr_images=_train_images, tr_labels=train_labels,
                    te_images=test_images, te_labels=test_labels,
                    num_epochs=num_epochs, n_targets=n_targets)

Epoch 0 in 3.66 sec
Training set accuracy 0.9591833353042603
Test set accuracy 0.9552000164985657
Epoch 1 in 3.38 sec
Training set accuracy 0.9778666496276855
Test set accuracy 0.9711000323295593
Epoch 2 in 3.29 sec
Training set accuracy 0.9874500036239624
Test set accuracy 0.9777000546455383
Epoch 3 in 3.29 sec
Training set accuracy 0.9908833503723145
Test set accuracy 0.9794000387191772
Epoch 4 in 3.29 sec
Training set accuracy 0.9940833449363708
Test set accuracy 0.9789000749588013
Epoch 5 in 3.56 sec
Training set accuracy 0.9948500394821167
Test set accuracy 0.9787000417709351
Epoch 6 in 3.35 sec
Training set accuracy 0.9974166750907898
Test set accuracy 0.9803000688552856
Epoch 7 in 3.29 sec
Training set accuracy 0.9975500106811523
Test set accuracy 0.9809000492095947



#### MetaInit in  PyTorch

In [0]:
#@title code from [MetaInit paper](https://papers.nips.cc/paper/9427-metainit-initializing-learning-by-learning-to-initialize.pdf)

import torch

def gradient_quotient(loss, params, eps=1e-5):
  """ Eq. 1 in paper. Returns a number."""
  grad = torch.autograd.grad(loss, params, 
                             retain_graph=True,
                             create_graph=True)
  prod = torch.autograd.grad(sum([(g**2).sum() / 2 for g in grad]),
                             params, retain_graph=True,
                             create_graph=True)
  out = sum([((g - p) / (g + eps * (2*(g >= 0).float() - 1).detach()) 
            - 1).abs().sum() for g, p in zip(grad, prod)])
  return out / sum([p.data.nelement() for p in params])

def metainit(model, criterion, x_size, y_size, lr=0.1,
             momentum=0.9, steps=500, eps=1e-5):
  model.eval()
  params = [p for p in model.main.parameters() 
            if p.requires_grad and len(p.size()) >= 2]  # omits biases
  memory = [0] * len(params)
  for i in range(steps):
    inputs = torch.Tensor(*x_size).normal_(0, 1).cuda()
    target = one_hot_pytorch(torch.randint(0, y_size, (x_size[0],1)),
                             nb_targets=y_size)
    loss = criterion(model(inputs), target)
    gq = gradient_quotient(loss, params, eps)  # or list(model.parameters()) 
    grad = torch.autograd.grad(gq, params)
    for j, (p, g_all) in enumerate(zip(params, grad)):
      norm = p.data.norm().item()
      g = torch.sign((p.data * g_all).sum() / norm)
      memory[j] = momentum * memory[j] - lr * g.item()
      new_norm = norm + memory[j]
      p.data.mul_(new_norm / norm)
    if i % 100  == 0 or i == (steps-1):
      print("%d/GQ = %.2f" % (i, gq.item())) 

def one_hot_pytorch(x, nb_targets, dtype=torch.FloatTensor, cuda=True):
  x_onehot = dtype(x.size()[0], nb_targets).zero_()
  x_onehot.scatter_(1, x, 1)
  if cuda:
    x_onehot = x_onehot.cuda()
  return x_onehot

In [17]:
#@title training in pytorch with MetaInit
import torch.optim as optim

params = init_network_params(layer_sizes, random.PRNGKey(0), zero_bias=False) # JAX
net = MLP(layer_sizes).cuda()
net.cp_params(params)

metainit(net, criterion, test_images.shape, 
         y_size=10, lr=step_size, momentum=.0)

train_mnist_pytorch(net, training_generator, 
                    tr_images=_train_images, tr_labels=train_labels,
                    te_images=test_images, te_labels=test_labels,
                    num_epochs=num_epochs, n_targets=n_targets)

0/GQ = 1102.85
100/GQ = 1291.64
200/GQ = 1234.79
300/GQ = 1238.66
400/GQ = 1362.68
499/GQ = 1217.70
Epoch 0 in 3.68 sec
Training set accuracy 0.9599000215530396
Test set accuracy 0.955500066280365
Epoch 1 in 3.50 sec
Training set accuracy 0.9781000018119812
Test set accuracy 0.9700000286102295
Epoch 2 in 3.54 sec
Training set accuracy 0.9865666627883911
Test set accuracy 0.9757000207901001
Epoch 3 in 3.42 sec
Training set accuracy 0.9917333722114563
Test set accuracy 0.978600025177002
Epoch 4 in 3.43 sec
Training set accuracy 0.9900500178337097
Test set accuracy 0.9760000705718994
Epoch 5 in 3.63 sec
Training set accuracy 0.9961667060852051
Test set accuracy 0.9798000454902649
Epoch 6 in 3.43 sec
Training set accuracy 0.9975166916847229
Test set accuracy 0.9807000756263733
Epoch 7 in 3.34 sec
Training set accuracy 0.9983000159263611
Test set accuracy 0.9808000326156616


In [18]:
#@title training in pytorch with MetaInit: biases are `init`ed to 0
import torch.optim as optim

params = init_network_params(layer_sizes, random.PRNGKey(0), zero_bias=True)
net = MLP(layer_sizes).cuda()
net.cp_params(params)

metainit(net, criterion, test_images.shape, 
         y_size=10, lr=step_size, momentum=.0)

train_mnist_pytorch(net, training_generator, 
                    tr_images=_train_images, tr_labels=train_labels,
                    te_images=test_images, te_labels=test_labels,
                    num_epochs=num_epochs, n_targets=n_targets)

0/GQ = 1368.74
100/GQ = 1001.26
200/GQ = 1527.70
300/GQ = 1324.46
400/GQ = 1231.78
499/GQ = 875.56
Epoch 0 in 3.41 sec
Training set accuracy 0.9573833346366882
Test set accuracy 0.9548000693321228
Epoch 1 in 3.37 sec
Training set accuracy 0.9773833155632019
Test set accuracy 0.9690000414848328
Epoch 2 in 3.31 sec
Training set accuracy 0.987333357334137
Test set accuracy 0.9765000343322754
Epoch 3 in 3.32 sec
Training set accuracy 0.9915000200271606
Test set accuracy 0.9780000448226929
Epoch 4 in 3.31 sec
Training set accuracy 0.9936333298683167
Test set accuracy 0.9788000583648682
Epoch 5 in 3.27 sec
Training set accuracy 0.9955166578292847
Test set accuracy 0.979900062084198
Epoch 6 in 3.33 sec
Training set accuracy 0.996916651725769
Test set accuracy 0.9800000190734863
Epoch 7 in 3.50 sec
Training set accuracy 0.9977333545684814
Test set accuracy 0.9800000190734863


## Verification

MNIST classification without MetaInit: sanity check if we (at least inittially) get same loss values up to some precision


#### JAX
```
loss jax Traced<ConcreteArray(320.86725)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(393.53812)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(275.22656)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(273.5281)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(230.26147)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(169.82643)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(124.452896)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(109.571625)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(134.94353)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(117.75346)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(190.89735)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(184.23071)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(115.22882)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(79.86719)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(72.12949)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(73.87135)>with<JVPTrace(level=2/0)>
loss jax Traced<ConcreteArray(56.154232)>with<JVPTrace(level=2/0)>
```
------

#### PyTorch

```
loss pytorch  320.8672180175781
loss pytorch  393.5380859375
loss pytorch  275.2265625
loss pytorch  273.528076171875
loss pytorch  230.26145935058594
loss pytorch  169.82643127441406
loss pytorch  124.45288848876953
loss pytorch  109.57162475585938
loss pytorch  134.94351196289062
loss pytorch  117.75344848632812
loss pytorch  190.89739990234375
loss pytorch  184.23074340820312
loss pytorch  115.22883605957031
loss pytorch  79.8671875
loss pytorch  72.12948608398438
loss pytorch  73.87134552001953
loss pytorch  56.154232025146484
```