In [1]:
%%capture
!pip install transformers datasets
!pip install flax optaxjaxlib
!pip install numpy pandas
!pip install tqdm
!pip install scikit-learn
!pip install evaluate

In [2]:
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


In [3]:
import jax
print(jax.devices())

[CudaDevice(id=0)]


# Loading The Data


In [4]:
from datasets import Dataset
import pandas as pd
import os

def load_data_from_dir(dir_path, label):
    data = []
    for filename in sorted(os.listdir(dir_path)):
        file_path = os.path.join(dir_path, filename)

        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                text = f.read().strip()
        except UnicodeDecodeError:
            with open(file_path, 'r', encoding='ISO-8859-9') as f:
                text = f.read().strip()

        data.append({'sentence': text, 'labels': label})
    return data

neg_dir   = "/content/sorted_news/neg"
pos_dir   = "/content/sorted_news/pos"
neg_data  = load_data_from_dir(neg_dir, 0)
pos_data  = load_data_from_dir(pos_dir, 1)
all_data  = neg_data + pos_data

seed = 35
full_dataset = Dataset.from_pandas(pd.DataFrame(all_data))
full_dataset = full_dataset.shuffle(seed = seed)

print(full_dataset[2:5])
# for i in full_dataset:
#    print(i)

train_test_split = full_dataset.train_test_split(test_size=0.2)
train_dataset = train_test_split['train']
test_dataset = train_test_split['test']

print(train_dataset)
print(test_dataset)

{'sentence': ['Cuma günü tarihi zirvesini 4.552,44 seviyesine taşıyan BIST-100 Endeksi kapanışa doğru etkili olan kâr satışlarıyla gün içi kazançlarını geri verdi. Yaşanan güçlü yükselişlerin ardından son günlerde endekste ve ana hisselerde gözlenen yorulma emareleri ve teknik indikatörlerdeki negatif uyuşmazlıklar olası bir düzeltme ihtimalini artırıyor. Endekste kısa vadede 4.400 seviyesi kısa vadeli destek olarak izlenecek olup, bu seviye altında 4.333 \x96 4.297 \x96 4.234 - 4.150 ve 4.100 seviyeleri destek olarak takip edilebilir. Endekste kısa vadede 4.400 üzerinde kalıcılığın korunması yükselişlerin devamlılığı açısından önem taşımaktadır. Endekste 4.400 üzerindeki tutunmanın korunması ve 4.500 üzerinde kapanışların yaşanması durumunda ise 4.552 \x96 4.575 ve 4.600 seviyeleri direnç konumunda bulunmaktadır.', 'Haftaya global eğilime paralel sınırlı zayıf eğilimle başlangıç beklediğimiz  endekste ilk aşamada 3500-3490 destek, 3610-3630 direnç bölgelerinin son 6 işlem gününde  olu

# Tokenizing the Data

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('dbmdz/bert-base-turkish-cased')

def preprocess_function(examples):
    texts     = (examples['sentence'],)
    processed = tokenizer(*texts, padding="max_length", max_length=256, truncation=True)

    processed["labels"] = examples["labels"]
    return processed

train_dataset = train_dataset.map(preprocess_function, batched=True)
test_dataset  = test_dataset.map(preprocess_function, batched=True)

train_dataset.set_format(type='np', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='np', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

print(train_dataset)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Map:   0%|          | 0/1364 [00:00<?, ? examples/s]

Map:   0%|          | 0/342 [00:00<?, ? examples/s]

Dataset({
    features: ['sentence', 'labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1364
})


# Loading the Model

In [6]:
from transformers import FlaxBertForSequenceClassification

id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

model = FlaxBertForSequenceClassification.from_pretrained(
    'dbmdz/bert-base-turkish-cased',
    num_labels = 2,
    id2label   = id2label,
    label2id   = label2id
)

Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at dbmdz/bert-base-turkish-cased and are newly initialized: {('classifier', 'kernel'), ('classifier', 'bias')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Evaluating Before

In [7]:
import evaluate
from tqdm import tqdm
import numpy as np
from sklearn.metrics import confusion_matrix

def evaluate_model(model, eval_dataset):
    accuracy_metric  = evaluate.load('accuracy')
    precision_metric = evaluate.load('precision')
    recall_metric    = evaluate.load('recall')

    # for conf matrix
    all_predictions = []
    all_labels = []

    batch_size = 16

    num_batches  = len(eval_dataset) // batch_size
    progress_bar = tqdm(range(num_batches), desc="Evaluating")

    for i in progress_bar:
        batch = eval_dataset[i*batch_size:(i+1)*batch_size]

        input_ids      = jax.device_put(np.array(batch['input_ids']))
        attention_mask = jax.device_put(np.array(batch['attention_mask']))
        labels         = jax.device_put(np.array(batch['labels']))

        # Forward pass through the model
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits  = outputs.logits

        # Convert JAX predictions to numpy and get the predicted class
        predictions = np.argmax(logits, axis=-1)
        predictions = np.array(predictions)

        all_predictions.extend(predictions)
        all_labels.extend(labels)

        accuracy_metric.add_batch(predictions=predictions, references=labels)
        precision_metric.add_batch(predictions=predictions, references=labels)
        recall_metric.add_batch(predictions=predictions, references=labels)

    accuracy_result = accuracy_metric.compute()
    precision_result = precision_metric.compute()
    recall_result = recall_metric.compute()

    confusion_mat = confusion_matrix(all_labels, all_predictions)

    return {
        "accuracy": accuracy_result['accuracy'],
        "precision": precision_result['precision'],
        "recall": recall_result['recall'],
        "confusion_matrix": confusion_mat
    }

In [8]:
eval_results = evaluate_model(model, test_dataset)

print(f"Accuracy: {eval_results['accuracy']:.4f}")
print(f"Precision: {eval_results['precision']:.4f}")
print(f"Recall: {eval_results['recall']:.4f}")
print(f"Confusion Matrix:\n{eval_results['confusion_matrix']}")

Evaluating: 100%|██████████| 21/21 [00:19<00:00,  1.09it/s]

Accuracy: 0.4554
Precision: 0.7500
Recall: 0.0162
Confusion Matrix:
[[150   1]
 [182   3]]





In [9]:
def after_evaluate_model(model, params, eval_dataset):
    # Load evaluation metrics from the evaluate library
    accuracy_metric = evaluate.load('accuracy')
    precision_metric = evaluate.load('precision')
    recall_metric = evaluate.load('recall')

    # Lists to store all predictions and labels for confusion matrix calculation
    all_predictions = []
    all_labels = []

    batch_size = 16  # Define your batch size

    num_batches = len(eval_dataset) // batch_size
    progress_bar = tqdm(range(num_batches), desc="Evaluating")

    for i in progress_bar:
        # Slice the dataset into batches
        batch = eval_dataset[i * batch_size:(i + 1) * batch_size]

        # Move data to JAX devices
        input_ids = jax.device_put(np.array(batch['input_ids']))
        attention_mask = jax.device_put(np.array(batch['attention_mask']))
        labels = jax.device_put(np.array(batch['labels']))

        # Forward pass through the model using the trained parameters
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, params=params, train=False)
        logits = outputs.logits

        # Convert logits to class predictions (0 or 1)
        predictions = np.argmax(logits, axis=-1)
        predictions = np.array(predictions)

        # Append predictions and labels for confusion matrix calculation
        all_predictions.extend(predictions)
        all_labels.extend(labels)

        # Add batches to the evaluation metrics
        accuracy_metric.add_batch(predictions=predictions, references=labels)
        precision_metric.add_batch(predictions=predictions, references=labels)
        recall_metric.add_batch(predictions=predictions, references=labels)

    # Compute evaluation results
    accuracy_result = accuracy_metric.compute()
    precision_result = precision_metric.compute()
    recall_result = recall_metric.compute()

    # Compute confusion matrix
    confusion_mat = confusion_matrix(all_labels, all_predictions)

    # Return the evaluation results
    return {
        "accuracy": accuracy_result['accuracy'],
        "precision": precision_result['precision'],
        "recall": recall_result['recall'],
        "confusion_matrix": confusion_mat
    }

# Training

In [10]:
import optax  # Optimizer library for JAX
import jax
from flax.training import train_state

learning_rate = 2e-5
tx = optax.adamw(learning_rate=learning_rate)

state = train_state.TrainState.create(
    apply_fn=model.__call__,
    params=jax.device_put(model.params),
    tx=tx)


In [11]:
def train_step(state, batch, rng):
    """Performs a training step."""

    def loss_fn(params):
        logits = model(input_ids=batch['input_ids'],
                       attention_mask=batch['attention_mask'],
                       params=params,
                       dropout_rng=rng,
                       train=True).logits
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['labels'])
        print(loss.mean)
        return loss.mean()

    # Compute gradients
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)

    state = state.apply_gradients(grads=grads)

    return state, loss


In [12]:
from tqdm import tqdm

def train_model(state, train_dataset, num_epochs=3, batch_size=16):
    num_batches = len(train_dataset) // batch_size
    rng = jax.random.PRNGKey(0)

    rng, new_rng = jax.random.split(rng)
    rng = jax.device_put(rng)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        progress_bar = tqdm(range(num_batches), desc="Training")

        for i in progress_bar:
            batch = train_dataset[i*batch_size:(i+1)*batch_size]
            batch = {k: jax.device_put(v) for k, v in batch.items()}

            rng, new_rng = jax.random.split(rng)
            state, loss = train_step(state, batch, new_rng)

            progress_bar.set_postfix({"loss": loss.item()})

        print(f"Evaluating after epoch {epoch + 1}...")
        eval_results = after_evaluate_model(model, state.params, test_dataset)

        # Print evaluation results
        print(f"Accuracy: {eval_results['accuracy']:.4f}")
        print(f"Precision: {eval_results['precision']:.4f}")
        print(f"Recall: {eval_results['recall']:.4f}")
        print(f"Confusion Matrix:\n{eval_results['confusion_matrix']}")

    return state

In [13]:
state = train_model(state, train_dataset)

Epoch 1/3


Training:   0%|          | 0/85 [00:00<?, ?it/s]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.47242653 1.1312478  0.39314756 0.7948002  0.63561356 0.9901922
 0.34019652 1.3354644  1.550159   0.932637   1.4192448  1.0556989
 0.85567945 0.8104117  1.2469792  1.299995  ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.47242653, 1.1312478 , 0.39314756, 0.7948002 , 0.63561356,
       0.9901922 , 0.34019652, 1.3354644 , 1.550159  , 0.932637  ,
       1.4192448 , 1.0556989 , 0.85567945, 0.8104117 , 1.2469792 ,
       1.299995  ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196d250>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a7f6f0; to 'JaxprTracer' at 0x7c31a1a7fba0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:   1%|          | 1/85 [00:12<16:52, 12.06s/it, loss=0.954]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.55466515 0.4072841  0.6829692  0.5432238  1.0124465  0.2603053
 0.89486057 1.0057849  1.0634753  0.7047622  0.7201892  0.73057795
 0.54509234 1.0175405  1.419709   0.87405515], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.55466515, 0.4072841 , 0.6829692 , 0.5432238 , 1.0124465 ,
       0.2603053 , 0.89486057, 1.0057849 , 1.0634753 , 0.7047622 ,
       0.7201892 , 0.73057795, 0.54509234, 1.0175405 , 1.419709  ,
       0.87405515], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a15196e0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a2f2e0; to 'JaxprTracer' at 0x7c31a1a2cdb0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:   2%|▏         | 2/85 [00:14<08:56,  6.47s/it, loss=0.777]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.45930505 0.78970647 0.39886853 1.1944317  1.390727   0.5903328
 0.9832037  0.7030411  0.2466994  0.97599244 0.9611197  0.6429114
 0.7046882  0.33156762 1.1580653  0.9939206 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.45930505, 0.78970647, 0.39886853, 1.1944317 , 1.390727  ,
       0.5903328 , 0.9832037 , 0.7030411 , 0.2466994 , 0.97599244,
       0.9611197 , 0.6429114 , 0.7046882 , 0.33156762, 1.1580653 ,
       0.9939206 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a15cd760>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a55d00; to 'JaxprTracer' at 0x7c31a1a56070>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:   4%|▎         | 3/85 [00:17<06:22,  4.67s/it, loss=0.783]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.42962253 0.821274   1.1155201  0.4406823  0.46224552 0.5167944
 0.5091909  0.98929626 0.8075218  0.59010524 0.41932032 0.45411408
 0.86086506 0.5577531  0.72886837 0.68552864], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.42962253, 0.821274  , 1.1155201 , 0.4406823 , 0.46224552,
       0.5167944 , 0.5091909 , 0.98929626, 0.8075218 , 0.59010524,
       0.41932032, 0.45411408, 0.86086506, 0.5577531 , 0.72886837,
       0.68552864], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a15917c0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1c54860; to 'JaxprTracer' at 0x7c31a1c57b50>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:   5%|▍         | 4/85 [00:20<05:22,  3.98s/it, loss=0.649]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.4979474  0.35635766 0.4364559  0.42432544 1.4658659  0.9331718
 1.2865661  0.50432634 1.1153369  0.22141229 0.9141086  1.0088587
 0.44393685 1.070999   0.47698286 0.7529982 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.4979474 , 0.35635766, 0.4364559 , 0.42432544, 1.4658659 ,
       0.9331718 , 1.2865661 , 0.50432634, 1.1153369 , 0.22141229,
       0.9141086 , 1.0088587 , 0.44393685, 1.070999  , 0.47698286,
       0.7529982 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a15cd820>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a7e750; to 'JaxprTracer' at 0x7c31a1a7c220>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:   6%|▌         | 5/85 [00:22<04:45,  3.57s/it, loss=0.744]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.61458    0.54176027 0.54471517 0.65598124 0.29671925 1.0271785
 0.6530111  0.7474393  0.58059156 0.35513362 0.7088153  0.9410563
 1.2638801  0.3968453  1.0239205  0.7783472 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.61458   , 0.54176027, 0.54471517, 0.65598124, 0.29671925,
       1.0271785 , 0.6530111 , 0.7474393 , 0.58059156, 0.35513362,
       0.7088153 , 0.9410563 , 1.2638801 , 0.3968453 , 1.0239205 ,
       0.7783472 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a15f1850>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ba12b0; to 'JaxprTracer' at 0x7c31a1ba3a10>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:   7%|▋         | 6/85 [00:25<04:22,  3.32s/it, loss=0.696]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.31209126 0.91090524 0.26412836 0.6508415  0.29579133 0.27151465
 0.44978625 0.8045969  0.58797204 0.7076105  1.5888437  0.24989492
 0.7692574  0.5758474  0.41741952 0.6578537 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.31209126, 0.91090524, 0.26412836, 0.6508415 , 0.29579133,
       0.27151465, 0.44978625, 0.8045969 , 0.58797204, 0.7076105 ,
       1.5888437 , 0.24989492, 0.7692574 , 0.5758474 , 0.41741952,
       0.6578537 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a15dd930>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a22700; to 'JaxprTracer' at 0x7c31a1a232e0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:   8%|▊         | 7/85 [00:28<03:59,  3.06s/it, loss=0.595]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.8895738  0.9112375  0.30778638 0.4137913  0.27210274 0.6170012
 0.41523507 0.533598   0.25444636 0.7339643  0.36626473 0.40977746
 1.5162944  0.49619067 0.6938741  0.80361885], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.8895738 , 0.9112375 , 0.30778638, 0.4137913 , 0.27210274,
       0.6170012 , 0.41523507, 0.533598  , 0.25444636, 0.7339643 ,
       0.36626473, 0.40977746, 1.5162944 , 0.49619067, 0.6938741 ,
       0.80361885], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a14dd970>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ac22a0; to 'JaxprTracer' at 0x7c31a1ac0900>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:   9%|▉         | 8/85 [00:31<04:03,  3.16s/it, loss=0.602]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.613163   0.82213384 1.2538716  0.8929019  0.29952437 1.5731236
 0.9351357  0.82941604 0.8062358  0.44458446 0.2824004  1.2012172
 1.1260085  0.22182958 0.431992   0.6979302 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.613163  , 0.82213384, 1.2538716 , 0.8929019 , 0.29952437,
       1.5731236 , 0.9351357 , 0.82941604, 0.8062358 , 0.44458446,
       0.2824004 , 1.2012172 , 1.1260085 , 0.22182958, 0.431992  ,
       0.6979302 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a14ed9b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a164fdd0; to 'JaxprTracer' at 0x7c31a164c4f0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  11%|█         | 9/85 [00:34<04:00,  3.16s/it, loss=0.777]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.44169313 0.5605985  0.2609786  0.61099833 1.2854424  0.5569312
 0.66078836 0.648501   0.491046   0.9956659  0.5059937  0.5142631
 0.71993923 0.5253342  0.5365274  1.3732965 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.44169313, 0.5605985 , 0.2609786 , 0.61099833, 1.2854424 ,
       0.5569312 , 0.66078836, 0.648501  , 0.491046  , 0.9956659 ,
       0.5059937 , 0.5142631 , 0.71993923, 0.5253342 , 0.5365274 ,
       1.3732965 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a14e1a10>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a165d8a0; to 'JaxprTracer' at 0x7c31a165c0e0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  12%|█▏        | 10/85 [00:37<03:56,  3.16s/it, loss=0.668]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.7290268  1.3083252  0.512354   0.81252456 0.36897126 1.8305178
 1.3934124  0.8160329  0.6148239  1.0692539  1.383653   0.66884184
 1.3310932  1.0331653  0.28000185 0.4159424 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.7290268 , 1.3083252 , 0.512354  , 0.81252456, 0.36897126,
       1.8305178 , 1.3934124 , 0.8160329 , 0.6148239 , 1.0692539 ,
       1.383653  , 0.66884184, 1.3310932 , 1.0331653 , 0.28000185,
       0.4159424 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1479a20>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a48310; to 'JaxprTracer' at 0x7c31a1a480e0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  13%|█▎        | 11/85 [00:42<04:15,  3.45s/it, loss=0.91]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.5370989  0.46482015 0.4594398  0.47198385 1.1427051  0.37951913
 1.2883222  0.4730027  0.39024535 1.2285382  0.32785314 0.93347573
 0.42749608 0.5590021  0.4296818  0.5357875 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.5370989 , 0.46482015, 0.4594398 , 0.47198385, 1.1427051 ,
       0.37951913, 1.2883222 , 0.4730027 , 0.39024535, 1.2285382 ,
       0.32785314, 0.93347573, 0.42749608, 0.5590021 , 0.4296818 ,
       0.5357875 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a14e1a50>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a17e02c0; to 'JaxprTracer' at 0x7c31a17e0f40>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  14%|█▍        | 12/85 [00:44<03:55,  3.22s/it, loss=0.628]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.33433455 0.32315072 0.5616638  1.3960114  0.5615688  0.9625906
 1.0893615  0.2081591  0.415745   1.3692154  1.3257483  0.4600299
 0.526405   0.3169423  0.21566941 1.1931751 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.33433455, 0.32315072, 0.5616638 , 1.3960114 , 0.5615688 ,
       0.9625906 , 1.0893615 , 0.2081591 , 0.415745  , 1.3692154 ,
       1.3257483 , 0.4600299 , 0.526405  , 0.3169423 , 0.21566941,
       1.1931751 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1479a60>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1c97b50; to 'JaxprTracer' at 0x7c31a1c96c00>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  15%|█▌        | 13/85 [00:47<03:48,  3.17s/it, loss=0.704]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.71910214 0.42794532 0.49575996 1.2321346  0.4774418  1.1493933
 0.9631913  1.0499656  1.2175045  0.909097   1.0107353  0.62780046
 0.38330093 0.21404947 0.35579565 0.32604882], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.71910214, 0.42794532, 0.49575996, 1.2321346 , 0.4774418 ,
       1.1493933 , 0.9631913 , 1.0499656 , 1.2175045 , 0.909097  ,
       1.0107353 , 0.62780046, 0.38330093, 0.21404947, 0.35579565,
       0.32604882], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a13b9aa0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1604310; to 'JaxprTracer' at 0x7c31a16063e0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  16%|█▋        | 14/85 [00:50<03:30,  2.97s/it, loss=0.722]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.26421174 0.7407701  0.8407533  1.0760262  0.60210073 0.66967195
 0.6945776  0.9565179  0.5670149  0.51446086 1.3814331  0.6431542
 0.7197126  0.5738841  0.3351499  0.67827183], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.26421174, 0.7407701 , 0.8407533 , 1.0760262 , 0.60210073,
       0.66967195, 0.6945776 , 0.9565179 , 0.5670149 , 0.51446086,
       1.3814331 , 0.6431542 , 0.7197126 , 0.5738841 , 0.3351499 ,
       0.67827183], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a13e5ab0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1605f80; to 'JaxprTracer' at 0x7c31a16067f0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  18%|█▊        | 15/85 [00:53<03:26,  2.94s/it, loss=0.704]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.7318731  0.57889676 0.8883345  1.000871   0.4128353  1.0245141
 0.70557886 0.6498471  0.385611   0.5141478  0.90968114 0.54055285
 0.55628026 0.3389335  0.53648925 0.44415617], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.7318731 , 0.57889676, 0.8883345 , 1.000871  , 0.4128353 ,
       1.0245141 , 0.70557886, 0.6498471 , 0.385611  , 0.5141478 ,
       0.90968114, 0.54055285, 0.55628026, 0.3389335 , 0.53648925,
       0.44415617], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a131daa0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1606b60; to 'JaxprTracer' at 0x7c31a1605210>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  19%|█▉        | 16/85 [00:55<03:15,  2.83s/it, loss=0.639]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.6199726  0.9572302  0.47844243 0.80910504 0.9915552  0.72516376
 0.5134243  0.49362284 0.5982016  0.64010054 0.7937503  1.7390394
 0.88768816 0.5262048  0.97242147 0.86834526], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.6199726 , 0.9572302 , 0.47844243, 0.80910504, 0.9915552 ,
       0.72516376, 0.5134243 , 0.49362284, 0.5982016 , 0.64010054,
       0.7937503 , 1.7390394 , 0.88768816, 0.5262048 , 0.97242147,
       0.86834526], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a13edac0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1611df0; to 'JaxprTracer' at 0x7c31a1613560>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  20%|██        | 17/85 [00:58<03:12,  2.83s/it, loss=0.788]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.52899927 0.40739664 0.6871669  0.98375106 0.739133   0.6231907
 0.76451725 0.6888435  0.7759678  0.90942466 0.7146334  0.64504963
 0.7800666  0.9998024  0.7269424  0.6920889 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.52899927, 0.40739664, 0.6871669 , 0.98375106, 0.739133  ,
       0.6231907 , 0.76451725, 0.6888435 , 0.7759678 , 0.90942466,
       0.7146334 , 0.64504963, 0.7800666 , 0.9998024 , 0.7269424 ,
       0.6920889 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1341af0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1559ee0; to 'JaxprTracer' at 0x7c31a1558e50>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  21%|██        | 18/85 [01:01<03:12,  2.88s/it, loss=0.729]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.2426562  0.40270057 0.8497147  0.3824104  0.55963624 1.3444066
 0.44112626 0.4778235  0.6807986  0.63644767 0.6235806  0.48506403
 0.55194724 0.7558608  0.8810848  0.6009147 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.2426562 , 0.40270057, 0.8497147 , 0.3824104 , 0.55963624,
       1.3444066 , 0.44112626, 0.4778235 , 0.6807986 , 0.63644767,
       0.6235806 , 0.48506403, 0.55194724, 0.7558608 , 0.8810848 ,
       0.6009147 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a126db10>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1cee3e0; to 'JaxprTracer' at 0x7c31a1ceee80>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  22%|██▏       | 19/85 [01:04<03:03,  2.78s/it, loss=0.682]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.8044878  0.57611394 0.7825475  0.48209783 0.35530937 0.6487098
 0.7206357  0.58918655 0.6278315  0.59747267 0.37330884 0.556502
 1.2251649  0.6868497  1.0109537  0.9355315 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.8044878 , 0.57611394, 0.7825475 , 0.48209783, 0.35530937,
       0.6487098 , 0.7206357 , 0.58918655, 0.6278315 , 0.59747267,
       0.37330884, 0.556502  , 1.2251649 , 0.6868497 , 1.0109537 ,
       0.9355315 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a12c1b10>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a13ffb00; to 'JaxprTracer' at 0x7c31a13ffe20>], out_avals=[ShapedArray(float32[16])], primitive=pjit, 

Training:  24%|██▎       | 20/85 [01:07<03:03,  2.82s/it, loss=0.686]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.2534759  0.65213954 0.6053695  0.8927909  0.49925685 0.65180975
 0.22900233 0.82548064 0.6967556  0.64021164 0.5799304  0.7889499
 0.22566894 0.33312646 0.85054475 0.7474554 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.2534759 , 0.65213954, 0.6053695 , 0.8927909 , 0.49925685,
       0.65180975, 0.22900233, 0.82548064, 0.6967556 , 0.64021164,
       0.5799304 , 0.7889499 , 0.22566894, 0.33312646, 0.85054475,
       0.7474554 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a13b9b20>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a148e7a0; to 'JaxprTracer' at 0x7c31a148f790>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  25%|██▍       | 21/85 [01:09<02:54,  2.73s/it, loss=0.654]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.95906174 0.63019204 0.8019825  0.6318706  0.31456003 0.3222444
 0.81537354 0.7863827  0.627741   0.83191544 1.4179331  0.43774655
 0.8377543  0.31236866 0.3912381  0.19470634], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.95906174, 0.63019204, 0.8019825 , 0.6318706 , 0.31456003,
       0.3222444 , 0.81537354, 0.7863827 , 0.627741  , 0.83191544,
       1.4179331 , 0.43774655, 0.8377543 , 0.31236866, 0.3912381 ,
       0.19470634], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1325bd0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ac1bc0; to 'JaxprTracer' at 0x7c31a1ac3560>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  26%|██▌       | 22/85 [01:12<03:00,  2.87s/it, loss=0.645]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.39441654 0.5019969  0.53619987 0.63573635 0.44033644 0.46229386
 0.61194056 0.5384848  0.52599883 0.6028942  0.58889705 0.82397807
 0.8105687  0.6940143  0.656296   0.893943  ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.39441654, 0.5019969 , 0.53619987, 0.63573635, 0.44033644,
       0.46229386, 0.61194056, 0.5384848 , 0.52599883, 0.6028942 ,
       0.58889705, 0.82397807, 0.8105687 , 0.6940143 , 0.656296  ,
       0.893943  ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a12edbd0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a140ec50; to 'JaxprTracer' at 0x7c31a140e340>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  27%|██▋       | 23/85 [01:15<02:53,  2.80s/it, loss=0.607]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.0081854  0.73798317 0.26485133 0.8235117  0.6435855  0.30076322
 1.5312937  0.52067053 0.5732278  1.2118926  0.2925281  1.0884461
 0.9772312  0.49459323 1.1320925  1.1094968 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.0081854 , 0.73798317, 0.26485133, 0.8235117 , 0.6435855 ,
       0.30076322, 1.5312937 , 0.52067053, 0.5732278 , 1.2118926 ,
       0.2925281 , 1.0884461 , 0.9772312 , 0.49459323, 1.1320925 ,
       1.1094968 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a12c5be0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16c6070; to 'JaxprTracer' at 0x7c31a16c49a0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  28%|██▊       | 24/85 [01:17<02:46,  2.72s/it, loss=0.794]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.09898473 0.48287353 1.2457452  0.18338914 0.46939647 0.37363666
 0.37545046 0.6516408  0.3105645  0.3960863  1.3340391  1.04151
 0.21630277 0.28604013 1.1694651  0.40820253], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.09898473, 0.48287353, 1.2457452 , 0.18338914, 0.46939647,
       0.37363666, 0.37545046, 0.6516408 , 0.3105645 , 0.3960863 ,
       1.3340391 , 1.04151   , 0.21630277, 0.28604013, 1.1694651 ,
       0.40820253], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119dc00>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15d0860; to 'JaxprTracer' at 0x7c31a15d15d0>], out_avals=[ShapedArray(float32[16])], primitive=pjit, 

Training:  29%|██▉       | 25/85 [01:20<02:47,  2.80s/it, loss=0.565]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.2212087  0.81209815 0.44319394 0.5959766  0.6059426  0.63742006
 0.6453938  0.26701903 0.43969676 0.91583467 0.62831515 0.9438747
 0.305541   1.3789614  0.4586373  0.4209078 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.2212087 , 0.81209815, 0.44319394, 0.5959766 , 0.6059426 ,
       0.63742006, 0.6453938 , 0.26701903, 0.43969676, 0.91583467,
       0.62831515, 0.9438747 , 0.305541  , 1.3789614 , 0.4586373 ,
       0.4209078 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119c850>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b7b010; to 'JaxprTracer' at 0x7c31a1b7ae80>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  31%|███       | 26/85 [01:23<02:42,  2.75s/it, loss=0.67]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.5743835  0.7725183  1.062999   0.92535394 0.39280298 0.30044973
 0.8062916  0.22817828 0.53402096 0.5327919  0.32913193 0.7231678
 1.1867969  1.1684375  0.52929664 0.5384822 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.5743835 , 0.7725183 , 1.062999  , 0.92535394, 0.39280298,
       0.30044973, 0.8062916 , 0.22817828, 0.53402096, 0.5327919 ,
       0.32913193, 0.7231678 , 1.1867969 , 1.1684375 , 0.52929664,
       0.5384822 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196df70>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a165f8d0; to 'JaxprTracer' at 0x7c31a16f92b0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  32%|███▏      | 27/85 [01:26<02:46,  2.87s/it, loss=0.663]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.32829973 1.4038484  0.5870446  0.7241082  1.7359926  0.6100152
 0.44322294 0.87245    1.2705818  0.7795093  0.9064827  0.82529783
 1.4401991  1.6439712  0.66155404 0.636546  ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.32829973, 1.4038484 , 0.5870446 , 0.7241082 , 1.7359926 ,
       0.6100152 , 0.44322294, 0.87245   , 1.2705818 , 0.7795093 ,
       0.9064827 , 0.82529783, 1.4401991 , 1.6439712 , 0.66155404,
       0.636546  ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b87c00>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1640d10; to 'JaxprTracer' at 0x7c31a1643290>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  33%|███▎      | 28/85 [01:29<02:37,  2.77s/it, loss=0.929]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.7454623  0.23321468 1.0742977  1.369025   0.87367207 0.45983416
 0.8762523  1.1635783  0.73500973 0.15921621 0.8665545  0.5027635
 0.92464054 0.26670647 1.0251188  0.3314603 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.7454623 , 0.23321468, 1.0742977 , 1.369025  , 0.87367207,
       0.45983416, 0.8762523 , 1.1635783 , 0.73500973, 0.15921621,
       0.8665545 , 0.5027635 , 0.92464054, 0.26670647, 1.0251188 ,
       0.3314603 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b849b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a13fc900; to 'JaxprTracer' at 0x7c31a13fcdb0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  34%|███▍      | 29/85 [01:31<02:31,  2.70s/it, loss=0.788]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.12405703 0.54154885 0.23092058 0.33587644 0.743211   0.5871846
 0.8577552  0.18002379 1.073457   0.0758558  0.08242435 0.677968
 1.1096144  1.1364555  0.50373876 0.18914899], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.12405703, 0.54154885, 0.23092058, 0.33587644, 0.743211  ,
       0.5871846 , 0.8577552 , 0.18002379, 1.073457  , 0.0758558 ,
       0.08242435, 0.677968  , 1.1096144 , 1.1364555 , 0.50373876,
       0.18914899], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655580>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1558fe0; to 'JaxprTracer' at 0x7c31a1558310>], out_avals=[ShapedArray(float32[16])], primitive=pjit, 

Training:  35%|███▌      | 30/85 [01:34<02:32,  2.77s/it, loss=0.528]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.1248717  1.4002957  0.51664823 0.5468997  0.74791527 0.16424866
 0.67348045 0.6098717  0.5421462  0.3015257  0.3632048  1.3936603
 0.6959722  0.82195485 0.49112922 0.5186732 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.1248717 , 1.4002957 , 0.51664823, 0.5468997 , 0.74791527,
       0.16424866, 0.67348045, 0.6098717 , 0.5421462 , 0.3015257 ,
       0.3632048 , 1.3936603 , 0.6959722 , 0.82195485, 0.49112922,
       0.5186732 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84090>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1f2db20; to 'JaxprTracer' at 0x7c31a1f2c900>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  36%|███▋      | 31/85 [01:37<02:32,  2.83s/it, loss=0.682]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.6530176  0.51850283 0.63708216 0.7386611  0.27607897 0.13748929
 0.26564634 0.29951456 0.42657942 0.24736768 0.55959344 0.38185388
 0.1780753  0.48976052 0.5753471  0.12518512], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.6530176 , 0.51850283, 0.63708216, 0.7386611 , 0.27607897,
       0.13748929, 0.26564634, 0.29951456, 0.42657942, 0.24736768,
       0.55959344, 0.38185388, 0.1780753 , 0.48976052, 0.5753471 ,
       0.12518512], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119f280>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16043b0; to 'JaxprTracer' at 0x7c31a1606980>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  38%|███▊      | 32/85 [01:40<02:30,  2.84s/it, loss=0.407]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.4926789  0.11139596 0.4551376  0.29305467 0.925881   0.0679843
 0.23020966 0.3680719  0.2137635  0.59506404 0.43245035 0.45022047
 0.82426894 0.65068626 1.7388176  0.7193954 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.4926789 , 0.11139596, 0.4551376 , 0.29305467, 0.925881  ,
       0.0679843 , 0.23020966, 0.3680719 , 0.2137635 , 0.59506404,
       0.43245035, 0.45022047, 0.82426894, 0.65068626, 1.7388176 ,
       0.7193954 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196d4d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b17a60; to 'JaxprTracer' at 0x7c31a1605ee0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  39%|███▉      | 33/85 [01:43<02:22,  2.74s/it, loss=0.536]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.8347644  0.28803957 0.17120889 0.92496467 0.9193655  1.1739627
 0.3026142  1.3022655  1.593936   0.39260232 0.1005541  0.20027404
 0.7344167  0.30979994 0.9840655  0.07758772], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.8347644 , 0.28803957, 0.17120889, 0.92496467, 0.9193655 ,
       1.1739627 , 0.3026142 , 1.3022655 , 1.593936  , 0.39260232,
       0.1005541 , 0.20027404, 0.7344167 , 0.30979994, 0.9840655 ,
       0.07758772], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196ddf0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1cee610; to 'JaxprTracer' at 0x7c31a1ceff60>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  40%|████      | 34/85 [01:45<02:16,  2.68s/it, loss=0.644]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.22543575 0.35927388 0.48331374 1.038642   0.68701524 0.33055842
 0.18351239 0.41652524 0.31264818 0.1329318  0.6605324  0.9706614
 0.53489494 0.88046217 0.18280925 1.0144155 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.22543575, 0.35927388, 0.48331374, 1.038642  , 0.68701524,
       0.33055842, 0.18351239, 0.41652524, 0.31264818, 0.1329318 ,
       0.6605324 , 0.9706614 , 0.53489494, 0.88046217, 0.18280925,
       1.0144155 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16577b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ec2b10; to 'JaxprTracer' at 0x7c31a1344a90>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  41%|████      | 35/85 [01:48<02:17,  2.74s/it, loss=0.526]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.140113   0.45327926 0.82375294 0.43622306 0.2948724  0.6008152
 0.6167896  0.8111936  0.80500925 0.3116968  0.6392202  0.8652714
 0.19927946 0.12966698 0.5100609  0.3676291 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.140113  , 0.45327926, 0.82375294, 0.43622306, 0.2948724 ,
       0.6008152 , 0.6167896 , 0.8111936 , 0.80500925, 0.3116968 ,
       0.6392202 , 0.8652714 , 0.19927946, 0.12966698, 0.5100609 ,
       0.3676291 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655270>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1102c00; to 'JaxprTracer' at 0x7c31a1100d60>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  42%|████▏     | 36/85 [01:51<02:22,  2.92s/it, loss=0.563]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.0013869  0.36515528 0.5305556  0.5810108  0.6513135  1.8085515
 0.20687442 0.4179218  0.6060711  0.43300036 0.4789772  0.77317923
 0.24916327 0.6392818  0.5194058  1.1423969 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.0013869 , 0.36515528, 0.5305556 , 0.5810108 , 0.6513135 ,
       1.8085515 , 0.20687442, 0.4179218 , 0.6060711 , 0.43300036,
       0.4789772 , 0.77317923, 0.24916327, 0.6392818 , 0.5194058 ,
       1.1423969 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196e1d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15453a0; to 'JaxprTracer' at 0x7c31a1187b00>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  44%|████▎     | 37/85 [01:54<02:15,  2.82s/it, loss=0.65]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.5778142  0.16019627 0.14538726 0.32772222 1.0574983  0.53602356
 0.34384346 0.17729345 0.33262697 0.8300384  0.34766465 1.0064782
 0.41022706 0.35020736 0.5971281  0.27183506], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.5778142 , 0.16019627, 0.14538726, 0.32772222, 1.0574983 ,
       0.53602356, 0.34384346, 0.17729345, 0.33262697, 0.8300384 ,
       0.34766465, 1.0064782 , 0.41022706, 0.35020736, 0.5971281 ,
       0.27183506], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84f00>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1416070; to 'JaxprTracer' at 0x7c31a15d19e0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  45%|████▍     | 38/85 [01:56<02:08,  2.73s/it, loss=0.467]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.16179006 0.50058436 0.38880405 1.0405787  0.25434366 0.49855405
 0.8108884  0.5442524  1.1616889  1.8174261  0.6943771  0.2693634
 0.38663992 1.0159478  0.44369486 1.0074719 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.16179006, 0.50058436, 0.38880405, 1.0405787 , 0.25434366,
       0.49855405, 0.8108884 , 0.5442524 , 1.1616889 , 1.8174261 ,
       0.6943771 , 0.2693634 , 0.38663992, 1.0159478 , 0.44369486,
       1.0074719 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196f1c0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16c5620; to 'JaxprTracer' at 0x7c31a16c69d0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  46%|████▌     | 39/85 [01:59<02:02,  2.67s/it, loss=0.687]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.0841311  0.15628219 0.8958541  0.875621   1.8712704  0.5810617
 0.45759588 1.1270735  0.18994382 0.47221208 0.40113553 0.88688195
 1.4715765  0.3064809  0.39968553 2.308285  ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.0841311 , 0.15628219, 0.8958541 , 0.875621  , 1.8712704 ,
       0.5810617 , 0.45759588, 1.1270735 , 0.18994382, 0.47221208,
       0.40113553, 0.88688195, 1.4715765 , 0.3064809 , 0.39968553,
       2.308285  ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119ff50>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b7b7e0; to 'JaxprTracer' at 0x7c31a1b78040>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  47%|████▋     | 40/85 [02:02<02:06,  2.81s/it, loss=0.843]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.6761889  0.23185205 0.42027184 0.5203802  0.47316328 0.18600978
 0.81985164 0.23708947 0.74674195 0.270506   0.92474663 1.438096
 0.3944265  1.2205613  1.2015986  0.5681372 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.6761889 , 0.23185205, 0.42027184, 0.5203802 , 0.47316328,
       0.18600978, 0.81985164, 0.23708947, 0.74674195, 0.270506  ,
       0.92474663, 1.438096  , 0.3944265 , 1.2205613 , 1.2015986 ,
       0.5681372 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a658b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16f8720; to 'JaxprTracer' at 0x7c31a1ac3790>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  48%|████▊     | 41/85 [02:05<02:07,  2.90s/it, loss=0.646]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.0247947  0.27016667 0.6270883  1.4897847  0.7356736  0.12924448
 0.5296502  0.15509821 0.9026913  0.6801672  1.6947261  0.54857504
 0.61383545 0.42162248 0.42203417 0.6888853 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.0247947 , 0.27016667, 0.6270883 , 1.4897847 , 0.7356736 ,
       0.12924448, 0.5296502 , 0.15509821, 0.9026913 , 0.6801672 ,
       1.6947261 , 0.54857504, 0.61383545, 0.42162248, 0.42203417,
       0.6888853 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b85e10>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16f9b70; to 'JaxprTracer' at 0x7c31a16c6020>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  49%|████▉     | 42/85 [02:08<01:59,  2.78s/it, loss=0.683]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.79849225 0.83402044 0.96681136 0.46807635 0.315056   0.4975154
 0.09245225 0.53035825 0.5250609  0.5619643  0.4351394  0.88188946
 0.11997683 0.8047503  0.83298695 0.43292376], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.79849225, 0.83402044, 0.96681136, 0.46807635, 0.315056  ,
       0.4975154 , 0.09245225, 0.53035825, 0.5250609 , 0.5619643 ,
       0.4351394 , 0.88188946, 0.11997683, 0.8047503 , 0.83298695,
       0.43292376], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a65aa0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1aad030; to 'JaxprTracer' at 0x7c31a1aace50>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  51%|█████     | 43/85 [02:10<01:52,  2.68s/it, loss=0.569]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.49817842 1.782894   0.217482   0.4278331  0.36907947 0.9569348
 0.4983311  1.2842734  0.38631618 1.7631999  0.36503273 0.18178104
 1.024834   0.45068067 0.6091711  0.4975241 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.49817842, 1.782894  , 0.217482  , 0.4278331 , 0.36907947,
       0.9569348 , 0.4983311 , 1.2842734 , 0.38631618, 1.7631999 ,
       0.36503273, 0.18178104, 1.024834  , 0.45068067, 0.6091711 ,
       0.4975241 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119c450>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a165e5c0; to 'JaxprTracer' at 0x7c31a165f3d0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  52%|█████▏    | 44/85 [02:13<01:47,  2.62s/it, loss=0.707]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.45774978 0.95147234 0.5113357  0.5224138  0.47751543 0.4325735
 0.24671271 1.3092852  0.24919736 0.30625987 1.0223649  0.34484515
 1.5881485  0.50020254 1.2772902  2.1986482 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.45774978, 0.95147234, 0.5113357 , 0.5224138 , 0.47751543,
       0.4325735 , 0.24671271, 1.3092852 , 0.24919736, 0.30625987,
       1.0223649 , 0.34484515, 1.5881485 , 0.50020254, 1.2772902 ,
       2.1986482 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64d40>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1cba2a0; to 'JaxprTracer' at 0x7c31a1cd8090>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  53%|█████▎    | 45/85 [02:15<01:46,  2.67s/it, loss=0.775]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.80724704 0.28894898 1.469613   0.57589114 0.982302   1.4103787
 0.28630573 0.4193562  0.41333604 0.17727318 0.41528457 0.41934067
 0.8383641  0.22216709 1.5345507  0.37667927], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.80724704, 0.28894898, 1.469613  , 0.57589114, 0.982302  ,
       1.4103787 , 0.28630573, 0.4193562 , 0.41333604, 0.17727318,
       0.41528457, 0.41934067, 0.8383641 , 0.22216709, 1.5345507 ,
       0.37667927], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16548f0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31aa4ffa60; to 'JaxprTracer' at 0x7c31aa4fe160>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  54%|█████▍    | 46/85 [02:19<01:54,  2.93s/it, loss=0.665]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.28983173 0.63869053 0.83211493 0.30253115 0.4348485  1.0132046
 0.43324003 0.34176698 0.6266228  1.4553609  0.9767175  0.24571285
 1.2115278  0.6440515  0.41825864 0.38872176], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.28983173, 0.63869053, 0.83211493, 0.30253115, 0.4348485 ,
       1.0132046 , 0.43324003, 0.34176698, 0.6266228 , 1.4553609 ,
       0.9767175 , 0.24571285, 1.2115278 , 0.6440515 , 0.41825864,
       0.38872176], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a642b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b78f90; to 'JaxprTracer' at 0x7c31a1b7bc40>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  55%|█████▌    | 47/85 [02:22<01:47,  2.82s/it, loss=0.641]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.0068486  0.519473   0.5191931  0.8340979  0.09658991 0.6717372
 1.309599   0.9333716  0.24016438 0.32457384 0.9603796  0.2465293
 0.24258277 1.5227302  0.38168803 0.47584784], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.0068486 , 0.519473  , 0.5191931 , 0.8340979 , 0.09658991,
       0.6717372 , 1.309599  , 0.9333716 , 0.24016438, 0.32457384,
       0.9603796 , 0.2465293 , 0.24258277, 1.5227302 , 0.38168803,
       0.47584784], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a65ec0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31aa381f30; to 'JaxprTracer' at 0x7c31aa383380>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  56%|█████▋    | 48/85 [02:24<01:41,  2.74s/it, loss=0.643]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.520073   0.27863425 1.6094831  0.80608463 0.2973879  0.27028158
 0.152395   1.1476305  0.90939146 0.22256088 1.4531934  0.39321828
 0.2575171  0.7249253  0.49120814 0.4263362 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.520073  , 0.27863425, 1.6094831 , 0.80608463, 0.2973879 ,
       0.27028158, 0.152395  , 1.1476305 , 0.90939146, 0.22256088,
       1.4531934 , 0.39321828, 0.2575171 , 0.7249253 , 0.49120814,
       0.4263362 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119d290>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b79ee0; to 'JaxprTracer' at 0x7c31a1b79350>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  58%|█████▊    | 49/85 [02:27<01:36,  2.67s/it, loss=0.685]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.1471125  0.93712187 0.26035637 0.14880404 0.3811892  0.2916222
 0.6849089  0.7979274  1.5309516  1.3280789  1.149884   0.16864116
 0.288257   0.2918455  0.4501523  0.5520372 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.1471125 , 0.93712187, 0.26035637, 0.14880404, 0.3811892 ,
       0.2916222 , 0.6849089 , 0.7979274 , 1.5309516 , 1.3280789 ,
       1.149884  , 0.16864116, 0.288257  , 0.2918455 , 0.4501523 ,
       0.5520372 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64e60>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16062f0; to 'JaxprTracer' at 0x7c31a16047c0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  59%|█████▉    | 50/85 [02:30<01:37,  2.77s/it, loss=0.651]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.35463896 1.3028123  1.2388146  0.19632927 0.91586417 0.25928646
 0.34911793 0.9634309  0.7087488  0.21542753 0.5523803  0.513585
 0.06966352 0.4715745  0.8559324  0.6745105 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.35463896, 1.3028123 , 1.2388146 , 0.19632927, 0.91586417,
       0.25928646, 0.34911793, 0.9634309 , 0.7087488 , 0.21542753,
       0.5523803 , 0.513585  , 0.06966352, 0.4715745 , 0.8559324 ,
       0.6745105 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a65690>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0fa2660; to 'JaxprTracer' at 0x7c31a0fa1e40>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  60%|██████    | 51/85 [02:33<01:39,  2.91s/it, loss=0.603]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.72336966 0.4185343  0.160185   0.07267417 0.8667725  0.9450956
 0.8140323  0.46387416 0.641774   0.34274885 0.74592024 0.16662385
 0.20543647 0.8137065  0.80785054 0.3966445 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.72336966, 0.4185343 , 0.160185  , 0.07267417, 0.8667725 ,
       0.9450956 , 0.8140323 , 0.46387416, 0.641774  , 0.34274885,
       0.74592024, 0.16662385, 0.20543647, 0.8137065 , 0.80785054,
       0.3966445 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b86400>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b7a890; to 'JaxprTracer' at 0x7c31a1b7a700>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  61%|██████    | 52/85 [02:35<01:32,  2.80s/it, loss=0.537]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.43397963 0.5203126  0.60327226 0.6258647  0.7238706  0.3894805
 0.09245312 0.7440589  0.946089   0.5293299  0.5056696  0.17601316
 0.877754   0.21932559 0.19550195 0.5282591 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.43397963, 0.5203126 , 0.60327226, 0.6258647 , 0.7238706 ,
       0.3894805 , 0.09245312, 0.7440589 , 0.946089  , 0.5293299 ,
       0.5056696 , 0.17601316, 0.877754  , 0.21932559, 0.19550195,
       0.5282591 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84290>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ad4f40; to 'JaxprTracer' at 0x7c31a1ad7420>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  62%|██████▏   | 53/85 [02:38<01:27,  2.72s/it, loss=0.507]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.37990656 1.0859191  0.3733712  1.2329826  1.1629522  0.44181705
 0.4897448  0.62228024 0.15954006 0.33447936 0.4315465  0.3483803
 0.24900055 1.389915   1.2463028  0.66033137], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.37990656, 1.0859191 , 0.3733712 , 1.2329826 , 1.1629522 ,
       0.44181705, 0.4897448 , 0.62228024, 0.15954006, 0.33447936,
       0.4315465 , 0.3483803 , 0.24900055, 1.389915  , 1.2463028 ,
       0.66033137], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84d10>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16c75b0; to 'JaxprTracer' at 0x7c31a16c7830>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  64%|██████▎   | 54/85 [02:41<01:24,  2.73s/it, loss=0.663]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.3057282  0.6956907  0.929901   0.5160417  0.7262619  0.5679192
 0.59388965 0.26684672 0.08425486 1.3917754  0.93077    0.50279164
 0.34018576 0.31721562 0.31348523 0.6775385 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.3057282 , 0.6956907 , 0.929901  , 0.5160417 , 0.7262619 ,
       0.5679192 , 0.59388965, 0.26684672, 0.08425486, 1.3917754 ,
       0.93077   , 0.50279164, 0.34018576, 0.31721562, 0.31348523,
       0.6775385 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1654760>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a0abb0; to 'JaxprTracer' at 0x7c31a1a08e50>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  65%|██████▍   | 55/85 [02:44<01:31,  3.05s/it, loss=0.635]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.8765823  0.10914261 0.25587964 0.35583824 0.6885044  0.96805286
 0.7744225  0.39597479 0.5388792  0.47521478 0.48613513 1.1244206
 0.24586128 0.15320893 1.006485   0.7793596 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.8765823 , 0.10914261, 0.25587964, 0.35583824, 0.6885044 ,
       0.96805286, 0.7744225 , 0.39597479, 0.5388792 , 0.47521478,
       0.48613513, 1.1244206 , 0.24586128, 0.15320893, 1.006485  ,
       0.7793596 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1654670>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1e3ad40; to 'JaxprTracer' at 0x7c31a1fdc090>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  66%|██████▌   | 56/85 [02:48<01:30,  3.13s/it, loss=0.577]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.33824396 0.8084351  0.5617208  0.46865097 1.152427   0.85189927
 0.29983014 0.30469203 1.7148886  0.5647626  0.76817685 0.5273803
 0.22777112 0.5967225  0.8797769  2.3193603 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.33824396, 0.8084351 , 0.5617208 , 0.46865097, 1.152427  ,
       0.85189927, 0.29983014, 0.30469203, 1.7148886 , 0.5647626 ,
       0.76817685, 0.5273803 , 0.22777112, 0.5967225 , 0.8797769 ,
       2.3193603 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655960>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1f2c7c0; to 'JaxprTracer' at 0x7c31a1f2c090>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  67%|██████▋   | 57/85 [02:50<01:22,  2.96s/it, loss=0.774]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.10096971 1.0093288  0.11777943 0.38211602 0.19418627 0.11259528
 0.977056   0.49173102 1.126318   0.17829636 0.06088601 2.2664485
 0.7667308  1.5708513  0.7358272  0.11268805], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.10096971, 1.0093288 , 0.11777943, 0.38211602, 0.19418627,
       0.11259528, 0.977056  , 0.49173102, 1.126318  , 0.17829636,
       0.06088601, 2.2664485 , 0.7667308 , 1.5708513 , 0.7358272 ,
       0.11268805], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196db90>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1692d40; to 'JaxprTracer' at 0x7c31a1692110>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  68%|██████▊   | 58/85 [02:53<01:16,  2.83s/it, loss=0.638]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.58710307 1.4625942  0.745653   0.17388995 1.91811    0.37022656
 0.12940192 1.1703373  1.33916    2.5367234  0.09603506 0.16291691
 1.0085149  0.35366914 0.749048   1.2911873 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.58710307, 1.4625942 , 0.745653  , 0.17388995, 1.91811   ,
       0.37022656, 0.12940192, 1.1703373 , 1.33916   , 2.5367234 ,
       0.09603506, 0.16291691, 1.0085149 , 0.35366914, 0.749048  ,
       1.2911873 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b87f20>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0fdcef0; to 'JaxprTracer' at 0x7c31a0fdd260>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  69%|██████▉   | 59/85 [02:55<01:11,  2.75s/it, loss=0.881]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.42211312 1.325207   0.26788157 0.2070673  0.45952266 0.06583492
 0.06820691 0.6048721  1.8492079  0.53532225 0.22794254 0.05149274
 0.31699294 0.30289719 1.0623839  1.165513  ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.42211312, 1.325207  , 0.26788157, 0.2070673 , 0.45952266,
       0.06583492, 0.06820691, 0.6048721 , 1.8492079 , 0.53532225,
       0.22794254, 0.05149274, 0.31699294, 0.30289719, 1.0623839 ,
       1.165513  ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b85480>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1cd96c0; to 'JaxprTracer' at 0x7c31a1cdb420>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  71%|███████   | 60/85 [02:59<01:12,  2.91s/it, loss=0.558]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.05122832 0.88255674 0.86193776 0.6426369  0.74401104 0.62539357
 0.78387225 0.71817595 0.45415652 0.7119243  1.0211507  0.25322175
 0.42840913 0.3040445  2.6284912  0.33554238], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.05122832, 0.88255674, 0.86193776, 0.6426369 , 0.74401104,
       0.62539357, 0.78387225, 0.71817595, 0.45415652, 0.7119243 ,
       1.0211507 , 0.25322175, 0.42840913, 0.3040445 , 2.6284912 ,
       0.33554238], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119c390>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a22ed0; to 'JaxprTracer' at 0x7c31a1a227a0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  72%|███████▏  | 61/85 [03:02<01:10,  2.92s/it, loss=0.715]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.9187979  0.532197   1.0360585  0.7260697  0.28151605 1.0168738
 0.11379913 0.28403822 0.681503   0.08588327 0.39958337 1.4506032
 1.316062   0.66215545 0.9670234  0.15780047], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.9187979 , 0.532197  , 1.0360585 , 0.7260697 , 0.28151605,
       1.0168738 , 0.11379913, 0.28403822, 0.681503  , 0.08588327,
       0.39958337, 1.4506032 , 1.316062  , 0.66215545, 0.9670234 ,
       0.15780047], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b87dd0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e64e50; to 'JaxprTracer' at 0x7c31a0e67650>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  73%|███████▎  | 62/85 [03:04<01:04,  2.81s/it, loss=0.664]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.35082003 0.48851198 0.48963457 0.54051113 0.27301207 0.3158649
 0.39627102 0.47988015 0.54351    0.8484967  0.48152262 0.20626861
 0.85121095 0.7518201  0.96547675 0.5954315 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.35082003, 0.48851198, 0.48963457, 0.54051113, 0.27301207,
       0.3158649 , 0.39627102, 0.47988015, 0.54351   , 0.8484967 ,
       0.48152262, 0.20626861, 0.85121095, 0.7518201 , 0.96547675,
       0.5954315 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84be0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ced030; to 'JaxprTracer' at 0x7c31a1cec040>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  74%|███████▍  | 63/85 [03:07<00:59,  2.73s/it, loss=0.536]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.23375599 1.0663273  0.8640804  0.7482652  0.46033758 0.47073242
 0.92698675 0.50356865 0.1473745  0.5545229  0.5445404  0.38087744
 0.48234755 0.663122   0.13188702 0.08517105], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.23375599, 1.0663273 , 0.8640804 , 0.7482652 , 0.46033758,
       0.47073242, 0.92698675, 0.50356865, 0.1473745 , 0.5545229 ,
       0.5445404 , 0.38087744, 0.48234755, 0.663122  , 0.13188702,
       0.08517105], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a66490>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1563880; to 'JaxprTracer' at 0x7c31a1563330>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  75%|███████▌  | 64/85 [03:09<00:57,  2.72s/it, loss=0.516]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.50469995 0.8141257  0.17296262 1.3421851  0.29904228 0.24967131
 0.16537194 0.93581474 0.19566214 0.61449265 1.1126907  0.5582845
 0.3974023  0.8935098  0.20418881 0.6681504 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.50469995, 0.8141257 , 0.17296262, 1.3421851 , 0.29904228,
       0.24967131, 0.16537194, 0.93581474, 0.19566214, 0.61449265,
       1.1126907 , 0.5582845 , 0.3974023 , 0.8935098 , 0.20418881,
       0.6681504 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119c170>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e03740; to 'JaxprTracer' at 0x7c31a0e031f0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  76%|███████▋  | 65/85 [03:13<00:56,  2.83s/it, loss=0.571]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.395959   0.32714903 0.91005874 1.9503094  0.08512212 0.23822783
 0.22783044 0.15671666 0.45134264 1.7628922  0.8272603  0.8597712
 0.35481238 1.2107682  0.3220614  1.4796327 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.395959  , 0.32714903, 0.91005874, 1.9503094 , 0.08512212,
       0.23822783, 0.22783044, 0.15671666, 0.45134264, 1.7628922 ,
       0.8272603 , 0.8597712 , 0.35481238, 1.2107682 , 0.3220614 ,
       1.4796327 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119e1d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1e7f920; to 'JaxprTracer' at 0x7c31a1e7d940>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  78%|███████▊  | 66/85 [03:15<00:54,  2.86s/it, loss=0.722]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.7764945  0.6023676  0.5389623  0.7442721  0.4485109  0.22512752
 0.83818483 0.4493463  0.18337366 0.23029006 0.35874364 0.78378576
 0.5644345  0.25024045 0.3859596  1.1972113 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.7764945 , 0.6023676 , 0.5389623 , 0.7442721 , 0.4485109 ,
       0.22512752, 0.83818483, 0.4493463 , 0.18337366, 0.23029006,
       0.35874364, 0.78378576, 0.5644345 , 0.25024045, 0.3859596 ,
       1.1972113 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1654240>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1022ac0; to 'JaxprTracer' at 0x7c31a1023d30>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  79%|███████▉  | 67/85 [03:18<00:49,  2.77s/it, loss=0.536]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.1780744  1.4448175  0.18131655 1.1682     0.18488052 0.37693632
 0.47526622 0.32114345 1.3069978  1.2653972  0.29873586 0.28065333
 0.39943883 1.1660199  0.36777773 0.09234813], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.1780744 , 1.4448175 , 0.18131655, 1.1682    , 0.18488052,
       0.37693632, 0.47526622, 0.32114345, 1.3069978 , 1.2653972 ,
       0.29873586, 0.28065333, 0.39943883, 1.1660199 , 0.36777773,
       0.09234813], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655750>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ac17b0; to 'JaxprTracer' at 0x7c31a1ac03b0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  80%|████████  | 68/85 [03:21<00:46,  2.71s/it, loss=0.594]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.20862105 0.37921348 0.06571537 0.61053056 1.2363462  1.5930398
 1.0182369  0.03900975 0.4470051  0.4699239  0.94479895 1.1062282
 1.3424108  0.23877572 0.19572164 0.21714415], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.20862105, 0.37921348, 0.06571537, 0.61053056, 1.2363462 ,
       1.5930398 , 1.0182369 , 0.03900975, 0.4470051 , 0.4699239 ,
       0.94479895, 1.1062282 , 1.3424108 , 0.23877572, 0.19572164,
       0.21714415], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b87020>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1686a70; to 'JaxprTracer' at 0x7c31a1684680>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  81%|████████  | 69/85 [03:24<00:44,  2.78s/it, loss=0.632]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.66178864 1.4144039  0.8209974  0.24865478 0.09629189 0.7860491
 0.27774265 1.4548304  0.1761784  0.13335139 0.5536445  0.19369854
 0.5067309  0.26885507 1.3415961  0.93253136], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.66178864, 1.4144039 , 0.8209974 , 0.24865478, 0.09629189,
       0.7860491 , 0.27774265, 1.4548304 , 0.1761784 , 0.13335139,
       0.5536445 , 0.19369854, 0.5067309 , 0.26885507, 1.3415961 ,
       0.93253136], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16558d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31aa383a60; to 'JaxprTracer' at 0x7c31aa381d00>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  82%|████████▏ | 70/85 [03:26<00:42,  2.82s/it, loss=0.617]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.08892211 0.29246944 0.1581501  0.3811101  0.12932524 2.3396664
 1.0801802  0.39645302 0.7611008  0.05331617 0.9495076  1.1770785
 1.566421   0.46233225 0.41188613 0.5591271 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.08892211, 0.29246944, 0.1581501 , 0.3811101 , 0.12932524,
       2.3396664 , 1.0801802 , 0.39645302, 0.7611008 , 0.05331617,
       0.9495076 , 1.1770785 , 1.566421  , 0.46233225, 0.41188613,
       0.5591271 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196fab0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e426b0; to 'JaxprTracer' at 0x7c31a0e41080>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  84%|████████▎ | 71/85 [03:29<00:39,  2.84s/it, loss=0.675]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.76331073 0.55133367 0.5312468  0.7928324  0.41375944 0.14439061
 0.0702892  0.17441265 0.16379368 0.15231413 0.33331552 0.7715725
 0.5801132  0.815647   0.10260332 0.38694698], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.76331073, 0.55133367, 0.5312468 , 0.7928324 , 0.41375944,
       0.14439061, 0.0702892 , 0.17441265, 0.16379368, 0.15231413,
       0.33331552, 0.7715725 , 0.5801132 , 0.815647  , 0.10260332,
       0.38694698], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119e9d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1547060; to 'JaxprTracer' at 0x7c31a0e70b30>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  85%|████████▍ | 72/85 [03:32<00:35,  2.73s/it, loss=0.422]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.08758013 1.5544156  0.5561029  1.11267    2.172971   0.31137493
 0.63148576 0.04664614 0.84045047 0.3647102  0.5629088  0.12782057
 0.5920714  0.48956376 0.3726823  0.44618043], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.08758013, 1.5544156 , 0.5561029 , 1.11267   , 2.172971  ,
       0.31137493, 0.63148576, 0.04664614, 0.84045047, 0.3647102 ,
       0.5629088 , 0.12782057, 0.5920714 , 0.48956376, 0.3726823 ,
       0.44618043], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119ddd0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31aa380cc0; to 'JaxprTracer' at 0x7c31aa383a60>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  86%|████████▌ | 73/85 [03:34<00:31,  2.66s/it, loss=0.642]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.0448651  0.55923796 1.2922106  1.9047282  0.4804619  0.21091509
 1.2351898  0.29962853 2.1252415  0.27084693 0.47359633 0.8647891
 0.29078043 0.8704463  0.95287514 0.5180082 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.0448651 , 0.55923796, 1.2922106 , 1.9047282 , 0.4804619 ,
       0.21091509, 1.2351898 , 0.29962853, 2.1252415 , 0.27084693,
       0.47359633, 0.8647891 , 0.29078043, 0.8704463 , 0.95287514,
       0.5180082 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119def0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a165fc40; to 'JaxprTracer' at 0x7c31a165e390>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  87%|████████▋ | 74/85 [03:38<00:31,  2.83s/it, loss=0.837]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.7167245  1.4246162  0.61673373 0.7929058  0.47113818 0.4345554
 0.6309153  0.94669896 1.3322557  1.0807769  0.83841246 0.4687657
 0.75026685 0.9510647  0.10405003 0.19886906], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.7167245 , 1.4246162 , 0.61673373, 0.7929058 , 0.47113818,
       0.4345554 , 0.6309153 , 0.94669896, 1.3322557 , 1.0807769 ,
       0.83841246, 0.4687657 , 0.75026685, 0.9510647 , 0.10405003,
       0.19886906], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119c740>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1c965c0; to 'JaxprTracer' at 0x7c31a1c94130>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  88%|████████▊ | 75/85 [03:40<00:27,  2.74s/it, loss=0.735]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.5049554  0.29624307 0.21083523 0.6673014  0.92486405 0.5332255
 0.5596676  0.40963057 0.27145532 0.47267288 0.5593028  0.63446194
 0.403838   0.99630713 0.7045827  0.23149662], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.5049554 , 0.29624307, 0.21083523, 0.6673014 , 0.92486405,
       0.5332255 , 0.5596676 , 0.40963057, 0.27145532, 0.47267288,
       0.5593028 , 0.63446194, 0.403838  , 0.99630713, 0.7045827 ,
       0.23149662], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a65810>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1c978d0; to 'JaxprTracer' at 0x7c31a1c96110>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  89%|████████▉ | 76/85 [03:43<00:25,  2.79s/it, loss=0.524]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.54185313 0.13271959 0.8664879  0.63635665 0.78708375 0.61801267
 0.85052454 0.8055617  1.6667459  0.91939616 0.17918605 0.94508827
 0.12694232 0.5914481  0.75511146 1.0364252 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.54185313, 0.13271959, 0.8664879 , 0.63635665, 0.78708375,
       0.61801267, 0.85052454, 0.8055617 , 1.6667459 , 0.91939616,
       0.17918605, 0.94508827, 0.12694232, 0.5914481 , 0.75511146,
       1.0364252 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a66340>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1613ba0; to 'JaxprTracer' at 0x7c31a1604180>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  91%|█████████ | 77/85 [03:45<00:21,  2.70s/it, loss=0.716]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.42757306 0.23959212 0.52558523 1.0735611  0.2143676  0.5065819
 0.37905812 0.19880241 0.62379414 1.0622916  0.4243287  0.7889497
 0.37273625 0.4253662  1.0088395  0.19590011], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.42757306, 0.23959212, 0.52558523, 1.0735611 , 0.2143676 ,
       0.5065819 , 0.37905812, 0.19880241, 0.62379414, 1.0622916 ,
       0.4243287 , 0.7889497 , 0.37273625, 0.4253662 , 1.0088395 ,
       0.19590011], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84550>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1055580; to 'JaxprTracer' at 0x7c31a1056ca0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  92%|█████████▏| 78/85 [03:48<00:18,  2.68s/it, loss=0.529]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.38612434 0.24789664 0.2912087  0.36894956 0.17301427 0.67387414
 0.45714468 1.3497934  0.3839535  0.44686982 0.58610934 0.6459482
 0.3690386  0.92626417 0.77218366 0.30018196], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.38612434, 0.24789664, 0.2912087 , 0.36894956, 0.17301427,
       0.67387414, 0.45714468, 1.3497934 , 0.3839535 , 0.44686982,
       0.58610934, 0.6459482 , 0.3690386 , 0.92626417, 0.77218366,
       0.30018196], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b87cd0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a2c400; to 'JaxprTracer' at 0x7c31a1a2db70>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  93%|█████████▎| 79/85 [03:51<00:16,  2.82s/it, loss=0.524]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.36077404 0.8996728  0.20536561 1.5238148  0.7272729  0.4890973
 0.388591   0.07999378 0.47996986 0.15201773 0.73088795 0.2501894
 0.9850712  0.6754033  0.6970546  0.7601266 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.36077404, 0.8996728 , 0.20536561, 1.5238148 , 0.7272729 ,
       0.4890973 , 0.388591  , 0.07999378, 0.47996986, 0.15201773,
       0.73088795, 0.2501894 , 0.9850712 , 0.6754033 , 0.6970546 ,
       0.7601266 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119fae0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e9d1c0; to 'JaxprTracer' at 0x7c31a0e9c310>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  94%|█████████▍| 80/85 [03:54<00:13,  2.73s/it, loss=0.588]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.2530709  0.35738885 0.99293864 0.5967444  0.59051174 0.2673098
 0.30401084 0.47993475 0.1961833  0.70375663 0.15716764 0.53320366
 0.7009294  0.78341097 0.34500828 0.6549711 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.2530709 , 0.35738885, 0.99293864, 0.5967444 , 0.59051174,
       0.2673098 , 0.30401084, 0.47993475, 0.1961833 , 0.70375663,
       0.15716764, 0.53320366, 0.7009294 , 0.78341097, 0.34500828,
       0.6549711 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b87960>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a22480; to 'JaxprTracer' at 0x7c31a1a20220>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  95%|█████████▌| 81/85 [03:57<00:11,  2.79s/it, loss=0.495]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.1484395  0.27119493 0.61769587 1.0486714  0.710166   0.38839474
 0.24293101 0.27442    0.43809754 0.84473395 0.45366603 0.54357296
 0.3947208  0.5345799  0.45546228 0.5078292 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.1484395 , 0.27119493, 0.61769587, 1.0486714 , 0.710166  ,
       0.38839474, 0.24293101, 0.27442   , 0.43809754, 0.84473395,
       0.45366603, 0.54357296, 0.3947208 , 0.5345799 , 0.45546228,
       0.5078292 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119f4f0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15443b0; to 'JaxprTracer' at 0x7c31a1544130>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  96%|█████████▋| 82/85 [03:59<00:08,  2.71s/it, loss=0.492]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.71792895 0.5741385  0.886595   0.30336702 0.21067226 0.41621038
 0.5720376  1.1013064  0.7842724  0.25073513 0.53498644 1.1886704
 0.5249097  0.44522727 1.0780166  0.25781587], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.71792895, 0.5741385 , 0.886595  , 0.30336702, 0.21067226,
       0.41621038, 0.5720376 , 1.1013064 , 0.7842724 , 0.25073513,
       0.53498644, 1.1886704 , 0.5249097 , 0.44522727, 1.0780166 ,
       0.25781587], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196de90>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0fa3a10; to 'JaxprTracer' at 0x7c31a0fa3ab0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  98%|█████████▊| 83/85 [04:02<00:05,  2.75s/it, loss=0.615]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.2805575  0.432296   0.76243526 1.6373405  0.911358   0.6049072
 0.6848162  1.2047081  1.1917202  0.22884455 0.47530237 1.3060554
 0.80369127 0.56534684 0.6065055  0.6270256 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.2805575 , 0.432296  , 0.76243526, 1.6373405 , 0.911358  ,
       0.6049072 , 0.6848162 , 1.2047081 , 1.1917202 , 0.22884455,
       0.47530237, 1.3060554 , 0.80369127, 0.56534684, 0.6065055 ,
       0.6270256 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16569b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a10204a0; to 'JaxprTracer' at 0x7c31a1023100>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  99%|█████████▉| 84/85 [04:05<00:02,  2.81s/it, loss=0.77]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.3347508  0.9770185  0.5792075  1.1594241  0.59614897 0.31243008
 1.1500577  0.37024903 0.25568587 0.19510962 0.61112595 0.37661302
 0.7383341  0.39214778 0.87432677 0.6985262 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.3347508 , 0.9770185 , 0.5792075 , 1.1594241 , 0.59614897,
       0.31243008, 1.1500577 , 0.37024903, 0.25568587, 0.19510962,
       0.61112595, 0.37661302, 0.7383341 , 0.39214778, 0.87432677,
       0.6985262 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16577b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b14360; to 'JaxprTracer' at 0x7c31a1b16980>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training: 100%|██████████| 85/85 [04:08<00:00,  2.92s/it, loss=0.601]


Evaluating after epoch 1...


Evaluating: 100%|██████████| 21/21 [00:12<00:00,  1.69it/s]


Accuracy: 0.6845
Precision: 0.8990
Recall: 0.4811
Confusion Matrix:
[[141  10]
 [ 96  89]]
Epoch 2/3


Training:   0%|          | 0/85 [00:00<?, ?it/s]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.47131103 0.56303823 1.0388651  0.99371105 0.9585325  0.5032133
 0.2582425  1.1325299  0.48846012 0.61818117 0.594622   1.1354845
 0.6772652  1.0640416  0.5551835  0.7223594 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.47131103, 0.56303823, 1.0388651 , 0.99371105, 0.9585325 ,
       0.5032133 , 0.2582425 , 1.1325299 , 0.48846012, 0.61818117,
       0.594622  , 1.1354845 , 0.6772652 , 1.0640416 , 0.5551835 ,
       0.7223594 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84630>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16ac3b0; to 'JaxprTracer' at 0x7c31a16af6f0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:   1%|          | 1/85 [00:02<03:33,  2.54s/it, loss=0.736]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.3721067  0.244403   0.18481176 0.2612298  0.7993476  0.6571694
 1.3107988  0.87941575 0.4587601  0.80359423 0.5923822  0.37878922
 1.5636498  0.88005203 0.43821946 0.2624449 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.3721067 , 0.244403  , 0.18481176, 0.2612298 , 0.7993476 ,
       0.6571694 , 1.3107988 , 0.87941575, 0.4587601 , 0.80359423,
       0.5923822 , 0.37878922, 1.5636498 , 0.88005203, 0.43821946,
       0.2624449 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b87090>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31aa3a5f30; to 'JaxprTracer' at 0x7c31a17f2340>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:   2%|▏         | 2/85 [00:05<03:27,  2.51s/it, loss=0.63]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.27339694 0.75297624 0.15729477 0.538643   0.7818582  0.8491989
 0.07178883 0.38310778 0.2984823  0.3172027  0.4748665  0.6354346
 0.24419442 0.4793836  0.33118144 0.70449114], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.27339694, 0.75297624, 0.15729477, 0.538643  , 0.7818582 ,
       0.8491989 , 0.07178883, 0.38310778, 0.2984823 , 0.3172027 ,
       0.4748665 , 0.6354346 , 0.24419442, 0.4793836 , 0.33118144,
       0.70449114], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a649a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1d5fa60; to 'JaxprTracer' at 0x7c31a1d5eb60>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:   4%|▎         | 3/85 [00:08<03:58,  2.91s/it, loss=0.456]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.32141    0.09273402 0.3230456  0.1756168  1.1084692  0.39893773
 0.54277223 0.8950635  1.0139604  0.67600846 1.1287615  0.55207425
 0.7202308  0.34192553 0.7177439  0.1242115 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.32141   , 0.09273402, 0.3230456 , 0.1756168 , 1.1084692 ,
       0.39893773, 0.54277223, 0.8950635 , 1.0139604 , 0.67600846,
       1.1287615 , 0.55207425, 0.7202308 , 0.34192553, 0.7177439 ,
       0.1242115 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a66890>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1e7e070; to 'JaxprTracer' at 0x7c31a1e7c770>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:   5%|▍         | 4/85 [00:11<04:01,  2.98s/it, loss=0.571]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.53437227 0.57220536 1.4988127  0.33897647 0.5021753  0.43139216
 0.7926421  1.0091095  0.76764274 0.47453678 0.32882643 0.0259834
 0.25578502 0.23668922 0.18672413 1.1791301 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.53437227, 0.57220536, 1.4988127 , 0.33897647, 0.5021753 ,
       0.43139216, 0.7926421 , 1.0091095 , 0.76764274, 0.47453678,
       0.32882643, 0.0259834 , 0.25578502, 0.23668922, 0.18672413,
       1.1791301 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a656f0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a162e840; to 'JaxprTracer' at 0x7c31a1560cc0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:   6%|▌         | 5/85 [00:14<03:57,  2.97s/it, loss=0.571]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.70097446 0.07386181 0.2843393  0.8137114  0.3495218  0.3755354
 0.3893952  0.90927815 0.5981805  0.1288415  0.40576848 0.6410588
 1.1264737  0.2115015  0.38844115 0.9403497 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.70097446, 0.07386181, 0.2843393 , 0.8137114 , 0.3495218 ,
       0.3755354 , 0.3893952 , 0.90927815, 0.5981805 , 0.1288415 ,
       0.40576848, 0.6410588 , 1.1264737 , 0.2115015 , 0.38844115,
       0.9403497 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a647f0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1c19d00; to 'JaxprTracer' at 0x7c31a1c1bba0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:   7%|▋         | 6/85 [00:17<03:43,  2.83s/it, loss=0.521]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.7777105  0.68001485 0.10265668 0.38867602 0.26236504 0.13171294
 0.16038029 1.075216   0.6593253  0.13441902 0.64002085 0.10773087
 0.6801374  0.6550407  0.2349053  1.1937466 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.7777105 , 0.68001485, 0.10265668, 0.38867602, 0.26236504,
       0.13171294, 0.16038029, 1.075216  , 0.6593253 , 0.13441902,
       0.64002085, 0.10773087, 0.6801374 , 0.6550407 , 0.2349053 ,
       1.1937466 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a65410>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a17f0f90; to 'JaxprTracer' at 0x7c31a17f1fd0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:   8%|▊         | 7/85 [00:20<03:47,  2.92s/it, loss=0.493]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.3055787  1.1597124  0.11573955 0.32670882 0.20413429 0.21065004
 0.26145315 0.10171546 0.17723933 0.4695888  0.31193995 0.14921066
 0.6583617  0.09105046 0.57854    1.0689092 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.3055787 , 1.1597124 , 0.11573955, 0.32670882, 0.20413429,
       0.21065004, 0.26145315, 0.10171546, 0.17723933, 0.4695888 ,
       0.31193995, 0.14921066, 0.6583617 , 0.09105046, 0.57854   ,
       1.0689092 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655d90>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0d43150; to 'JaxprTracer' at 0x7c31a1cdb790>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:   9%|▉         | 8/85 [00:22<03:40,  2.87s/it, loss=0.387]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.0449165  1.5245306  1.2334169  1.6665858  0.18045971 0.704058
 0.8329916  0.06176345 1.0472147  0.40293986 0.14820085 0.44754058
 0.2109926  0.15860656 0.16702145 0.6378786 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.0449165 , 1.5245306 , 1.2334169 , 1.6665858 , 0.18045971,
       0.704058  , 0.8329916 , 0.06176345, 1.0472147 , 0.40293986,
       0.14820085, 0.44754058, 0.2109926 , 0.15860656, 0.16702145,
       0.6378786 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a65c50>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a145b740; to 'JaxprTracer' at 0x7c31a12f1df0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  11%|█         | 9/85 [00:25<03:30,  2.78s/it, loss=0.592]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.7153073  0.22081617 0.2780973  0.9491308  0.18083355 0.39656606
 1.095631   0.2447326  0.43236047 0.71634614 0.67337877 0.7570245
 0.49925685 0.6914412  0.57085425 0.20178498], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.7153073 , 0.22081617, 0.2780973 , 0.9491308 , 0.18083355,
       0.39656606, 1.095631  , 0.2447326 , 0.43236047, 0.71634614,
       0.67337877, 0.7570245 , 0.49925685, 0.6914412 , 0.57085425,
       0.20178498], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64170>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a10d0e50; to 'JaxprTracer' at 0x7c31a10d3880>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  12%|█▏        | 10/85 [00:27<03:22,  2.69s/it, loss=0.539]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.2324041  1.6988178  0.10534683 0.3584243  0.31065196 2.4598618
 1.2691638  0.17065555 0.8147079  0.72223014 0.5526904  0.66628975
 1.5223734  1.8051405  0.03940154 0.3079057 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.2324041 , 1.6988178 , 0.10534683, 0.3584243 , 0.31065196,
       2.4598618 , 1.2691638 , 0.17065555, 0.8147079 , 0.72223014,
       0.5526904 , 0.66628975, 1.5223734 , 1.8051405 , 0.03940154,
       0.3079057 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a67f80>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1610ef0; to 'JaxprTracer' at 0x7c31a16131f0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  13%|█▎        | 11/85 [00:30<03:23,  2.75s/it, loss=0.877]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.5791032  0.9349745  0.12655418 0.15497868 0.63662076 0.03371574
 0.8930887  0.24304433 0.7539551  1.4829674  0.08913805 0.8367255
 0.11414568 0.38199872 0.5328925  0.12921557], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.5791032 , 0.9349745 , 0.12655418, 0.15497868, 0.63662076,
       0.03371574, 0.8930887 , 0.24304433, 0.7539551 , 1.4829674 ,
       0.08913805, 0.8367255 , 0.11414568, 0.38199872, 0.5328925 ,
       0.12921557], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a651b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0d42f70; to 'JaxprTracer' at 0x7c31a0d41fd0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  14%|█▍        | 12/85 [00:34<03:32,  2.91s/it, loss=0.495]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.24813813 0.9822142  0.64151996 0.57780796 0.19830345 0.6802737
 1.7353555  0.21163578 0.16723903 0.8493046  0.62231034 0.5253664
 0.10821176 0.20824245 0.21843506 0.97337353], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.24813813, 0.9822142 , 0.64151996, 0.57780796, 0.19830345,
       0.6802737 , 1.7353555 , 0.21163578, 0.16723903, 0.8493046 ,
       0.62231034, 0.5253664 , 0.10821176, 0.20824245, 0.21843506,
       0.97337353], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b86070>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0fa4b80; to 'JaxprTracer' at 0x7c31a0fa4900>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  15%|█▌        | 13/85 [00:36<03:21,  2.80s/it, loss=0.559]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.73865163 0.35704806 0.20365673 0.29399642 0.59199595 1.7100776
 1.0555836  0.6370373  0.4050645  0.7208538  0.6928286  0.36427248
 0.5676204  0.8672017  0.64206123 0.29376382], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.73865163, 0.35704806, 0.20365673, 0.29399642, 0.59199595,
       1.7100776 , 1.0555836 , 0.6370373 , 0.4050645 , 0.7208538 ,
       0.6928286 , 0.36427248, 0.5676204 , 0.8672017 , 0.64206123,
       0.29376382], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64440>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a145ab10; to 'JaxprTracer' at 0x7c31a1459620>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  16%|█▋        | 14/85 [00:39<03:13,  2.72s/it, loss=0.634]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.22282907 0.79676473 0.27701858 1.4927821  0.45355117 0.93672836
 0.10616844 0.4291991  0.17798442 0.40975958 0.2616936  1.6923686
 0.15684119 0.50184447 0.27252582 0.41862327], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.22282907, 0.79676473, 0.27701858, 1.4927821 , 0.45355117,
       0.93672836, 0.10616844, 0.4291991 , 0.17798442, 0.40975958,
       0.2616936 , 1.6923686 , 0.15684119, 0.50184447, 0.27252582,
       0.41862327], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a640b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a13fe390; to 'JaxprTracer' at 0x7c31a13fefc0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  18%|█▊        | 15/85 [00:41<03:06,  2.66s/it, loss=0.538]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.53745705 1.5272192  0.11671117 0.53844666 0.41489416 0.66556996
 0.1892807  0.4338803  0.3299507  0.2079066  0.71431065 1.4092255
 0.14857031 0.43425763 0.43299744 0.8120667 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.53745705, 1.5272192 , 0.11671117, 0.53844666, 0.41489416,
       0.66556996, 0.1892807 , 0.4338803 , 0.3299507 , 0.2079066 ,
       0.71431065, 1.4092255 , 0.14857031, 0.43425763, 0.43299744,
       0.8120667 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119eb20>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15d0e00; to 'JaxprTracer' at 0x7c31a15d3e20>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  19%|█▉        | 16/85 [00:44<03:15,  2.83s/it, loss=0.557]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.26250497 0.12508677 0.20095934 0.36316404 0.96461886 0.14792205
 0.11074874 0.30883488 0.22960722 0.31788293 0.42470422 0.21946199
 2.2871296  0.37091422 0.37580624 1.1604335 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.26250497, 0.12508677, 0.20095934, 0.36316404, 0.96461886,
       0.14792205, 0.11074874, 0.30883488, 0.22960722, 0.31788293,
       0.42470422, 0.21946199, 2.2871296 , 0.37091422, 0.37580624,
       1.1604335 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b85980>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1257b00; to 'JaxprTracer' at 0x7c31a12551c0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  20%|██        | 17/85 [00:48<03:17,  2.91s/it, loss=0.492]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.93286633 0.30687147 0.29913774 0.15998113 0.8746135  1.4789188
 0.25010082 0.35647786 0.29261902 0.49678618 0.45780587 0.41111624
 0.3802785  0.18284082 1.1781617  0.09742545], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.93286633, 0.30687147, 0.29913774, 0.15998113, 0.8746135 ,
       1.4789188 , 0.25010082, 0.35647786, 0.29261902, 0.49678618,
       0.45780587, 0.41111624, 0.3802785 , 0.18284082, 1.1781617 ,
       0.09742545], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119da40>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0d42200; to 'JaxprTracer' at 0x7c31a0d41580>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  21%|██        | 18/85 [00:50<03:07,  2.80s/it, loss=0.51]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.12834926 0.30263844 1.5450668  0.06876551 0.6678393  0.46328312
 0.31257972 0.30236623 0.7280457  0.77004707 0.2514138  0.51792914
 0.20264548 0.51643497 0.8407648  0.31652796], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.12834926, 0.30263844, 1.5450668 , 0.06876551, 0.6678393 ,
       0.46328312, 0.31257972, 0.30236623, 0.7280457 , 0.77004707,
       0.2514138 , 0.51792914, 0.20264548, 0.51643497, 0.8407648 ,
       0.31652796], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119fc00>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16d7830; to 'JaxprTracer' at 0x7c31a16d6660>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  22%|██▏       | 19/85 [00:53<02:59,  2.71s/it, loss=0.496]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.09794969 0.1165026  0.11136375 0.91814095 0.33607185 0.23987463
 0.17780082 0.6209464  0.4044701  0.5062479  0.5437043  1.1303868
 0.34692678 0.5151683  0.5801208  0.9121351 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.09794969, 0.1165026 , 0.11136375, 0.91814095, 0.33607185,
       0.23987463, 0.17780082, 0.6209464 , 0.4044701 , 0.5062479 ,
       0.5437043 , 1.1303868 , 0.34692678, 0.5151683 , 0.5801208 ,
       0.9121351 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119fa80>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a10578d0; to 'JaxprTracer' at 0x7c31a0e43600>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  24%|██▎       | 20/85 [00:55<02:51,  2.64s/it, loss=0.472]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.22957301 1.5635518  0.04314367 0.16560814 0.20302717 0.81456053
 1.1101856  0.08175042 0.5498292  0.25299427 0.46146107 0.59256834
 0.30905092 0.08985213 0.9973874  0.24036624], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.22957301, 1.5635518 , 0.04314367, 0.16560814, 0.20302717,
       0.81456053, 1.1101856 , 0.08175042, 0.5498292 , 0.25299427,
       0.46146107, 0.59256834, 0.30905092, 0.08985213, 0.9973874 ,
       0.24036624], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119cbb0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a12f08b0; to 'JaxprTracer' at 0x7c31a12f1850>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  25%|██▍       | 21/85 [00:58<03:01,  2.84s/it, loss=0.482]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.04524366 1.5467458  0.5179986  0.1605533  0.09121621 0.15951343
 0.41520807 0.9089477  0.57840896 1.2359316  0.09765868 0.23394828
 0.46406978 1.1215374  0.16500156 0.3293966 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.04524366, 1.5467458 , 0.5179986 , 0.1605533 , 0.09121621,
       0.15951343, 0.41520807, 0.9089477 , 0.57840896, 1.2359316 ,
       0.09765868, 0.23394828, 0.46406978, 1.1215374 , 0.16500156,
       0.3293966 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119ed50>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e9f470; to 'JaxprTracer' at 0x7c31a0e9c720>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  26%|██▌       | 22/85 [01:01<03:02,  2.89s/it, loss=0.504]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.0218121  0.12351469 1.39558    0.56400883 0.10861436 0.27405334
 0.35500664 0.24879774 0.19536616 0.19063734 0.34526896 0.2048432
 0.25682014 0.33370468 0.14493845 0.17043984], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.0218121 , 0.12351469, 1.39558   , 0.56400883, 0.10861436,
       0.27405334, 0.35500664, 0.24879774, 0.19536616, 0.19063734,
       0.34526896, 0.2048432 , 0.25682014, 0.33370468, 0.14493845,
       0.17043984], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119fbf0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16922f0; to 'JaxprTracer' at 0x7c31a1690ea0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  27%|██▋       | 23/85 [01:04<02:52,  2.77s/it, loss=0.371]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.46935695 0.10410815 0.8039524  0.37413093 1.4247851  0.12164621
 0.6636103  0.3500679  0.35721505 0.85306305 0.21099192 0.6644893
 1.0324327  0.13518138 0.67199165 0.25491995], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.46935695, 0.10410815, 0.8039524 , 0.37413093, 1.4247851 ,
       0.12164621, 0.6636103 , 0.3500679 , 0.35721505, 0.85306305,
       0.21099192, 0.6644893 , 1.0324327 , 0.13518138, 0.67199165,
       0.25491995], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119e270>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a11c1300; to 'JaxprTracer' at 0x7c31a11c3100>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  28%|██▊       | 24/85 [01:06<02:44,  2.70s/it, loss=0.531]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.411878   0.30973664 0.91301966 0.12242691 0.20350464 1.2350172
 0.09770982 0.09890904 0.12925778 0.10113855 2.2245612  0.3099243
 0.11110467 0.2812437  0.7009343  0.06084114], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.411878  , 0.30973664, 0.91301966, 0.12242691, 0.20350464,
       1.2350172 , 0.09770982, 0.09890904, 0.12925778, 0.10113855,
       2.2245612 , 0.3099243 , 0.11110467, 0.2812437 , 0.7009343 ,
       0.06084114], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119cc10>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1641da0; to 'JaxprTracer' at 0x7c31a1641d50>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  29%|██▉       | 25/85 [01:09<02:38,  2.64s/it, loss=0.457]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.2891294  0.607938   0.20675392 0.6733972  0.72792137 0.1847411
 0.21321522 0.07518793 0.50913954 0.13673139 0.22058108 0.14295019
 0.13520075 0.4287887  0.21814537 0.16599062], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.2891294 , 0.607938  , 0.20675392, 0.6733972 , 0.72792137,
       0.1847411 , 0.21321522, 0.07518793, 0.50913954, 0.13673139,
       0.22058108, 0.14295019, 0.13520075, 0.4287887 , 0.21814537,
       0.16599062], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196d760>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1187e70; to 'JaxprTracer' at 0x7c31a0e004f0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  31%|███       | 26/85 [01:12<02:49,  2.87s/it, loss=0.371]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.4553451  0.50904065 0.3134517  0.20582032 0.13633433 0.36465773
 1.0840305  0.23539105 0.5507964  0.13860005 0.77026445 0.577615
 2.0143447  0.8712475  0.30158058 0.61154175], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.4553451 , 0.50904065, 0.3134517 , 0.20582032, 0.13633433,
       0.36465773, 1.0840305 , 0.23539105, 0.5507964 , 0.13860005,
       0.77026445, 0.577615  , 2.0143447 , 0.8712475 , 0.30158058,
       0.61154175], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196e710>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1610db0; to 'JaxprTracer' at 0x7c31a1186b10>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  32%|███▏      | 27/85 [01:15<02:43,  2.83s/it, loss=0.571]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.12738039 1.7844826  0.4006019  0.24663502 1.8705528  0.1100647
 0.19295771 2.1294324  0.43664798 0.6844912  0.70598286 0.06392639
 1.1379617  0.17715175 0.35404262 0.06227054], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.12738039, 1.7844826 , 0.4006019 , 0.24663502, 1.8705528 ,
       0.1100647 , 0.19295771, 2.1294324 , 0.43664798, 0.6844912 ,
       0.70598286, 0.06392639, 1.1379617 , 0.17715175, 0.35404262,
       0.06227054], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655930>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1cd93a0; to 'JaxprTracer' at 0x7c31a1cdb0b0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  33%|███▎      | 28/85 [01:18<02:35,  2.73s/it, loss=0.655]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.6606076  0.1707452  0.98640096 1.2500196  0.6484182  0.19084834
 0.23995173 0.13624103 0.48580986 0.12001458 0.4188032  0.35146034
 0.17307071 0.5920633  0.19877173 0.45041478], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.6606076 , 0.1707452 , 0.98640096, 1.2500196 , 0.6484182 ,
       0.19084834, 0.23995173, 0.13624103, 0.48580986, 0.12001458,
       0.4188032 , 0.35146034, 0.17307071, 0.5920633 , 0.19877173,
       0.45041478], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196f8b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e10540; to 'JaxprTracer' at 0x7c31a0e10720>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  34%|███▍      | 29/85 [01:20<02:29,  2.66s/it, loss=0.505]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.19528928 0.07225184 0.06756321 0.2292108  0.14319547 0.16862184
 0.3076599  0.20990998 0.67386633 1.7726076  0.18957879 0.3366768
 0.34114367 0.278949   0.23325405 0.16167466], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.19528928, 0.07225184, 0.06756321, 0.2292108 , 0.14319547,
       0.16862184, 0.3076599 , 0.20990998, 0.67386633, 1.7726076 ,
       0.18957879, 0.3366768 , 0.34114367, 0.278949  , 0.23325405,
       0.16167466], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196df20>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a148e250; to 'JaxprTracer' at 0x7c31a148fe20>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  35%|███▌      | 30/85 [01:23<02:23,  2.61s/it, loss=0.336]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.14401527 0.16384235 0.2304727  0.5701067  0.12183049 0.7611809
 0.18905911 1.5972333  0.32908776 1.6162486  0.52080786 1.8815833
 0.0960631  1.605149   0.12708953 0.23752792], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.14401527, 0.16384235, 0.2304727 , 0.5701067 , 0.12183049,
       0.7611809 , 0.18905911, 1.5972333 , 0.32908776, 1.6162486 ,
       0.52080786, 1.8815833 , 0.0960631 , 1.605149  , 0.12708953,
       0.23752792], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196fcd0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1673ec0; to 'JaxprTracer' at 0x7c31a1607ce0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  36%|███▋      | 31/85 [01:26<02:36,  2.90s/it, loss=0.637]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.68740714 0.17232418 0.9013869  0.40679285 0.11746573 0.1307533
 0.8947002  0.07778042 0.31133223 0.06737612 0.35175675 0.37441888
 0.10097036 0.06745502 0.27272168 0.1327079 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.68740714, 0.17232418, 0.9013869 , 0.40679285, 0.11746573,
       0.1307533 , 0.8947002 , 0.07778042, 0.31133223, 0.06737612,
       0.35175675, 0.37441888, 0.10097036, 0.06745502, 0.27272168,
       0.1327079 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a65230>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1264040; to 'JaxprTracer' at 0x7c31a0e707c0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  38%|███▊      | 32/85 [01:29<02:27,  2.79s/it, loss=0.317]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.27412078 0.04795406 0.16664311 0.03499174 0.35414067 0.12683491
 0.14827485 0.09845519 0.16482669 0.86843705 0.2854482  0.05871174
 0.24935707 0.38648313 1.155739   0.24878287], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.27412078, 0.04795406, 0.16664311, 0.03499174, 0.35414067,
       0.12683491, 0.14827485, 0.09845519, 0.16482669, 0.86843705,
       0.2854482 , 0.05871174, 0.24935707, 0.38648313, 1.155739  ,
       0.24878287], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119ee30>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1517510; to 'JaxprTracer' at 0x7c31a0e72020>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  39%|███▉      | 33/85 [01:31<02:21,  2.73s/it, loss=0.292]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.39294574 1.273253   0.4635897  0.25447196 0.38475585 0.5327649
 0.10685462 0.24031533 0.543664   0.61926806 0.05479304 0.23914988
 0.23220783 0.22237727 0.43048897 0.1293241 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.39294574, 1.273253  , 0.4635897 , 0.25447196, 0.38475585,
       0.5327649 , 0.10685462, 0.24031533, 0.543664  , 0.61926806,
       0.05479304, 0.23914988, 0.23220783, 0.22237727, 0.43048897,
       0.1293241 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196d290>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1f2d710; to 'JaxprTracer' at 0x7c31a1f2e5c0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  40%|████      | 34/85 [01:34<02:15,  2.65s/it, loss=0.383]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.22063339 0.8619743  0.02998853 0.587987   0.47427183 1.4608269
 0.0380483  0.09006873 0.52939755 0.03980864 0.3090097  0.03866608
 0.54673177 0.61290616 0.02898849 0.38080624], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.22063339, 0.8619743 , 0.02998853, 0.587987  , 0.47427183,
       1.4608269 , 0.0380483 , 0.09006873, 0.52939755, 0.03980864,
       0.3090097 , 0.03866608, 0.54673177, 0.61290616, 0.02898849,
       0.38080624], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196ead0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a95080; to 'JaxprTracer' at 0x7c31a1a95350>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  41%|████      | 35/85 [01:36<02:10,  2.60s/it, loss=0.391]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.03585619 0.57630837 0.51179004 0.34160572 0.3413002  0.33622935
 0.04379437 0.08313523 1.1162375  0.03117386 0.19323446 0.43764406
 0.06039482 0.02489435 0.6946074  0.51182544], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.03585619, 0.57630837, 0.51179004, 0.34160572, 0.3413002 ,
       0.33622935, 0.04379437, 0.08313523, 1.1162375 , 0.03117386,
       0.19323446, 0.43764406, 0.06039482, 0.02489435, 0.6946074 ,
       0.51182544], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16542f0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0f73b50; to 'JaxprTracer' at 0x7c31a0f72390>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  42%|████▏     | 36/85 [01:40<02:20,  2.86s/it, loss=0.334]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.14983787 1.2648131  0.10267001 0.54560816 0.15445387 1.6421154
 0.9271414  0.05053937 0.20347108 0.2696096  0.0313943  0.19766258
 0.15174255 0.25196815 0.1971328  0.0657637 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.14983787, 1.2648131 , 0.10267001, 0.54560816, 0.15445387,
       1.6421154 , 0.9271414 , 0.05053937, 0.20347108, 0.2696096 ,
       0.0313943 , 0.19766258, 0.15174255, 0.25196815, 0.1971328 ,
       0.0657637 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196e910>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a95c10; to 'JaxprTracer' at 0x7c31a1a97b00>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  44%|████▎     | 37/85 [01:42<02:12,  2.76s/it, loss=0.388]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.420752   0.1757029  0.25706714 0.37251747 1.105494   0.23925827
 0.25201353 0.14164977 0.14792791 0.49107873 0.14011006 0.30062714
 0.09343502 0.07849472 0.11491781 0.06033062], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.420752  , 0.1757029 , 0.25706714, 0.37251747, 1.105494  ,
       0.23925827, 0.25201353, 0.14164977, 0.14792791, 0.49107873,
       0.14011006, 0.30062714, 0.09343502, 0.07849472, 0.11491781,
       0.06033062], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64e00>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15d2a70; to 'JaxprTracer' at 0x7c31a15d2930>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  45%|████▍     | 38/85 [01:45<02:06,  2.68s/it, loss=0.337]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.03326382 0.09086847 0.10820855 0.10812981 0.09686998 0.30796346
 0.39341578 0.05287009 0.24780415 1.3249805  0.95853645 0.1461442
 0.11082536 0.02450405 0.39361218 0.02515141], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.03326382, 0.09086847, 0.10820855, 0.10812981, 0.09686998,
       0.30796346, 0.39341578, 0.05287009, 0.24780415, 1.3249805 ,
       0.95853645, 0.1461442 , 0.11082536, 0.02450405, 0.39361218,
       0.02515141], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64c60>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0fa3150; to 'JaxprTracer' at 0x7c31a0fa1d50>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  46%|████▌     | 39/85 [01:47<02:04,  2.71s/it, loss=0.276]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.04314218 0.12725633 0.09403048 1.9586359  0.31951547 0.8821667
 0.9096242  0.3567606  0.04284494 0.06566782 0.0705644  1.1323649
 0.13476057 0.07612053 0.16575177 0.7891676 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.04314218, 0.12725633, 0.09403048, 1.9586359 , 0.31951547,
       0.8821667 , 0.9096242 , 0.3567606 , 0.04284494, 0.06566782,
       0.0705644 , 1.1323649 , 0.13476057, 0.07612053, 0.16575177,
       0.7891676 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b86bf0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0fa20c0; to 'JaxprTracer' at 0x7c31a0fa1260>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  47%|████▋     | 40/85 [01:50<02:03,  2.75s/it, loss=0.448]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.06590032 0.04567835 0.03189959 0.05721442 0.06211047 0.2835627
 0.2306874  0.8620042  2.3940232  1.1148779  0.16987751 0.21768577
 0.11365753 1.485494   1.6344253  0.05441593], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.06590032, 0.04567835, 0.03189959, 0.05721442, 0.06211047,
       0.2835627 , 0.2306874 , 0.8620042 , 2.3940232 , 1.1148779 ,
       0.16987751, 0.21768577, 0.11365753, 1.485494  , 1.6344253 ,
       0.05441593], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a662e0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1692020; to 'JaxprTracer' at 0x7c31a1693ec0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  48%|████▊     | 41/85 [01:54<02:08,  2.92s/it, loss=0.551]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.3896926  0.01235887 0.13679762 0.19071546 0.27011946 0.02528311
 0.09895882 0.04897561 0.18790306 0.07548522 0.18654153 0.09595394
 1.1841147  0.01127114 0.08320895 0.8690045 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.3896926 , 0.01235887, 0.13679762, 0.19071546, 0.27011946,
       0.02528311, 0.09895882, 0.04897561, 0.18790306, 0.07548522,
       0.18654153, 0.09595394, 1.1841147 , 0.01127114, 0.08320895,
       0.8690045 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64130>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a11901d0; to 'JaxprTracer' at 0x7c31a1191d00>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  49%|████▉     | 42/85 [01:56<01:59,  2.78s/it, loss=0.242]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.7127072  0.36399508 0.05086935 0.13501745 0.13152316 0.5963629
 0.03367148 0.5743666  0.06951853 0.14625739 1.1600596  0.06384274
 0.03565363 0.5437715  0.06387082 0.06511173], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.7127072 , 0.36399508, 0.05086935, 0.13501745, 0.13152316,
       0.5963629 , 0.03367148, 0.5743666 , 0.06951853, 0.14625739,
       1.1600596 , 0.06384274, 0.03565363, 0.5437715 , 0.06387082,
       0.06511173], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a667b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1bf3010; to 'JaxprTracer' at 0x7c31a1bf3b50>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  51%|█████     | 43/85 [01:59<01:53,  2.71s/it, loss=0.297]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.7024825  0.10881025 0.04864638 0.86087286 0.03817269 1.2011867
 0.07036176 0.24396504 0.04735381 1.9198252  0.0162958  0.21393861
 0.8571071  0.06452304 0.14043953 0.05641174], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.7024825 , 0.10881025, 0.04864638, 0.86087286, 0.03817269,
       1.2011867 , 0.07036176, 0.24396504, 0.04735381, 1.9198252 ,
       0.0162958 , 0.21393861, 0.8571071 , 0.06452304, 0.14043953,
       0.05641174], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a66700>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1264090; to 'JaxprTracer' at 0x7c31a0e41440>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  52%|█████▏    | 44/85 [02:01<01:49,  2.66s/it, loss=0.474]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.0142838  0.1828988  0.2720182  0.66979533 0.1658352  0.03668449
 0.09318407 0.06388748 0.07882232 0.02817151 1.60029    0.09511945
 0.46676487 0.19364628 0.12762342 2.663711  ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.0142838 , 0.1828988 , 0.2720182 , 0.66979533, 0.1658352 ,
       0.03668449, 0.09318407, 0.06388748, 0.07882232, 0.02817151,
       1.60029   , 0.09511945, 0.46676487, 0.19364628, 0.12762342,
       2.663711  ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64340>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a13fbfb0; to 'JaxprTracer' at 0x7c31a13fbc40>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  53%|█████▎    | 45/85 [02:04<01:52,  2.82s/it, loss=0.422]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02397463 0.05063264 1.8695835  0.01932975 2.0646696  2.0573015
 0.03102386 0.04442983 0.02589989 0.09197285 0.02296984 0.6222632
 0.5526484  0.06221789 2.197869   0.03295267], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02397463, 0.05063264, 1.8695835 , 0.01932975, 2.0646696 ,
       2.0573015 , 0.03102386, 0.04442983, 0.02589989, 0.09197285,
       0.02296984, 0.6222632 , 0.5526484 , 0.06221789, 2.197869  ,
       0.03295267], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16565e0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0f301d0; to 'JaxprTracer' at 0x7c31a0f31ee0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  54%|█████▍    | 46/85 [02:07<01:53,  2.91s/it, loss=0.611]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00562821 0.9699531  1.7713817  0.03828261 0.12988329 0.01995336
 0.0194351  0.04412544 0.01267449 0.7982569  0.0158586  0.17648241
 0.08565219 0.4380029  0.15390569 0.04250134], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00562821, 0.9699531 , 1.7713817 , 0.03828261, 0.12988329,
       0.01995336, 0.0194351 , 0.04412544, 0.01267449, 0.7982569 ,
       0.0158586 , 0.17648241, 0.08565219, 0.4380029 , 0.15390569,
       0.04250134], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16545a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e418a0; to 'JaxprTracer' at 0x7c31a0e427f0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  55%|█████▌    | 47/85 [02:10<01:46,  2.81s/it, loss=0.295]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([2.1050725  1.1998425  1.2115566  0.5034612  0.05335223 0.40523937
 2.6902378  2.978791   0.05712356 0.22969316 0.5256881  0.02067923
 0.05784782 0.35209715 0.05446199 0.09646489], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([2.1050725 , 1.1998425 , 1.2115566 , 0.5034612 , 0.05335223,
       0.40523937, 2.6902378 , 2.978791  , 0.05712356, 0.22969316,
       0.5256881 , 0.02067923, 0.05784782, 0.35209715, 0.05446199,
       0.09646489], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119ef50>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a20130; to 'JaxprTracer' at 0x7c31a1a235b0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  56%|█████▋    | 48/85 [02:13<01:40,  2.72s/it, loss=0.784]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.2154014  0.16689445 0.2927046  0.1195187  0.04826773 0.73978806
 0.04655068 0.50884664 0.06284791 0.11691482 0.12051708 0.17892972
 0.07433868 0.01852731 0.10107844 0.0707198 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.2154014 , 0.16689445, 0.2927046 , 0.1195187 , 0.04826773,
       0.73978806, 0.04655068, 0.50884664, 0.06284791, 0.11691482,
       0.12051708, 0.17892972, 0.07433868, 0.01852731, 0.10107844,
       0.0707198 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119f210>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ac3420; to 'JaxprTracer' at 0x7c31a1ac2890>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  58%|█████▊    | 49/85 [02:15<01:35,  2.65s/it, loss=0.243]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.51648307 0.3402838  0.0136156  0.05049857 0.04884927 0.17419314
 0.14635913 0.04422159 0.6498773  1.0649745  0.13262239 2.668006
 0.02653044 0.0725949  0.16601041 0.59889954], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.51648307, 0.3402838 , 0.0136156 , 0.05049857, 0.04884927,
       0.17419314, 0.14635913, 0.04422159, 0.6498773 , 1.0649745 ,
       0.13262239, 2.668006  , 0.02653044, 0.0725949 , 0.16601041,
       0.59889954], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b85ac0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b30900; to 'JaxprTracer' at 0x7c31a1b328e0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  59%|█████▉    | 50/85 [02:18<01:37,  2.80s/it, loss=0.42]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.03648429 1.2157333  1.3987706  0.12665091 0.20567058 0.03167788
 0.07920863 0.07201958 0.18701753 0.48778445 0.18806042 0.09342828
 0.04110676 0.10661633 0.0296958  0.7759727 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.03648429, 1.2157333 , 1.3987706 , 0.12665091, 0.20567058,
       0.03167788, 0.07920863, 0.07201958, 0.18701753, 0.48778445,
       0.18806042, 0.09342828, 0.04110676, 0.10661633, 0.0296958 ,
       0.7759727 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119e870>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a10d22f0; to 'JaxprTracer' at 0x7c31a10d0d10>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  60%|██████    | 51/85 [02:22<01:44,  3.06s/it, loss=0.317]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.5092619  2.1020074  0.05804829 0.01106661 0.52623165 0.08526454
 0.05970385 0.0571176  0.06374007 0.67874146 0.17019995 0.2172141
 2.1609452  1.4021988  0.12200809 0.0738844 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.5092619 , 2.1020074 , 0.05804829, 0.01106661, 0.52623165,
       0.08526454, 0.05970385, 0.0571176 , 0.06374007, 0.67874146,
       0.17019995, 0.2172141 , 2.1609452 , 1.4021988 , 0.12200809,
       0.0738844 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a660e0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1bd8950; to 'JaxprTracer' at 0x7c31a1bd8590>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  61%|██████    | 52/85 [02:24<01:36,  2.91s/it, loss=0.519]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.23446089 0.13846523 0.28065276 0.1440033  0.24210677 0.11541364
 0.03418956 0.11243805 0.07133205 0.18569152 0.03583468 0.29856995
 0.06498159 0.5584928  0.13203152 1.4601775 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.23446089, 0.13846523, 0.28065276, 0.1440033 , 0.24210677,
       0.11541364, 0.03418956, 0.11243805, 0.07133205, 0.18569152,
       0.03583468, 0.29856995, 0.06498159, 0.5584928 , 0.13203152,
       1.4601775 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119f940>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1c94090; to 'JaxprTracer' at 0x7c31a1c97f10>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  62%|██████▏   | 53/85 [02:27<01:30,  2.83s/it, loss=0.257]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.10611366 2.429992   0.0125981  1.790031   0.22685969 0.09975473
 0.074342   0.15071756 0.37660486 0.01890194 0.12245275 1.4991311
 0.13339457 1.1807549  0.06013903 0.27134782], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.10611366, 2.429992  , 0.0125981 , 1.790031  , 0.22685969,
       0.09975473, 0.074342  , 0.15071756, 0.37660486, 0.01890194,
       0.12245275, 1.4991311 , 0.13339457, 1.1807549 , 0.06013903,
       0.27134782], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b86140>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1bdb9c0; to 'JaxprTracer' at 0x7c31a1bd9bc0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  64%|██████▎   | 54/85 [02:30<01:31,  2.94s/it, loss=0.535]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.0488386  1.473896   0.20535056 0.23088746 0.84436333 0.8035552
 0.1404823  0.02078864 0.02499958 0.12345169 0.08968191 0.6524487
 0.08061016 0.87772894 0.01226608 0.09441931], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.0488386 , 1.473896  , 0.20535056, 0.23088746, 0.84436333,
       0.8035552 , 0.1404823 , 0.02078864, 0.02499958, 0.12345169,
       0.08968191, 0.6524487 , 0.08061016, 0.87772894, 0.01226608,
       0.09441931], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b86c60>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e9e2f0; to 'JaxprTracer' at 0x7c31a0e72ed0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  65%|██████▍   | 55/85 [02:33<01:27,  2.91s/it, loss=0.358]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.6538749  0.04082779 0.0439549  0.09913686 1.332797   0.44534367
 0.21674736 0.21325529 0.6073345  1.1905872  0.02397801 0.23706897
 0.06895201 0.03019941 0.07037154 0.14573139], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.6538749 , 0.04082779, 0.0439549 , 0.09913686, 1.332797  ,
       0.44534367, 0.21674736, 0.21325529, 0.6073345 , 1.1905872 ,
       0.02397801, 0.23706897, 0.06895201, 0.03019941, 0.07037154,
       0.14573139], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b86e30>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1bd8e50; to 'JaxprTracer' at 0x7c31a1bdb920>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  66%|██████▌   | 56/85 [02:36<01:24,  2.91s/it, loss=0.339]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.0183362  0.21106905 0.1521537  0.04269668 0.2162153  1.9597255
 0.01888603 0.13426988 0.12478082 0.313346   0.3839855  0.1460967
 0.01392742 1.1273584  0.23440357 0.68654203], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.0183362 , 0.21106905, 0.1521537 , 0.04269668, 0.2162153 ,
       1.9597255 , 0.01888603, 0.13426988, 0.12478082, 0.313346  ,
       0.3839855 , 0.1460967 , 0.01392742, 1.1273584 , 0.23440357,
       0.68654203], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b867d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a155a020; to 'JaxprTracer' at 0x7c31a1559530>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  67%|██████▋   | 57/85 [02:39<01:18,  2.81s/it, loss=0.361]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01689189 0.14660636 0.0048357  0.07245787 0.00702862 0.04621791
 0.90697545 0.03735813 0.14917092 0.04163736 0.02398371 1.5949458
 0.21571773 0.44596562 0.03420166 0.01309733], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01689189, 0.14660636, 0.0048357 , 0.07245787, 0.00702862,
       0.04621791, 0.90697545, 0.03735813, 0.14917092, 0.04163736,
       0.02398371, 1.5949458 , 0.21571773, 0.44596562, 0.03420166,
       0.01309733], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b86b00>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a10d1530; to 'JaxprTracer' at 0x7c31a1b31850>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  68%|██████▊   | 58/85 [02:41<01:13,  2.73s/it, loss=0.235]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.06456126 0.27229404 0.08295224 0.37183058 0.35865027 0.04292009
 0.07026887 1.3329709  0.38050318 0.42440966 0.03405246 0.01293671
 0.17784782 0.33240548 0.29838803 0.22607085], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.06456126, 0.27229404, 0.08295224, 0.37183058, 0.35865027,
       0.04292009, 0.07026887, 1.3329709 , 0.38050318, 0.42440966,
       0.03405246, 0.01293671, 0.17784782, 0.33240548, 0.29838803,
       0.22607085], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84380>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15f3ab0; to 'JaxprTracer' at 0x7c31a15f1f80>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  69%|██████▉   | 59/85 [02:44<01:14,  2.87s/it, loss=0.28]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.40147063 0.17998426 0.09482902 0.5225248  0.19643486 0.03209806
 0.00693249 0.15544115 0.343026   0.21360129 0.22456418 0.08359367
 0.13263597 0.09472103 0.3725656  0.8307886 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.40147063, 0.17998426, 0.09482902, 0.5225248 , 0.19643486,
       0.03209806, 0.00693249, 0.15544115, 0.343026  , 0.21360129,
       0.22456418, 0.08359367, 0.13263597, 0.09472103, 0.3725656 ,
       0.8307886 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b853b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a145a2a0; to 'JaxprTracer' at 0x7c31a145ab10>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  71%|███████   | 60/85 [02:47<01:09,  2.77s/it, loss=0.243]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01057592 0.36846626 0.13610671 0.32168803 0.22448714 0.02963307
 0.7948687  0.06814878 0.05848284 0.16961938 0.37976974 0.20238885
 0.21079604 0.15751405 0.55820775 0.03891596], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01057592, 0.36846626, 0.13610671, 0.32168803, 0.22448714,
       0.02963307, 0.7948687 , 0.06814878, 0.05848284, 0.16961938,
       0.37976974, 0.20238885, 0.21079604, 0.15751405, 0.55820775,
       0.03891596], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b87a20>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1346a70; to 'JaxprTracer' at 0x7c31a1346930>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  72%|███████▏  | 61/85 [02:50<01:08,  2.84s/it, loss=0.233]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.06651876 0.07329409 0.05278235 0.15862426 0.05244226 0.05690579
 0.0245871  0.02676515 0.07832895 0.01536967 0.9895294  0.2857486
 0.08348391 0.10736302 1.4228963  0.02852563], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.06651876, 0.07329409, 0.05278235, 0.15862426, 0.05244226,
       0.05690579, 0.0245871 , 0.02676515, 0.07832895, 0.01536967,
       0.9895294 , 0.2857486 , 0.08348391, 0.10736302, 1.4228963 ,
       0.02852563], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b86c00>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1185fd0; to 'JaxprTracer' at 0x7c31a1186b10>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  73%|███████▎  | 62/85 [02:52<01:03,  2.75s/it, loss=0.22]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.05266157 0.09375483 0.15610771 0.01630108 0.6062811  1.1499807
 0.07388993 0.18897237 0.05739498 0.05091886 1.9209957  0.03281967
 1.0196457  0.07417201 0.13913469 0.04357573], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.05266157, 0.09375483, 0.15610771, 0.01630108, 0.6062811 ,
       1.1499807 , 0.07388993, 0.18897237, 0.05739498, 0.05091886,
       1.9209957 , 0.03281967, 1.0196457 , 0.07417201, 0.13913469,
       0.04357573], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b87b40>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1cfc9f0; to 'JaxprTracer' at 0x7c31a1cff880>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  74%|███████▍  | 63/85 [02:55<01:00,  2.75s/it, loss=0.355]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.24675053 0.1616061  0.68956363 0.05578983 0.04850886 0.02353821
 1.919482   0.04290045 0.03092423 0.07791827 0.08367733 0.84632003
 0.3653901  0.10367289 0.02166276 0.07418451], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.24675053, 0.1616061 , 0.68956363, 0.05578983, 0.04850886,
       0.02353821, 1.919482  , 0.04290045, 0.03092423, 0.07791827,
       0.08367733, 0.84632003, 0.3653901 , 0.10367289, 0.02166276,
       0.07418451], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64ab0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e3f5b0; to 'JaxprTracer' at 0x7c31a1c56750>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  75%|███████▌  | 64/85 [02:58<01:00,  2.87s/it, loss=0.299]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.08547831 0.08702528 0.01610004 0.17329168 0.08717694 1.2387516
 0.05730808 0.2524593  0.03075975 0.7353941  0.3556265  0.03401478
 0.0525962  0.4149487  0.5141227  3.3502192 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.08547831, 0.08702528, 0.01610004, 0.17329168, 0.08717694,
       1.2387516 , 0.05730808, 0.2524593 , 0.03075975, 0.7353941 ,
       0.3556265 , 0.03401478, 0.0525962 , 0.4149487 , 0.5141227 ,
       3.3502192 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a665a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0fa29d0; to 'JaxprTracer' at 0x7c31a165dfd0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  76%|███████▋  | 65/85 [03:01<00:56,  2.82s/it, loss=0.468]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.07487882 0.07033931 1.3821943  0.08400182 0.02036682 0.09417284
 0.02270325 0.06184873 0.11035116 0.08812124 0.03749363 0.0886457
 0.84861994 0.12069534 0.15100676 2.0738187 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.07487882, 0.07033931, 1.3821943 , 0.08400182, 0.02036682,
       0.09417284, 0.02270325, 0.06184873, 0.11035116, 0.08812124,
       0.03749363, 0.0886457 , 0.84861994, 0.12069534, 0.15100676,
       2.0738187 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119cc10>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1c95b70; to 'JaxprTracer' at 0x7c31a1c963e0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  78%|███████▊  | 66/85 [03:04<00:54,  2.89s/it, loss=0.333]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.1702638  0.03319244 0.57720107 0.57375205 0.05547343 0.37220424
 0.0511137  0.05487666 0.09371097 0.2087094  0.14856662 0.0099921
 0.20535435 0.3807967  0.06112063 0.08140048], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.1702638 , 0.03319244, 0.57720107, 0.57375205, 0.05547343,
       0.37220424, 0.0511137 , 0.05487666, 0.09371097, 0.2087094 ,
       0.14856662, 0.0099921 , 0.20535435, 0.3807967 , 0.06112063,
       0.08140048], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b878d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16d58f0; to 'JaxprTracer' at 0x7c31a16d7330>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  79%|███████▉  | 67/85 [03:07<00:50,  2.78s/it, loss=0.192]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02313549 3.2220292  1.1178384  0.07131696 0.05989496 0.03256504
 0.25660878 0.0106192  1.2391621  1.2886636  0.06853556 0.03231253
 0.04207636 0.05299752 0.1068589  0.09237856], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02313549, 3.2220292 , 1.1178384 , 0.07131696, 0.05989496,
       0.03256504, 0.25660878, 0.0106192 , 1.2391621 , 1.2886636 ,
       0.06853556, 0.03231253, 0.04207636, 0.05299752, 0.1068589 ,
       0.09237856], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119dc80>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e85a80; to 'JaxprTracer' at 0x7c31a0e85fd0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  80%|████████  | 68/85 [03:10<00:49,  2.90s/it, loss=0.482]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02469457 0.06489547 0.02142663 0.6366961  0.06929565 2.1101308
 0.05563829 0.0250369  2.5917866  0.4050827  0.3456793  2.220981
 0.7086661  0.40551773 0.18099283 0.2295369 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02469457, 0.06489547, 0.02142663, 0.6366961 , 0.06929565,
       2.1101308 , 0.05563829, 0.0250369 , 2.5917866 , 0.4050827 ,
       0.3456793 , 2.220981  , 0.7086661 , 0.40551773, 0.18099283,
       0.2295369 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119c4e0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a2cfe0; to 'JaxprTracer' at 0x7c31a138c810>], out_avals=[ShapedArray(float32[16])], primitive=pjit, 

Training:  81%|████████  | 69/85 [03:12<00:45,  2.82s/it, loss=0.631]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01482717 1.899234   0.02512467 0.2944314  0.00982497 0.5154051
 0.33772364 0.04815595 0.0213698  0.2685706  0.02328726 0.3364988
 0.62724525 0.06855281 0.28742716 0.06658658], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01482717, 1.899234  , 0.02512467, 0.2944314 , 0.00982497,
       0.5154051 , 0.33772364, 0.04815595, 0.0213698 , 0.2685706 ,
       0.02328726, 0.3364988 , 0.62724525, 0.06855281, 0.28742716,
       0.06658658], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1657c90>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15f2160; to 'JaxprTracer' at 0x7c31a15f3ba0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  82%|████████▏ | 70/85 [03:15<00:41,  2.75s/it, loss=0.303]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01080342 0.7587257  0.04339585 0.18140595 0.01917411 2.2146351
 3.7902908  0.08174866 0.06757436 0.0057132  0.02140119 2.8086357
 0.15339679 2.068756   0.10682494 0.03275472], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01080342, 0.7587257 , 0.04339585, 0.18140595, 0.01917411,
       2.2146351 , 3.7902908 , 0.08174866, 0.06757436, 0.0057132 ,
       0.02140119, 2.8086357 , 0.15339679, 2.068756  , 0.10682494,
       0.03275472], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119cd00>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15f1b20; to 'JaxprTracer' at 0x7c31a15f3ab0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  84%|████████▎ | 71/85 [03:18<00:39,  2.81s/it, loss=0.773]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01122493 0.23759155 0.08926693 0.0350646  0.10347684 0.01557449
 0.34234414 0.09454784 0.08879439 0.09802632 0.0461347  0.38368574
 0.01737399 0.25314364 0.05331244 1.1597657 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01122493, 0.23759155, 0.08926693, 0.0350646 , 0.10347684,
       0.01557449, 0.34234414, 0.09454784, 0.08879439, 0.09802632,
       0.0461347 , 0.38368574, 0.01737399, 0.25314364, 0.05331244,
       1.1597657 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655d50>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e9f510; to 'JaxprTracer' at 0x7c31a0e9e4d0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  85%|████████▍ | 72/85 [03:21<00:36,  2.81s/it, loss=0.189]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.05148221 1.4838303  0.0491285  0.1214733  1.7409954  0.3747502
 0.03559473 0.01579829 0.4253853  0.16394971 1.1087244  0.02058966
 0.05343676 0.01935992 1.0744295  0.46938336], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.05148221, 1.4838303 , 0.0491285 , 0.1214733 , 1.7409954 ,
       0.3747502 , 0.03559473, 0.01579829, 0.4253853 , 0.16394971,
       1.1087244 , 0.02058966, 0.05343676, 0.01935992, 1.0744295 ,
       0.46938336], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119e9a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ac0ae0; to 'JaxprTracer' at 0x7c31a1ac3650>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  86%|████████▌ | 73/85 [03:24<00:34,  2.90s/it, loss=0.451]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.1715603  0.0425251  0.8800783  0.08395644 2.521214   0.38258904
 0.26737612 0.20625794 0.22484623 0.02578778 0.01919177 0.10779947
 0.10302334 0.07866488 0.55390334 0.16020583], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.1715603 , 0.0425251 , 0.8800783 , 0.08395644, 2.521214  ,
       0.38258904, 0.26737612, 0.20625794, 0.22484623, 0.02578778,
       0.01919177, 0.10779947, 0.10302334, 0.07866488, 0.55390334,
       0.16020583], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1654890>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a105ae80; to 'JaxprTracer' at 0x7c31a1193060>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  87%|████████▋ | 74/85 [03:26<00:30,  2.79s/it, loss=0.364]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.05832025 0.13668731 0.07156835 0.09201199 0.01792669 0.03138529
 0.12204049 0.23344068 0.03397422 0.20861709 0.22310817 0.15484501
 0.05426125 0.09785133 0.02273821 0.19999678], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.05832025, 0.13668731, 0.07156835, 0.09201199, 0.01792669,
       0.03138529, 0.12204049, 0.23344068, 0.03397422, 0.20861709,
       0.22310817, 0.15484501, 0.05426125, 0.09785133, 0.02273821,
       0.19999678], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a67900>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1265850; to 'JaxprTracer' at 0x7c31a12663e0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  88%|████████▊ | 75/85 [03:29<00:27,  2.71s/it, loss=0.11]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.022192   1.597303   0.04239862 0.01809072 0.81576306 0.07577635
 0.30490786 0.7400553  0.11383669 0.07060529 0.04359194 0.01154457
 0.7267121  1.412753   3.442303   0.08728128], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.022192  , 1.597303  , 0.04239862, 0.01809072, 0.81576306,
       0.07577635, 0.30490786, 0.7400553 , 0.11383669, 0.07060529,
       0.04359194, 0.01154457, 0.7267121 , 1.412753  , 3.442303  ,
       0.08728128], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84a90>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1642d40; to 'JaxprTracer' at 0x7c31a1642fc0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  89%|████████▉ | 76/85 [03:32<00:25,  2.79s/it, loss=0.595]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([2.0998225  0.04692656 0.39288908 0.05529737 0.08894905 0.11784524
 0.05662185 0.17597596 0.38462302 0.44671395 0.0844569  0.06836635
 0.04983362 0.872023   0.60546434 0.25764287], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([2.0998225 , 0.04692656, 0.39288908, 0.05529737, 0.08894905,
       0.11784524, 0.05662185, 0.17597596, 0.38462302, 0.44671395,
       0.0844569 , 0.06836635, 0.04983362, 0.872023  , 0.60546434,
       0.25764287], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b85200>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ef2700; to 'JaxprTracer' at 0x7c31a1ef3e20>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  91%|█████████ | 77/85 [03:35<00:22,  2.81s/it, loss=0.363]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.06917941 0.0480003  0.08925679 0.30369878 0.04617261 0.10952348
 0.03613091 0.0257333  0.18868832 1.0415956  0.14571552 2.4376647
 0.0205398  0.0629856  1.0042775  0.04640035], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.06917941, 0.0480003 , 0.08925679, 0.30369878, 0.04617261,
       0.10952348, 0.03613091, 0.0257333 , 0.18868832, 1.0415956 ,
       0.14571552, 2.4376647 , 0.0205398 , 0.0629856 , 1.0042775 ,
       0.04640035], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196e550>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31aa4c9350; to 'JaxprTracer' at 0x7c31a13f81d0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  92%|█████████▏| 78/85 [03:38<00:20,  2.87s/it, loss=0.355]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01410351 0.03374939 0.134123   0.04100859 0.01102982 0.0234496
 0.38730425 0.06611257 0.06871799 0.07923781 0.06132875 0.14937145
 0.08233861 0.18911062 0.35506532 0.80092806], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01410351, 0.03374939, 0.134123  , 0.04100859, 0.01102982,
       0.0234496 , 0.38730425, 0.06611257, 0.06871799, 0.07923781,
       0.06132875, 0.14937145, 0.08233861, 0.18911062, 0.35506532,
       0.80092806], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196f7b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15a73d0; to 'JaxprTracer' at 0x7c31a15a4d60>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  93%|█████████▎| 79/85 [03:40<00:16,  2.76s/it, loss=0.156]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.08969717 1.1429303  0.08223397 0.08734584 0.02145206 0.17152807
 1.0358894  0.02006004 0.06620496 0.11813297 1.0747693  0.0138932
 0.10241093 0.05977111 0.07459751 0.1290054 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.08969717, 1.1429303 , 0.08223397, 0.08734584, 0.02145206,
       0.17152807, 1.0358894 , 0.02006004, 0.06620496, 0.11813297,
       1.0747693 , 0.0138932 , 0.10241093, 0.05977111, 0.07459751,
       0.1290054 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196f160>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15a70b0; to 'JaxprTracer' at 0x7c31a15a4680>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  94%|█████████▍| 80/85 [03:43<00:13,  2.69s/it, loss=0.268]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.17618479 0.20730974 0.07867116 0.02688271 0.06955834 0.25534084
 0.7953696  0.4178054  0.01817747 0.12318297 0.03293987 0.04814709
 0.7987348  0.02557714 0.06271659 1.6411083 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.17618479, 0.20730974, 0.07867116, 0.02688271, 0.06955834,
       0.25534084, 0.7953696 , 0.4178054 , 0.01817747, 0.12318297,
       0.03293987, 0.04814709, 0.7987348 , 0.02557714, 0.06271659,
       1.6411083 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196fb60>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1254540; to 'JaxprTracer' at 0x7c31a1255350>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  95%|█████████▌| 81/85 [03:45<00:10,  2.65s/it, loss=0.299]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00652386 0.22762644 0.05445714 0.77011496 0.02778469 1.3488843
 0.2405221  0.13181023 0.00901309 1.2853985  0.9145181  0.10917938
 0.2814377  2.8098524  0.18720004 0.3811502 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00652386, 0.22762644, 0.05445714, 0.77011496, 0.02778469,
       1.3488843 , 0.2405221 , 0.13181023, 0.00901309, 1.2853985 ,
       0.9145181 , 0.10917938, 0.2814377 , 2.8098524 , 0.18720004,
       0.3811502 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196ef00>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1345d00; to 'JaxprTracer' at 0x7c31a1346250>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  96%|█████████▋| 82/85 [03:49<00:08,  2.90s/it, loss=0.549]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02133351 0.01013042 0.1514804  0.01861729 0.97752106 0.14898576
 0.36577103 0.27192283 0.61944854 0.4140563  0.01252076 0.13547209
 0.9385684  0.10911375 0.02082016 0.3713012 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02133351, 0.01013042, 0.1514804 , 0.01861729, 0.97752106,
       0.14898576, 0.36577103, 0.27192283, 0.61944854, 0.4140563 ,
       0.01252076, 0.13547209, 0.9385684 , 0.10911375, 0.02082016,
       0.3713012 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196f1d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a12f3650; to 'JaxprTracer' at 0x7c31a12f2b10>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  98%|█████████▊| 83/85 [03:52<00:05,  2.88s/it, loss=0.287]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.5210075  0.03246245 0.10983391 0.5382485  0.01599681 0.0413956
 0.05182818 0.04369806 0.87394977 0.20204684 0.15131862 0.13507192
 0.49945068 0.05478413 0.12236088 0.0353461 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.5210075 , 0.03246245, 0.10983391, 0.5382485 , 0.01599681,
       0.0413956 , 0.05182818, 0.04369806, 0.87394977, 0.20204684,
       0.15131862, 0.13507192, 0.49945068, 0.05478413, 0.12236088,
       0.0353461 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196f480>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a12f2e80; to 'JaxprTracer' at 0x7c31a12f07c0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  99%|█████████▉| 84/85 [03:54<00:02,  2.77s/it, loss=0.214]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.13605314 0.06714675 0.12067167 0.15090682 0.23070349 0.02359421
 0.9346942  0.01414229 0.06962505 0.07523481 0.32471523 0.09950428
 0.5509988  0.03711602 0.04683363 0.18572043], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.13605314, 0.06714675, 0.12067167, 0.15090682, 0.23070349,
       0.02359421, 0.9346942 , 0.01414229, 0.06962505, 0.07523481,
       0.32471523, 0.09950428, 0.5509988 , 0.03711602, 0.04683363,
       0.18572043], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a66420>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a12f0f90; to 'JaxprTracer' at 0x7c31a12f3b50>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training: 100%|██████████| 85/85 [03:57<00:00,  2.79s/it, loss=0.192]


Evaluating after epoch 2...


Evaluating: 100%|██████████| 21/21 [00:12<00:00,  1.63it/s]


Accuracy: 0.7708
Precision: 0.9355
Recall: 0.6270
Confusion Matrix:
[[143   8]
 [ 69 116]]
Epoch 3/3


Training:   0%|          | 0/85 [00:00<?, ?it/s]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.03619288 0.0786003  0.13703848 0.32872227 2.1032603  0.05101175
 0.03572334 0.454609   0.5399461  0.15347134 0.1279455  2.3937404
 0.01149849 2.0821345  0.12808248 0.6686789 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.03619288, 0.0786003 , 0.13703848, 0.32872227, 2.1032603 ,
       0.05101175, 0.03572334, 0.454609  , 0.5399461 , 0.15347134,
       0.1279455 , 2.3937404 , 0.01149849, 2.0821345 , 0.12808248,
       0.6686789 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1657330>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c34d83f84f0; to 'JaxprTracer' at 0x7c31a1c56bb0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:   1%|          | 1/85 [00:03<04:33,  3.26s/it, loss=0.583]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.67525536 0.14296083 0.069033   2.1457005  0.04550921 0.09357008
 0.30506578 0.02810301 0.70170236 1.2501658  0.12110858 0.07468845
 0.8596852  1.3638626  0.12341671 0.03628302], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.67525536, 0.14296083, 0.069033  , 2.1457005 , 0.04550921,
       0.09357008, 0.30506578, 0.02810301, 0.70170236, 1.2501658 ,
       0.12110858, 0.07468845, 0.8596852 , 1.3638626 , 0.12341671,
       0.03628302], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16566d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e3f150; to 'JaxprTracer' at 0x7c31a0e3f790>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:   2%|▏         | 2/85 [00:06<04:21,  3.15s/it, loss=0.502]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.05610106 0.24938345 0.03639349 0.03210938 0.06750149 0.1455363
 0.03457657 0.05580696 2.2009387  2.6731372  0.84900653 0.09296782
 0.07099355 0.6669437  0.05722241 0.5679516 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.05610106, 0.24938345, 0.03639349, 0.03210938, 0.06750149,
       0.1455363 , 0.03457657, 0.05580696, 2.2009387 , 2.6731372 ,
       0.84900653, 0.09296782, 0.07099355, 0.6669437 , 0.05722241,
       0.5679516 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1654d30>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0d2eed0; to 'JaxprTracer' at 0x7c31a0d2e200>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:   4%|▎         | 3/85 [00:08<03:54,  2.86s/it, loss=0.491]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.06203663 0.03761729 2.0914404  0.09757002 0.14283332 1.1193001
 0.13900046 0.9133365  0.05914342 0.31650224 0.1197442  0.15146688
 0.38960937 1.9530259  0.081141   0.02381738], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.06203663, 0.03761729, 2.0914404 , 0.09757002, 0.14283332,
       1.1193001 , 0.13900046, 0.9133365 , 0.05914342, 0.31650224,
       0.1197442 , 0.15146688, 0.38960937, 1.9530259 , 0.081141  ,
       0.02381738], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655160>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16f84a0; to 'JaxprTracer' at 0x7c31a16f8b30>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:   5%|▍         | 4/85 [00:11<03:40,  2.72s/it, loss=0.481]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.7118265  0.62803644 0.8117889  0.0890785  0.01629041 0.05530966
 0.10220464 0.089076   0.17034    2.26541    0.8978087  0.02836076
 0.01543963 0.07489851 2.3394237  0.37780732], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.7118265 , 0.62803644, 0.8117889 , 0.0890785 , 0.01629041,
       0.05530966, 0.10220464, 0.089076  , 0.17034   , 2.26541   ,
       0.8978087 , 0.02836076, 0.01543963, 0.07489851, 2.3394237 ,
       0.37780732], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a65950>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a7dc10; to 'JaxprTracer' at 0x7c31a1a7fbf0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:   6%|▌         | 5/85 [00:13<03:32,  2.66s/it, loss=0.542]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.14352538 0.05581103 0.01653104 0.0197222  0.26455638 0.01134669
 0.12205093 0.10823391 0.1747182  0.0135588  0.0122264  0.02077428
 0.49398792 0.02391248 0.05314855 0.18064292], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.14352538, 0.05581103, 0.01653104, 0.0197222 , 0.26455638,
       0.01134669, 0.12205093, 0.10823391, 0.1747182 , 0.0135588 ,
       0.0122264 , 0.02077428, 0.49398792, 0.02391248, 0.05314855,
       0.18064292], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196fe40>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a12678d0; to 'JaxprTracer' at 0x7c31a0e11170>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:   7%|▋         | 6/85 [00:17<03:50,  2.91s/it, loss=0.107]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.08880475 0.47281963 0.00766065 0.0378261  0.02168714 0.02096422
 0.00800154 2.6948712  3.0217404  0.00488339 0.5875683  0.0074148
 0.03887502 0.03834112 0.0121747  0.02758687], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.08880475, 0.47281963, 0.00766065, 0.0378261 , 0.02168714,
       0.02096422, 0.00800154, 2.6948712 , 3.0217404 , 0.00488339,
       0.5875683 , 0.0074148 , 0.03887502, 0.03834112, 0.0121747 ,
       0.02758687], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196eaa0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1266390; to 'JaxprTracer' at 0x7c31a1264e00>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:   8%|▊         | 7/85 [00:20<03:48,  2.93s/it, loss=0.443]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.17159283 0.8026813  0.01130108 0.04371621 0.01177786 0.01735138
 0.04984167 0.00521122 0.03747273 1.3071799  0.01823729 0.02140866
 0.33916894 0.01219413 0.30448905 0.31266606], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.17159283, 0.8026813 , 0.01130108, 0.04371621, 0.01177786,
       0.01735138, 0.04984167, 0.00521122, 0.03747273, 1.3071799 ,
       0.01823729, 0.02140866, 0.33916894, 0.01219413, 0.30448905,
       0.31266606], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196e0d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1265490; to 'JaxprTracer' at 0x7c31a1267a60>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:   9%|▉         | 8/85 [00:22<03:35,  2.80s/it, loss=0.217]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00701891 2.4859881  0.3963607  2.2033527  0.03354652 0.21654443
 0.5554671  0.00327218 0.55437446 0.00685471 0.0181022  0.15002811
 0.00799479 0.03288462 0.05293307 0.03807733], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00701891, 2.4859881 , 0.3963607 , 2.2033527 , 0.03354652,
       0.21654443, 0.5554671 , 0.00327218, 0.55437446, 0.00685471,
       0.0181022 , 0.15002811, 0.00799479, 0.03288462, 0.05293307,
       0.03807733], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196e500>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b4b9c0; to 'JaxprTracer' at 0x7c31a1b49990>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  11%|█         | 9/85 [00:25<03:26,  2.72s/it, loss=0.423]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.14338377 0.00448982 0.00890368 2.2165616  0.0702312  0.61781776
 1.437827   0.00830766 0.01774589 0.19718066 0.11580252 0.34758365
 0.23818798 0.21399462 0.17091964 0.01123177], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.14338377, 0.00448982, 0.00890368, 2.2165616 , 0.0702312 ,
       0.61781776, 1.437827  , 0.00830766, 0.01774589, 0.19718066,
       0.11580252, 0.34758365, 0.23818798, 0.21399462, 0.17091964,
       0.01123177], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84290>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b49c60; to 'JaxprTracer' at 0x7c31a1d20ea0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  12%|█▏        | 10/85 [00:27<03:20,  2.67s/it, loss=0.364]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.28453422 0.71422184 0.00879179 0.02030351 0.01851841 3.5735817
 0.19756944 0.09942722 0.29779202 0.19272178 0.7150005  0.36193252
 2.4040031  1.0904372  0.01046397 0.36112094], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.28453422, 0.71422184, 0.00879179, 0.02030351, 0.01851841,
       3.5735817 , 0.19756944, 0.09942722, 0.29779202, 0.19272178,
       0.7150005 , 0.36193252, 2.4040031 , 1.0904372 , 0.01046397,
       0.36112094], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196fad0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a10d2c50; to 'JaxprTracer' at 0x7c31a10d14e0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  13%|█▎        | 11/85 [00:31<03:41,  3.00s/it, loss=0.647]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.06668148 0.8932536  0.00986947 0.02308715 0.08989026 0.01189579
 0.24230185 0.07577524 0.57141733 0.39385864 0.04426516 0.13693869
 0.00855224 0.3584628  1.072647   0.01377093], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.06668148, 0.8932536 , 0.00986947, 0.02308715, 0.08989026,
       0.01189579, 0.24230185, 0.07577524, 0.57141733, 0.39385864,
       0.04426516, 0.13693869, 0.00855224, 0.3584628 , 1.072647  ,
       0.01377093], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a644b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a148d300; to 'JaxprTracer' at 0x7c31a148e250>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  14%|█▍        | 12/85 [00:34<03:30,  2.89s/it, loss=0.251]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02275242 0.7828071  0.08828769 0.5224034  0.00701086 1.0516309
 0.92365456 0.03774276 0.00790042 2.1673849  0.15348013 0.73104537
 0.01345894 0.00652907 0.00773459 0.09025054], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02275242, 0.7828071 , 0.08828769, 0.5224034 , 0.00701086,
       1.0516309 , 0.92365456, 0.03774276, 0.00790042, 2.1673849 ,
       0.15348013, 0.73104537, 0.01345894, 0.00652907, 0.00773459,
       0.09025054], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64c70>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a0a1b0; to 'JaxprTracer' at 0x7c31a1a0a250>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  15%|█▌        | 13/85 [00:36<03:21,  2.81s/it, loss=0.413]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.07475449 0.04039557 0.12223662 0.03397434 0.80105233 0.21654405
 0.31547078 0.06992166 0.14874826 0.19900516 0.05036766 0.5521203
 0.5791097  0.3697919  0.7043911  0.03669288], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.07475449, 0.04039557, 0.12223662, 0.03397434, 0.80105233,
       0.21654405, 0.31547078, 0.06992166, 0.14874826, 0.19900516,
       0.05036766, 0.5521203 , 0.5791097 , 0.3697919 , 0.7043911 ,
       0.03669288], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196daa0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1c56f70; to 'JaxprTracer' at 0x7c31a16ac540>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  16%|█▋        | 14/85 [00:39<03:15,  2.75s/it, loss=0.27]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02634468 0.87172776 0.26727986 0.34145626 0.9821837  0.05953201
 0.02082857 0.23417391 0.03572817 0.0490772  0.09886725 1.4661732
 0.02251514 0.02629185 0.84669125 0.03561049], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02634468, 0.87172776, 0.26727986, 0.34145626, 0.9821837 ,
       0.05953201, 0.02082857, 0.23417391, 0.03572817, 0.0490772 ,
       0.09886725, 1.4661732 , 0.02251514, 0.02629185, 0.84669125,
       0.03561049], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b845b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a15589f0; to 'JaxprTracer' at 0x7c31a155ad90>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  18%|█▊        | 15/85 [00:42<03:16,  2.81s/it, loss=0.337]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02681018 0.4139858  0.06910476 0.06257405 0.09071476 0.29131764
 0.10279265 0.05309519 0.02927587 0.05629433 0.21792471 0.09822527
 0.5990039  0.09391491 0.04815641 0.6184737 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02681018, 0.4139858 , 0.06910476, 0.06257405, 0.09071476,
       0.29131764, 0.10279265, 0.05309519, 0.02927587, 0.05629433,
       0.21792471, 0.09822527, 0.5990039 , 0.09391491, 0.04815641,
       0.6184737 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196e840>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b168e0; to 'JaxprTracer' at 0x7c31a1b15e90>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  19%|█▉        | 16/85 [00:46<03:29,  3.04s/it, loss=0.179]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.03794042 0.0197721  0.02810568 0.12332081 1.3217603  0.01503399
 0.05910477 0.13945867 0.09021459 0.07345682 0.04231372 0.04986492
 0.24998163 0.73063827 0.02417212 1.1519295 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.03794042, 0.0197721 , 0.02810568, 0.12332081, 1.3217603 ,
       0.01503399, 0.05910477, 0.13945867, 0.09021459, 0.07345682,
       0.04231372, 0.04986492, 0.24998163, 0.73063827, 0.02417212,
       1.1519295 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196ecf0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1140cc0; to 'JaxprTracer' at 0x7c31a15d2ac0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  20%|██        | 17/85 [00:48<03:17,  2.91s/it, loss=0.26]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.08609822 0.00970892 0.01702492 0.07917867 0.08577015 1.2167273
 0.70799094 0.01628583 0.18842532 0.53136164 0.02499865 0.03120112
 0.02312571 0.04316536 0.33917293 0.05588656], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.08609822, 0.00970892, 0.01702492, 0.07917867, 0.08577015,
       1.2167273 , 0.70799094, 0.01628583, 0.18842532, 0.53136164,
       0.02499865, 0.03120112, 0.02312571, 0.04316536, 0.33917293,
       0.05588656], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196eb40>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16fb600; to 'JaxprTracer' at 0x7c31a1ef3e20>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  21%|██        | 18/85 [00:51<03:08,  2.81s/it, loss=0.216]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.03842865 1.3930643  0.6911178  0.0225134  0.2922556  0.18785614
 0.0519589  0.02400152 0.21879154 0.52403647 0.28849638 0.09544609
 0.03541698 0.16797072 0.25020713 0.06292034], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.03842865, 1.3930643 , 0.6911178 , 0.0225134 , 0.2922556 ,
       0.18785614, 0.0519589 , 0.02400152, 0.21879154, 0.52403647,
       0.28849638, 0.09544609, 0.03541698, 0.16797072, 0.25020713,
       0.06292034], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196f0d0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1bd8ae0; to 'JaxprTracer' at 0x7c31a1bd88b0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  22%|██▏       | 19/85 [00:53<03:00,  2.73s/it, loss=0.272]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01410762 0.00797942 0.01074434 0.09576526 0.1305738  0.05677965
 0.04792417 0.39620087 0.04555705 0.23816328 0.19284713 0.5572743
 0.01408329 0.09263764 0.9251771  0.28288537], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01410762, 0.00797942, 0.01074434, 0.09576526, 0.1305738 ,
       0.05677965, 0.04792417, 0.39620087, 0.04555705, 0.23816328,
       0.19284713, 0.5572743 , 0.01408329, 0.09263764, 0.9251771 ,
       0.28288537], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196fec0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1cbb380; to 'JaxprTracer' at 0x7c31a1cb8540>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  24%|██▎       | 20/85 [00:57<03:07,  2.88s/it, loss=0.194]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.2230374  2.318106   0.06066894 0.08022738 0.07335413 0.64154744
 0.25765964 0.03021548 1.17311    0.17636079 0.33013552 1.0520985
 0.0848361  0.01659401 0.20496877 0.0558064 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.2230374 , 2.318106  , 0.06066894, 0.08022738, 0.07335413,
       0.64154744, 0.25765964, 0.03021548, 1.17311   , 0.17636079,
       0.33013552, 1.0520985 , 0.0848361 , 0.01659401, 0.20496877,
       0.0558064 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119c820>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16c7bf0; to 'JaxprTracer' at 0x7c31a0e43100>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  25%|██▍       | 21/85 [01:00<03:09,  2.96s/it, loss=0.424]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02013739 1.6106663  0.09593412 0.0312775  0.14484976 0.04341012
 0.02128041 0.18780348 1.3290348  0.37863722 0.03996637 0.03155082
 0.6251861  0.7083147  0.00701773 0.01629287], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02013739, 1.6106663 , 0.09593412, 0.0312775 , 0.14484976,
       0.04341012, 0.02128041, 0.18780348, 1.3290348 , 0.37863722,
       0.03996637, 0.03155082, 0.6251861 , 0.7083147 , 0.00701773,
       0.01629287], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119cdd0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a164ef70; to 'JaxprTracer' at 0x7c31a164c590>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  26%|██▌       | 22/85 [01:02<02:59,  2.85s/it, loss=0.331]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.18460265 0.02617713 0.75856686 0.02996447 0.07075879 0.5937855
 0.03697427 1.4375892  0.02761308 0.02404679 0.03207174 0.06776509
 0.13983728 0.03870795 0.02487726 0.07146847], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.18460265, 0.02617713, 0.75856686, 0.02996447, 0.07075879,
       0.5937855 , 0.03697427, 1.4375892 , 0.02761308, 0.02404679,
       0.03207174, 0.06776509, 0.13983728, 0.03870795, 0.02487726,
       0.07146847], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655a80>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ef2340; to 'JaxprTracer' at 0x7c31a15f3060>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  27%|██▋       | 23/85 [01:05<02:50,  2.75s/it, loss=0.223]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.3845594  0.03110441 0.6142696  1.4251957  1.9418182  0.01735583
 0.8445349  0.19691332 0.78187525 1.7687641  0.10645977 0.07817098
 0.39983118 0.02391842 0.13472173 0.02178874], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.3845594 , 0.03110441, 0.6142696 , 1.4251957 , 1.9418182 ,
       0.01735583, 0.8445349 , 0.19691332, 0.78187525, 1.7687641 ,
       0.10645977, 0.07817098, 0.39983118, 0.02391842, 0.13472173,
       0.02178874], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1657220>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e71b70; to 'JaxprTracer' at 0x7c31a0d40ea0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  28%|██▊       | 24/85 [01:08<02:48,  2.77s/it, loss=0.548]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.04914065 0.04664478 1.9049817  0.02820443 0.02484005 0.7431491
 0.02321086 0.01237547 0.03825828 0.01126666 2.48124    0.25502598
 0.03605272 0.14653212 0.21580246 0.01997077], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.04914065, 0.04664478, 1.9049817 , 0.02820443, 0.02484005,
       0.7431491 , 0.02321086, 0.01237547, 0.03825828, 0.01126666,
       2.48124   , 0.25502598, 0.03605272, 0.14653212, 0.21580246,
       0.01997077], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1654f60>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1fdf7e0; to 'JaxprTracer' at 0x7c31a1055080>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  29%|██▉       | 25/85 [01:11<02:53,  2.90s/it, loss=0.377]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.3004113  0.0911285  0.03910455 0.07362495 0.65529853 0.40864524
 0.05693575 0.01842198 0.02958887 0.02581485 0.10822289 0.02527951
 0.0475656  0.03884957 0.49564567 0.30692285], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.3004113 , 0.0911285 , 0.03910455, 0.07362495, 0.65529853,
       0.40864524, 0.05693575, 0.01842198, 0.02958887, 0.02581485,
       0.10822289, 0.02527951, 0.0475656 , 0.03884957, 0.49564567,
       0.30692285], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1657f90>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1610360; to 'JaxprTracer' at 0x7c31a1611620>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  31%|███       | 26/85 [01:14<02:52,  2.93s/it, loss=0.17]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.07429098 0.14327645 0.06758115 0.04628404 0.01029833 0.09585332
 0.8087888  0.00975638 0.1782018  0.02750279 0.13875581 0.05811702
 2.0538902  0.13294579 0.02760508 0.33130702], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.07429098, 0.14327645, 0.06758115, 0.04628404, 0.01029833,
       0.09585332, 0.8087888 , 0.00975638, 0.1782018 , 0.02750279,
       0.13875581, 0.05811702, 2.0538902 , 0.13294579, 0.02760508,
       0.33130702], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1656460>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1346e80; to 'JaxprTracer' at 0x7c31a1346a20>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  32%|███▏      | 27/85 [01:16<02:43,  2.81s/it, loss=0.263]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.03800745 0.67970294 0.16283943 0.03274745 0.14690755 0.08098742
 0.11168256 2.3262641  0.9185637  0.10662437 0.14998665 0.02971836
 0.28194803 0.07994282 0.08720622 0.00720545], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.03800745, 0.67970294, 0.16283943, 0.03274745, 0.14690755,
       0.08098742, 0.11168256, 2.3262641 , 0.9185637 , 0.10662437,
       0.14998665, 0.02971836, 0.28194803, 0.07994282, 0.08720622,
       0.00720545], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16570f0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a121b060; to 'JaxprTracer' at 0x7c31a1545b20>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  33%|███▎      | 28/85 [01:19<02:35,  2.73s/it, loss=0.328]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.45900828 0.01913633 0.22583102 0.51911765 0.43153697 0.02925457
 0.06613846 0.04269018 0.03136507 0.01804987 0.2672888  0.01469315
 0.01755474 0.28553465 0.13267156 0.13985936], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.45900828, 0.01913633, 0.22583102, 0.51911765, 0.43153697,
       0.02925457, 0.06613846, 0.04269018, 0.03136507, 0.01804987,
       0.2672888 , 0.01469315, 0.01755474, 0.28553465, 0.13267156,
       0.13985936], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1654bd0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1190ea0; to 'JaxprTracer' at 0x7c31a148f3d0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  34%|███▍      | 29/85 [01:22<02:41,  2.88s/it, loss=0.169]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.04168538 0.0611323  0.01983031 0.04872915 0.03560382 0.01946187
 0.04972281 0.19021954 0.8036389  0.6522785  0.03509084 0.11386637
 0.11690421 0.07856096 0.04689892 0.04854406], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.04168538, 0.0611323 , 0.01983031, 0.04872915, 0.03560382,
       0.01946187, 0.04972281, 0.19021954, 0.8036389 , 0.6522785 ,
       0.03509084, 0.11386637, 0.11690421, 0.07856096, 0.04689892,
       0.04854406], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a65700>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a14b8040; to 'JaxprTracer' at 0x7c31a14b9cb0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  35%|███▌      | 30/85 [01:25<02:37,  2.87s/it, loss=0.148]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.32679558 0.06511519 0.09280106 0.0621332  0.2767983  0.65871996
 0.01212524 0.5677121  0.21558225 0.1909423  0.06683849 0.4155207
 0.03259365 0.13435565 0.00655299 0.05153079], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.32679558, 0.06511519, 0.09280106, 0.0621332 , 0.2767983 ,
       0.65871996, 0.01212524, 0.5677121 , 0.21558225, 0.1909423 ,
       0.06683849, 0.4155207 , 0.03259365, 0.13435565, 0.00655299,
       0.05153079], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64800>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e41cb0; to 'JaxprTracer' at 0x7c31a0e433d0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  36%|███▋      | 31/85 [01:27<02:29,  2.76s/it, loss=0.199]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.217908   0.01834907 0.06784228 0.16575086 0.039891   0.0038021
 0.01147881 0.00282722 0.17283276 0.00894539 0.04214527 0.91734374
 0.00665342 0.02587874 0.07442179 0.04213911], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.217908  , 0.01834907, 0.06784228, 0.16575086, 0.039891  ,
       0.0038021 , 0.01147881, 0.00282722, 0.17283276, 0.00894539,
       0.04214527, 0.91734374, 0.00665342, 0.02587874, 0.07442179,
       0.04213911], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119fb90>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ad4ef0; to 'JaxprTracer' at 0x7c31a1ad7a10>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  38%|███▊      | 32/85 [01:30<02:29,  2.82s/it, loss=0.176]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.46009454 0.00832599 0.05478007 0.01238701 0.1405518  0.01686282
 0.1774984  0.01127055 0.04271553 1.0489378  0.03167395 0.03544701
 0.16368559 0.02191412 0.30298558 0.38144547], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.46009454, 0.00832599, 0.05478007, 0.01238701, 0.1405518 ,
       0.01686282, 0.1774984 , 0.01127055, 0.04271553, 1.0489378 ,
       0.03167395, 0.03544701, 0.16368559, 0.02191412, 0.30298558,
       0.38144547], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119d7a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e133d0; to 'JaxprTracer' at 0x7c31a0e10450>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  39%|███▉      | 33/85 [01:33<02:22,  2.75s/it, loss=0.182]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.7301822  1.5165762  0.40722293 0.0617881  0.04795724 1.9827589
 0.0099908  0.064554   0.07978643 0.33356535 0.00887166 0.01173651
 0.06060925 0.04893032 0.40744823 0.0701753 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.7301822 , 1.5165762 , 0.40722293, 0.0617881 , 0.04795724,
       1.9827589 , 0.0099908 , 0.064554  , 0.07978643, 0.33356535,
       0.00887166, 0.01173651, 0.06060925, 0.04893032, 0.40744823,
       0.0701753 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119eab0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ac0180; to 'JaxprTracer' at 0x7c31a1514270>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  40%|████      | 34/85 [01:36<02:27,  2.90s/it, loss=0.365]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00748177 0.0442533  0.0131418  0.30503502 0.03395786 0.07770333
 0.09376959 0.01751831 0.09437852 0.01088857 0.02755162 0.01294577
 0.10419086 0.1780733  0.02376477 0.3672255 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00748177, 0.0442533 , 0.0131418 , 0.30503502, 0.03395786,
       0.07770333, 0.09376959, 0.01751831, 0.09437852, 0.01088857,
       0.02755162, 0.01294577, 0.10419086, 0.1780733 , 0.02376477,
       0.3672255 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64300>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1e6c9f0; to 'JaxprTracer' at 0x7c31a1e6fe20>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  41%|████      | 35/85 [01:39<02:19,  2.78s/it, loss=0.0882]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00888289 0.12468864 0.10214479 0.01935618 0.0203104  0.09465466
 0.00394839 0.03288623 0.09862003 0.01067334 0.04481107 0.06767407
 0.13982308 0.00918402 0.16545418 0.3912601 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00888289, 0.12468864, 0.10214479, 0.01935618, 0.0203104 ,
       0.09465466, 0.00394839, 0.03288623, 0.09862003, 0.01067334,
       0.04481107, 0.06767407, 0.13982308, 0.00918402, 0.16545418,
       0.3912601 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84b40>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1140400; to 'JaxprTracer' at 0x7c31a0d418a0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  42%|████▏     | 36/85 [01:41<02:13,  2.72s/it, loss=0.0834]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01287269 0.28594726 0.04139914 0.496567   0.21404505 0.384176
 0.36224675 0.00538032 0.12810062 0.0712498  0.02006635 0.03371159
 0.01940084 0.04592397 0.07633305 0.01313663], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01287269, 0.28594726, 0.04139914, 0.496567  , 0.21404505,
       0.384176  , 0.36224675, 0.00538032, 0.12810062, 0.0712498 ,
       0.02006635, 0.03371159, 0.01940084, 0.04592397, 0.07633305,
       0.01313663], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a66640>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a2d170; to 'JaxprTracer' at 0x7c31a1d5f3d0>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  44%|████▎     | 37/85 [01:44<02:14,  2.79s/it, loss=0.138]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.25321287 0.02010713 0.02490912 0.15767972 0.00891384 0.29745847
 1.8926862  0.02716791 0.25328812 0.14871877 0.0430763  0.02620848
 0.01714152 0.00605594 0.13485996 0.0148819 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.25321287, 0.02010713, 0.02490912, 0.15767972, 0.00891384,
       0.29745847, 1.8926862 , 0.02716791, 0.25328812, 0.14871877,
       0.0430763 , 0.02620848, 0.01714152, 0.00605594, 0.13485996,
       0.0148819 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1656670>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a11425c0; to 'JaxprTracer' at 0x7c31a1143510>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  45%|████▍     | 38/85 [01:47<02:12,  2.82s/it, loss=0.208]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01134268 0.00375471 0.3931441  0.00567741 0.00297366 0.03244445
 0.03963438 0.00277896 0.06924159 0.14505292 0.05295897 0.15527062
 0.00684382 0.00780284 0.37099993 0.00566496], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01134268, 0.00375471, 0.3931441 , 0.00567741, 0.00297366,
       0.03244445, 0.03963438, 0.00277896, 0.06924159, 0.14505292,
       0.05295897, 0.15527062, 0.00684382, 0.00780284, 0.37099993,
       0.00566496], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b87db0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a13fb330; to 'JaxprTracer' at 0x7c31a13f8a40>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  46%|████▌     | 39/85 [01:50<02:14,  2.92s/it, loss=0.0816]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02186642 0.01835633 0.00548477 0.03848532 0.14846602 0.24841388
 0.10256039 0.08037522 0.02164106 0.04732357 0.00625415 0.9811875
 0.03200143 0.1374469  0.01305003 0.5780764 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02186642, 0.01835633, 0.00548477, 0.03848532, 0.14846602,
       0.24841388, 0.10256039, 0.08037522, 0.02164106, 0.04732357,
       0.00625415, 0.9811875 , 0.03200143, 0.1374469 , 0.01305003,
       0.5780764 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1654140>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0dc76f0; to 'JaxprTracer' at 0x7c31a0dc76a0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  47%|████▋     | 40/85 [01:53<02:06,  2.81s/it, loss=0.155]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.1410208e-01 5.8533665e-02 1.6690861e-03 1.8829761e-02 9.4752600e-03
 1.2345564e-02 7.5337306e-02 5.3700704e-02 2.9413398e-02 1.6987058e+00
 2.7992435e-02 6.1581248e-01 1.3871925e-02 1.1250042e+00 7.0869625e-02
 3.8246468e-02], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.1410208e-01, 5.8533665e-02, 1.6690861e-03, 1.8829761e-02,
       9.4752600e-03, 1.2345564e-02, 7.5337306e-02, 5.3700704e-02,
       2.9413398e-02, 1.6987058e+00, 2.7992435e-02, 6.1581248e-01,
       1.3871925e-02, 1.1250042e+00, 7.0869625e-02, 3.8246468e-02],      dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1657430>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7

Training:  48%|████▊     | 41/85 [01:55<02:00,  2.73s/it, loss=0.248]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.08174624 0.00616139 0.0277208  0.09831107 0.04102942 0.00314919
 0.01491432 0.03636521 0.08519831 0.03637452 0.20494944 0.08300029
 0.5724437  0.01010457 0.01066001 0.03831657], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.08174624, 0.00616139, 0.0277208 , 0.09831107, 0.04102942,
       0.00314919, 0.01491432, 0.03636521, 0.08519831, 0.03637452,
       0.20494944, 0.08300029, 0.5724437 , 0.01010457, 0.01066001,
       0.03831657], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b875e0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0fa4bd0; to 'JaxprTracer' at 0x7c31a0fa7060>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  49%|████▉     | 42/85 [01:58<02:00,  2.80s/it, loss=0.0844]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01209238 0.02360143 0.02840073 0.02418818 0.0659069  0.0696096
 0.00988009 0.01854182 0.01857142 0.01777868 2.784356   0.0107263
 0.00895638 0.01111895 0.05422444 0.01234474], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01209238, 0.02360143, 0.02840073, 0.02418818, 0.0659069 ,
       0.0696096 , 0.00988009, 0.01854182, 0.01857142, 0.01777868,
       2.784356  , 0.0107263 , 0.00895638, 0.01111895, 0.05422444,
       0.01234474], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84790>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1691c10; to 'JaxprTracer' at 0x7c31a1693380>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  51%|█████     | 43/85 [02:01<01:59,  2.86s/it, loss=0.198]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01213901 0.02374416 0.07353169 0.02717302 0.00875622 0.04292511
 0.00660475 0.33819485 0.01527481 0.2895994  0.00357774 0.02433769
 0.5178143  0.01618239 0.01247225 0.02984333], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01213901, 0.02374416, 0.07353169, 0.02717302, 0.00875622,
       0.04292511, 0.00660475, 0.33819485, 0.01527481, 0.2895994 ,
       0.00357774, 0.02433769, 0.5178143 , 0.01618239, 0.01247225,
       0.02984333], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b85820>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b92890; to 'JaxprTracer' at 0x7c31a1b914e0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  52%|█████▏    | 44/85 [02:04<01:58,  2.88s/it, loss=0.0901]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00740817 0.0133872  0.00721161 0.07552623 0.01141906 0.00496844
 0.03688799 0.00733611 0.00813457 0.00451794 0.2011287  0.02007734
 0.02724738 0.15958345 0.01765055 2.550221  ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00740817, 0.0133872 , 0.00721161, 0.07552623, 0.01141906,
       0.00496844, 0.03688799, 0.00733611, 0.00813457, 0.00451794,
       0.2011287 , 0.02007734, 0.02724738, 0.15958345, 0.01765055,
       2.550221  ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b85620>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1955940; to 'JaxprTracer' at 0x7c31a1956930>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  53%|█████▎    | 45/85 [02:07<01:50,  2.77s/it, loss=0.197]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([3.0345127e-03 6.1329505e-03 8.8819424e-03 1.8687657e-03 9.6040213e-01
 1.3413641e+00 1.0587896e-01 8.1447372e-03 2.5844708e-02 5.4412767e-02
 2.3328487e-02 1.9436732e-02 2.8548690e-02 2.0434335e-02 3.1865733e+00
 1.5915502e-02], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([3.0345127e-03, 6.1329505e-03, 8.8819424e-03, 1.8687657e-03,
       9.6040213e-01, 1.3413641e+00, 1.0587896e-01, 8.1447372e-03,
       2.5844708e-02, 5.4412767e-02, 2.3328487e-02, 1.9436732e-02,
       2.8548690e-02, 2.0434335e-02, 3.1865733e+00, 1.5915502e-02],      dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196ee60>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7

Training:  54%|█████▍    | 46/85 [02:09<01:45,  2.71s/it, loss=0.363]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([2.3627956e-03 2.5329227e+00 1.0208527e+00 3.7081547e-03 3.2194916e-02
 5.5438108e-03 6.3091153e-03 2.3251852e-02 2.0572229e-03 1.8911297e-02
 9.0433294e-03 1.8876789e-02 2.9762335e-02 2.0587149e+00 6.1051073e-03
 2.5223840e-03], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([2.3627956e-03, 2.5329227e+00, 1.0208527e+00, 3.7081547e-03,
       3.2194916e-02, 5.5438108e-03, 6.3091153e-03, 2.3251852e-02,
       2.0572229e-03, 1.8911297e-02, 9.0433294e-03, 1.8876789e-02,
       2.9762335e-02, 2.0587149e+00, 6.1051073e-03, 2.5223840e-03],      dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196d720>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7

Training:  55%|█████▌    | 47/85 [02:13<01:53,  2.98s/it, loss=0.361]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.10179083 2.068393   0.51946604 0.02981221 0.01598026 0.02667195
 0.52289104 2.752938   0.00892731 0.0194358  0.286998   0.00629762
 0.03012249 0.09096002 0.00957586 0.10873819], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.10179083, 2.068393  , 0.51946604, 0.02981221, 0.01598026,
       0.02667195, 0.52289104, 2.752938  , 0.00892731, 0.0194358 ,
       0.286998  , 0.00629762, 0.03012249, 0.09096002, 0.00957586,
       0.10873819], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84b90>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1606430; to 'JaxprTracer' at 0x7c31a1607e20>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  56%|█████▋    | 48/85 [02:16<01:53,  3.06s/it, loss=0.412]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01324604 0.05157981 0.03650969 0.02345565 0.00540118 0.14444561
 0.14959976 0.06437215 0.00778995 0.17060609 0.0082411  0.73390335
 0.0058847  0.00149008 0.01416592 0.00724356], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01324604, 0.05157981, 0.03650969, 0.02345565, 0.00540118,
       0.14444561, 0.14959976, 0.06437215, 0.00778995, 0.17060609,
       0.0082411 , 0.73390335, 0.0058847 , 0.00149008, 0.01416592,
       0.00724356], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196d750>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1ad5990; to 'JaxprTracer' at 0x7c31a0e43380>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  58%|█████▊    | 49/85 [02:19<01:44,  2.91s/it, loss=0.0899]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([7.5052574e-02 1.8676966e-02 2.2662214e-03 2.7090413e-02 1.6143347e-01
 1.7711105e-02 3.4583244e-02 1.2276562e-02 8.6112216e-02 2.3790612e+00
 8.5167764e-03 3.2061207e+00 3.5778575e-03 4.6393378e-03 2.1754567e-02
 1.2415503e-02], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([7.5052574e-02, 1.8676966e-02, 2.2662214e-03, 2.7090413e-02,
       1.6143347e-01, 1.7711105e-02, 3.4583244e-02, 1.2276562e-02,
       8.6112216e-02, 2.3790612e+00, 8.5167764e-03, 3.2061207e+00,
       3.5778575e-03, 4.6393378e-03, 2.1754567e-02, 1.2415503e-02],      dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196fff0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7

Training:  59%|█████▉    | 50/85 [02:21<01:37,  2.79s/it, loss=0.379]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00765781 1.681494   0.04326309 0.01426264 0.00605285 0.00485646
 0.33414468 0.01227303 0.00811009 0.03491323 0.02133981 0.00878044
 0.23957223 0.07891597 0.00532494 0.03414313], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00765781, 1.681494  , 0.04326309, 0.01426264, 0.00605285,
       0.00485646, 0.33414468, 0.01227303, 0.00811009, 0.03491323,
       0.02133981, 0.00878044, 0.23957223, 0.07891597, 0.00532494,
       0.03414313], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16568b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1f2f240; to 'JaxprTracer' at 0x7c31a1f2c400>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  60%|██████    | 51/85 [02:24<01:32,  2.72s/it, loss=0.158]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02590999 0.06255569 0.00473521 0.02806662 0.35249543 0.01214255
 0.00316536 0.02527335 0.0403097  0.03987325 0.01049948 0.11493045
 1.02504    1.1912727  0.00737894 0.0049395 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02590999, 0.06255569, 0.00473521, 0.02806662, 0.35249543,
       0.01214255, 0.00316536, 0.02527335, 0.0403097 , 0.03987325,
       0.01049948, 0.11493045, 1.02504   , 1.1912727 , 0.00737894,
       0.0049395 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196de30>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1d231a0; to 'JaxprTracer' at 0x7c31a1d23920>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  61%|██████    | 52/85 [02:27<01:34,  2.87s/it, loss=0.184]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.14114007 0.03682652 0.09659737 0.0083254  0.1970633  0.97610897
 0.02123747 0.03440715 0.00695108 0.12449288 0.00516046 0.0120488
 0.09694085 0.23486298 0.04213933 0.22253035], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.14114007, 0.03682652, 0.09659737, 0.0083254 , 0.1970633 ,
       0.97610897, 0.02123747, 0.03440715, 0.00695108, 0.12449288,
       0.00516046, 0.0120488 , 0.09694085, 0.23486298, 0.04213933,
       0.22253035], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a660a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a162f9c0; to 'JaxprTracer' at 0x7c31a1d21620>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  62%|██████▏   | 53/85 [02:30<01:33,  2.93s/it, loss=0.141]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.3362260e-02 8.0685496e-02 1.8339020e-03 3.4509976e+00 1.1438267e-02
 1.7323608e-02 3.6664039e-02 2.7414646e-02 8.3370402e-02 2.7610059e-03
 1.6962333e-02 3.9498832e-02 2.4614431e-02 1.2896580e-01 5.4672244e-03
 2.8142798e-01], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.3362260e-02, 8.0685496e-02, 1.8339020e-03, 3.4509976e+00,
       1.1438267e-02, 1.7323608e-02, 3.6664039e-02, 2.7414646e-02,
       8.3370402e-02, 2.7610059e-03, 1.6962333e-02, 3.9498832e-02,
       2.4614431e-02, 1.2896580e-01, 5.4672244e-03, 2.8142798e-01],      dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119e8a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7

Training:  64%|██████▎   | 54/85 [02:33<01:27,  2.81s/it, loss=0.264]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02202234 0.27544713 0.04360107 0.01797739 0.00943086 0.76371104
 0.00540237 0.00250395 0.02871365 0.03677757 0.01459588 0.0672523
 0.00544991 0.02788799 0.00659977 0.03556402], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02202234, 0.27544713, 0.04360107, 0.01797739, 0.00943086,
       0.76371104, 0.00540237, 0.00250395, 0.02871365, 0.03677757,
       0.01459588, 0.0672523 , 0.00544991, 0.02788799, 0.00659977,
       0.03556402], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16553b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a162c5e0; to 'JaxprTracer' at 0x7c31a162c310>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  65%|██████▍   | 55/85 [02:35<01:21,  2.71s/it, loss=0.0852]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.16371119 0.00375602 0.00642745 0.00468277 0.8368875  0.902555
 0.03560129 0.16196881 0.17145878 1.2359002  0.00398342 0.01888322
 0.01383677 0.0135976  0.0066713  0.18566519], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.16371119, 0.00375602, 0.00642745, 0.00468277, 0.8368875 ,
       0.902555  , 0.03560129, 0.16196881, 0.17145878, 1.2359002 ,
       0.00398342, 0.01888322, 0.01383677, 0.0135976 , 0.0066713 ,
       0.18566519], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b86ef0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b309a0; to 'JaxprTracer' at 0x7c31a1b30900>], out_avals=[ShapedArray(float32[16])], primitive=pjit,

Training:  66%|██████▌   | 56/85 [02:38<01:16,  2.65s/it, loss=0.235]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00810252 0.01926555 0.0055105  0.02579801 0.01057214 0.27276024
 0.0074684  0.02127563 0.00801715 0.00914067 0.01346024 0.11956439
 0.00362501 0.07759137 0.00608295 0.23853984], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00810252, 0.01926555, 0.0055105 , 0.02579801, 0.01057214,
       0.27276024, 0.0074684 , 0.02127563, 0.00801715, 0.00914067,
       0.01346024, 0.11956439, 0.00362501, 0.07759137, 0.00608295,
       0.23853984], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655a60>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0d42340; to 'JaxprTracer' at 0x7c31a0d41fd0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  67%|██████▋   | 57/85 [02:41<01:22,  2.95s/it, loss=0.0529]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00308657 0.15438451 0.00147651 0.05507705 0.00269193 0.00422466
 0.1384101  0.00355469 0.00926079 0.03710925 0.0087678  0.47055224
 0.02595739 0.03335895 0.00402118 0.01077949], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00308657, 0.15438451, 0.00147651, 0.05507705, 0.00269193,
       0.00422466, 0.1384101 , 0.00355469, 0.00926079, 0.03710925,
       0.0087678 , 0.47055224, 0.02595739, 0.03335895, 0.00402118,
       0.01077949], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b865a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a16ad670; to 'JaxprTracer' at 0x7c31a16ada80>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  68%|██████▊   | 58/85 [02:44<01:17,  2.88s/it, loss=0.0602]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00837954 0.04070236 0.01080213 0.25020406 0.00915036 0.08026038
 0.02316403 0.39333838 0.01145041 0.22532794 0.01141293 0.00520873
 0.07270011 0.00558803 0.00965237 0.12099493], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00837954, 0.04070236, 0.01080213, 0.25020406, 0.00915036,
       0.08026038, 0.02316403, 0.39333838, 0.01145041, 0.22532794,
       0.01141293, 0.00520873, 0.07270011, 0.00558803, 0.00965237,
       0.12099493], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655500>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0fa3ab0; to 'JaxprTracer' at 0x7c31a0fa3c40>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  69%|██████▉   | 59/85 [02:47<01:12,  2.79s/it, loss=0.0799]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.15837988 0.03098017 0.01727288 2.4743936  0.03209471 0.00463234
 0.00444187 0.00449623 0.08540202 0.00694859 0.1144599  0.05795796
 0.02308634 0.02065529 0.01172661 0.02737718], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.15837988, 0.03098017, 0.01727288, 2.4743936 , 0.03209471,
       0.00463234, 0.00444187, 0.00449623, 0.08540202, 0.00694859,
       0.1144599 , 0.05795796, 0.02308634, 0.02065529, 0.01172661,
       0.02737718], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84f60>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e9dd00; to 'JaxprTracer' at 0x7c31a0e9f1f0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  71%|███████   | 60/85 [02:49<01:07,  2.71s/it, loss=0.192]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00749407 0.19520935 0.01155659 0.02062773 0.01852005 0.11578096
 0.5637602  0.0299824  0.0045787  0.76537156 0.14760233 0.01532095
 0.03057062 0.03353661 0.06985853 0.02316333], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00749407, 0.19520935, 0.01155659, 0.02062773, 0.01852005,
       0.11578096, 0.5637602 , 0.0299824 , 0.0045787 , 0.76537156,
       0.14760233, 0.01532095, 0.03057062, 0.03353661, 0.06985853,
       0.02316333], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1657bf0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a192b8d0; to 'JaxprTracer' at 0x7c31a1ba3e70>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  72%|███████▏  | 61/85 [02:52<01:04,  2.67s/it, loss=0.128]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02558272 0.01065565 0.01574243 0.00681445 0.05628227 0.00412316
 0.00435108 0.00538257 0.01528408 0.00263998 0.12207583 0.3662197
 0.01654676 0.01757009 0.32560554 0.0077282 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02558272, 0.01065565, 0.01574243, 0.00681445, 0.05628227,
       0.00412316, 0.00435108, 0.00538257, 0.01528408, 0.00263998,
       0.12207583, 0.3662197 , 0.01654676, 0.01757009, 0.32560554,
       0.0077282 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84e10>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1c1b6f0; to 'JaxprTracer' at 0x7c31a14bb6a0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  73%|███████▎  | 62/85 [02:55<01:07,  2.95s/it, loss=0.0627]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00379948 0.01516938 0.0234291  0.00233675 0.01134233 0.1512102
 0.02268111 0.05050288 0.00575244 0.00626185 0.01549632 0.00395849
 0.11438303 0.00671073 0.00858202 0.00753217], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00379948, 0.01516938, 0.0234291 , 0.00233675, 0.01134233,
       0.1512102 , 0.02268111, 0.05050288, 0.00575244, 0.00626185,
       0.01549632, 0.00395849, 0.11438303, 0.00671073, 0.00858202,
       0.00753217], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16560b0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1143790; to 'JaxprTracer' at 0x7c31a1140d10>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  74%|███████▍  | 63/85 [02:58<01:02,  2.82s/it, loss=0.0281]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.20798203 0.01616549 0.04574372 0.00217154 0.01030824 0.02329017
 1.6610608  0.00744793 0.00883562 0.00360838 0.01331086 0.06213186
 0.04058252 0.02696092 0.00364354 0.01209438], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.20798203, 0.01616549, 0.04574372, 0.00217154, 0.01030824,
       0.02329017, 1.6610608 , 0.00744793, 0.00883562, 0.00360838,
       0.01331086, 0.06213186, 0.04058252, 0.02696092, 0.00364354,
       0.01209438], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b85a10>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1642e30; to 'JaxprTracer' at 0x7c31a1640cc0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  75%|███████▌  | 64/85 [03:00<00:57,  2.72s/it, loss=0.134]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.05707481 0.08844844 0.00223006 0.01016712 1.6101651  1.3353064
 0.00879592 0.18994068 0.0441243  0.486046   0.05016394 0.01733392
 0.00310142 0.05286297 0.95474267 1.3753492 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.05707481, 0.08844844, 0.00223006, 0.01016712, 1.6101651 ,
       1.3353064 , 0.00879592, 0.18994068, 0.0441243 , 0.486046  ,
       0.05016394, 0.01733392, 0.00310142, 0.05286297, 0.95474267,
       1.3753492 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655a40>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a12f3060; to 'JaxprTracer' at 0x7c31a12f2f70>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  76%|███████▋  | 65/85 [03:03<00:53,  2.65s/it, loss=0.393]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01711035 0.06351669 0.33032054 0.0065421  0.00289236 0.01011661
 0.00294062 0.0104983  0.06593782 0.0192838  0.0095846  0.01197165
 0.1200427  0.02532205 0.02158868 2.458691  ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01711035, 0.06351669, 0.33032054, 0.0065421 , 0.00289236,
       0.01011661, 0.00294062, 0.0104983 , 0.06593782, 0.0192838 ,
       0.0095846 , 0.01197165, 0.1200427 , 0.02532205, 0.02158868,
       2.458691  ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119e770>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a12f39c0; to 'JaxprTracer' at 0x7c31a12f1a80>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  78%|███████▊  | 66/85 [03:06<00:51,  2.68s/it, loss=0.199]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02451626 0.02614751 0.0573127  0.02141519 0.00327087 0.00632061
 0.00602276 0.00425671 0.01653902 0.00594882 0.51551116 0.00347641
 0.04255149 0.0958387  0.01154209 0.01061484], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02451626, 0.02614751, 0.0573127 , 0.02141519, 0.00327087,
       0.00632061, 0.00602276, 0.00425671, 0.01653902, 0.00594882,
       0.51551116, 0.00347641, 0.04255149, 0.0958387 , 0.01154209,
       0.01061484], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64780>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1a23290; to 'JaxprTracer' at 0x7c31a1a22520>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  79%|███████▉  | 67/85 [03:09<00:51,  2.89s/it, loss=0.0532]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.2179527e-02 3.6465807e+00 1.6414589e-01 3.1331338e-02 6.9015929e-03
 6.0074367e-02 1.2486331e-01 1.3578670e-02 2.9044569e-01 3.3744089e-02
 7.5857677e-03 1.4306123e-02 3.9986190e-03 1.4052845e-02 3.2968935e-02
 2.2743093e-03], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.2179527e-02, 3.6465807e+00, 1.6414589e-01, 3.1331338e-02,
       6.9015929e-03, 6.0074367e-02, 1.2486331e-01, 1.3578670e-02,
       2.9044569e-01, 3.3744089e-02, 7.5857677e-03, 1.4306123e-02,
       3.9986190e-03, 1.4052845e-02, 3.2968935e-02, 2.2743093e-03],      dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a64ab0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7

Training:  80%|████████  | 68/85 [03:11<00:47,  2.77s/it, loss=0.279]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00445338 0.04958827 0.04078384 1.0808727  0.02437853 0.41345122
 0.01861214 0.03672851 0.8831005  0.01272252 2.278065   0.5509956
 2.847479   0.12789221 0.0159094  0.01587995], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00445338, 0.04958827, 0.04078384, 1.0808727 , 0.02437853,
       0.41345122, 0.01861214, 0.03672851, 0.8831005 , 0.01272252,
       2.278065  , 0.5509956 , 2.847479  , 0.12789221, 0.0159094 ,
       0.01587995], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119e6a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a165e980; to 'JaxprTracer' at 0x7c31a165ef20>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  81%|████████  | 69/85 [03:14<00:42,  2.68s/it, loss=0.525]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00258326 0.5680226  0.00316512 0.22402169 0.01109691 1.5583938
 0.01701859 0.00488647 0.00540628 0.05906061 0.00486061 0.01472275
 0.18028641 0.01162788 0.038355   0.01657923], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00258326, 0.5680226 , 0.00316512, 0.22402169, 0.01109691,
       1.5583938 , 0.01701859, 0.00488647, 0.00540628, 0.05906061,
       0.00486061, 0.01472275, 0.18028641, 0.01162788, 0.038355  ,
       0.01657923], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16579a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0e3e390; to 'JaxprTracer' at 0x7c31a0e3cef0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  82%|████████▏ | 70/85 [03:16<00:39,  2.63s/it, loss=0.17]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([3.21906689e-03 9.09377262e-03 3.01707219e-02 8.24235827e-02
 6.60770712e-03 2.26095289e-01 3.83589530e+00 3.91645217e-03
 1.18900165e-02 2.30297307e-03 8.59029312e-03 1.54346704e+00
 2.59770989e-01 3.21666449e-01 4.69548814e-02 1.18955523e-02], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([3.21906689e-03, 9.09377262e-03, 3.01707219e-02, 8.24235827e-02,
       6.60770712e-03, 2.26095289e-01, 3.83589530e+00, 3.91645217e-03,
       1.18900165e-02, 2.30297307e-03, 8.59029312e-03, 1.54346704e+00,
       2.59770989e-01, 3.21666449e-01, 4.69548814e-02, 1.18955523e-02],      dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196e310>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), 

Training:  84%|████████▎ | 71/85 [03:19<00:38,  2.75s/it, loss=0.4]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00163529 0.0090881  0.01642762 0.00209874 0.0288318  0.00181236
 0.00432508 0.02115367 0.03435693 0.03273026 0.00718143 0.50840247
 0.00663684 0.01782927 0.15087011 1.5770552 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00163529, 0.0090881 , 0.01642762, 0.00209874, 0.0288318 ,
       0.00181236, 0.00432508, 0.02115367, 0.03435693, 0.03273026,
       0.00718143, 0.50840247, 0.00663684, 0.01782927, 0.15087011,
       1.5770552 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b84990>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a12f08b0; to 'JaxprTracer' at 0x7c31a12f1030>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  85%|████████▍ | 72/85 [03:23<00:37,  2.86s/it, loss=0.151]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([8.8688275e-03 5.6626230e-01 2.0514105e-02 2.5948325e-02 1.4875586e+00
 1.1936446e+00 1.2061283e-02 8.8175986e-04 1.0516603e-01 3.6204034e-01
 3.2400709e-02 1.1216070e-03 3.1657662e-02 4.8453058e-03 1.6908375e+00
 4.9843375e-02], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([8.8688275e-03, 5.6626230e-01, 2.0514105e-02, 2.5948325e-02,
       1.4875586e+00, 1.1936446e+00, 1.2061283e-02, 8.8175986e-04,
       1.0516603e-01, 3.6204034e-01, 3.2400709e-02, 1.1216070e-03,
       3.1657662e-02, 4.8453058e-03, 1.6908375e+00, 4.9843375e-02],      dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a196db10>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7

Training:  86%|████████▌ | 73/85 [03:25<00:32,  2.74s/it, loss=0.35]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.63642585 0.03406847 0.25821668 0.0210024  2.6586888  0.01675568
 0.03348127 0.01132394 0.01346976 0.01428109 0.01315169 0.01320039
 0.08855875 0.00396573 0.00884898 0.09227519], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.63642585, 0.03406847, 0.25821668, 0.0210024 , 2.6586888 ,
       0.01675568, 0.03348127, 0.01132394, 0.01346976, 0.01428109,
       0.01315169, 0.01320039, 0.08855875, 0.00396573, 0.00884898,
       0.09227519], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a665a0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1c55ad0; to 'JaxprTracer' at 0x7c31a1c57380>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  87%|████████▋ | 74/85 [03:28<00:29,  2.66s/it, loss=0.245]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.02943759 0.00556207 0.08238538 0.46259725 0.01168514 0.00556337
 0.00931855 0.01551064 0.015885   0.03793721 0.04573461 0.02850153
 0.02863036 0.02655111 0.00544209 0.03895872], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.02943759, 0.00556207, 0.08238538, 0.46259725, 0.01168514,
       0.00556337, 0.00931855, 0.01551064, 0.015885  , 0.03793721,
       0.04573461, 0.02850153, 0.02863036, 0.02655111, 0.00544209,
       0.03895872], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1a65270>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1606d40; to 'JaxprTracer' at 0x7c31a1607b50>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  88%|████████▊ | 75/85 [03:30<00:26,  2.62s/it, loss=0.0531]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01129613 2.086086   0.005492   0.00602134 0.37632725 0.01075224
 0.01143768 0.05712649 0.05427525 0.01658393 0.00918721 0.00654577
 1.1815683  0.02999732 3.287638   0.01489247], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01129613, 2.086086  , 0.005492  , 0.00602134, 0.37632725,
       0.01075224, 0.01143768, 0.05712649, 0.05427525, 0.01658393,
       0.00918721, 0.00654577, 1.1815683 , 0.02999732, 3.287638  ,
       0.01489247], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1655bc0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1222f70; to 'JaxprTracer' at 0x7c31a1222de0>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  89%|████████▉ | 76/85 [03:33<00:25,  2.78s/it, loss=0.448]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([1.1659049  0.02659975 0.6986352  0.02243693 0.0619842  0.03295359
 0.01347882 0.01996692 0.26761186 0.10303808 0.01059455 0.43674725
 0.02767164 0.06016159 0.68908054 0.09228019], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.1659049 , 0.02659975, 0.6986352 , 0.02243693, 0.0619842 ,
       0.03295359, 0.01347882, 0.01996692, 0.26761186, 0.10303808,
       0.01059455, 0.43674725, 0.02767164, 0.06016159, 0.68908054,
       0.09228019], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1657750>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a13449f0; to 'JaxprTracer' at 0x7c31a1344180>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  91%|█████████ | 77/85 [03:36<00:21,  2.69s/it, loss=0.233]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01794437 0.01318851 0.05430529 0.02797795 0.01204644 0.03346859
 0.01792376 0.00348009 0.02999501 0.24008122 0.05335708 1.9458396
 0.02645789 0.01021715 0.21776469 0.02246503], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01794437, 0.01318851, 0.05430529, 0.02797795, 0.01204644,
       0.03346859, 0.01792376, 0.00348009, 0.02999501, 0.24008122,
       0.05335708, 1.9458396 , 0.02645789, 0.01021715, 0.21776469,
       0.02246503], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a16541f0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a13f9cb0; to 'JaxprTracer' at 0x7c31a13fb380>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  92%|█████████▏| 78/85 [03:39<00:19,  2.75s/it, loss=0.17]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.0246713  0.01358949 0.05351665 0.00923965 0.01375659 0.02635235
 0.25588176 0.08722491 0.01602414 0.1424332  0.00828094 0.20489126
 0.02082261 0.06147124 0.26702031 0.09869932], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.0246713 , 0.01358949, 0.05351665, 0.00923965, 0.01375659,
       0.02635235, 0.25588176, 0.08722491, 0.01602414, 0.1424332 ,
       0.00828094, 0.20489126, 0.02082261, 0.06147124, 0.26702031,
       0.09869932], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119d3f0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a0d2c8b0; to 'JaxprTracer' at 0x7c31a0d2fc90>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  93%|█████████▎| 79/85 [03:41<00:16,  2.68s/it, loss=0.0815]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.10830665 0.48232174 0.04564282 0.0144093  0.01160902 0.07307681
 0.03794891 0.01421152 0.07594    0.00767816 0.6126569  0.00342652
 0.04933436 0.00748496 0.00759713 0.04274683], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.10830665, 0.48232174, 0.04564282, 0.0144093 , 0.01160902,
       0.07307681, 0.03794891, 0.01421152, 0.07594   , 0.00767816,
       0.6126569 , 0.00342652, 0.04933436, 0.00748496, 0.00759713,
       0.04274683], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119dff0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1265f80; to 'JaxprTracer' at 0x7c31a1266480>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  94%|█████████▍| 80/85 [03:44<00:13,  2.70s/it, loss=0.0996]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.01271828 0.06447766 0.00950608 0.03267753 0.00687685 0.03506587
 0.59245825 1.0671979  0.00757406 0.04366919 0.01601018 0.07168797
 0.1721571  0.00862587 0.01667584 0.11413452], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.01271828, 0.06447766, 0.00950608, 0.03267753, 0.00687685,
       0.03506587, 0.59245825, 1.0671979 , 0.00757406, 0.04366919,
       0.01601018, 0.07168797, 0.1721571 , 0.00862587, 0.01667584,
       0.11413452], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119c800>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a17f08b0; to 'JaxprTracer' at 0x7c31a17f0130>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training:  95%|█████████▌| 81/85 [03:47<00:11,  2.81s/it, loss=0.142]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00435927 0.01548576 0.09554447 0.14435749 0.01460516 0.6187524
 0.02186665 0.00755891 0.01344624 0.0816377  0.5896456  0.03219341
 0.02503085 0.30215317 0.06882916 0.7763741 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00435927, 0.01548576, 0.09554447, 0.14435749, 0.01460516,
       0.6187524 , 0.02186665, 0.00755891, 0.01344624, 0.0816377 ,
       0.5896456 , 0.03219341, 0.02503085, 0.30215317, 0.06882916,
       0.7763741 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119e580>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1b78860; to 'JaxprTracer' at 0x7c31a1b787c0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  96%|█████████▋| 82/85 [03:49<00:08,  2.71s/it, loss=0.176]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.00237124 0.00809815 0.0257764  0.00952391 0.01013313 0.0162442
 1.3592317  0.22116922 0.08951537 0.02286475 0.0067755  0.01078633
 0.07256032 0.0134106  0.00579013 0.09334099], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.00237124, 0.00809815, 0.0257764 , 0.00952391, 0.01013313,
       0.0162442 , 1.3592317 , 0.22116922, 0.08951537, 0.02286475,
       0.0067755 , 0.01078633, 0.07256032, 0.0134106 , 0.00579013,
       0.09334099], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a119e600>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1cdabb0; to 'JaxprTracer' at 0x7c31a1cdbec0>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  98%|█████████▊| 83/85 [03:52<00:05,  2.77s/it, loss=0.123]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.05903723 0.0115864  0.04984961 0.26655737 0.00420401 0.0136945
 0.00524016 0.02963238 0.01452879 0.0869207  0.08367481 0.04726694
 0.06515171 0.01812163 0.0299308  0.00613745], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.05903723, 0.0115864 , 0.04984961, 0.26655737, 0.00420401,
       0.0136945 , 0.00524016, 0.02963238, 0.01452879, 0.0869207 ,
       0.08367481, 0.04726694, 0.06515171, 0.01812163, 0.0299308 ,
       0.00613745], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1caffa0>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a13fcdb0; to 'JaxprTracer' at 0x7c31a1f3a020>], out_avals=[ShapedArray(float32[16])], primitive=pjit

Training:  99%|█████████▉| 84/85 [03:55<00:02,  2.69s/it, loss=0.0495]

<bound method _forward_method_to_aval.<locals>.meth of Traced<ConcreteArray([0.11868417 0.00444175 0.01744053 0.04708511 0.06250921 0.01929958
 0.8214077  0.0372411  0.10214975 0.02849609 0.01699023 0.09417078
 1.2994682  0.01295389 0.00976028 0.01452715], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([0.11868417, 0.00444175, 0.01744053, 0.04708511, 0.06250921,
       0.01929958, 0.8214077 , 0.0372411 , 0.10214975, 0.02849609,
       0.01699023, 0.09417078, 1.2994682 , 0.01295389, 0.00976028,
       0.01452715], dtype=float32)
  tangent = Traced<ShapedArray(float32[16])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[16]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7c31a1b86f20>, in_tracers=(Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>, Traced<ShapedArray(float32[16]):JaxprTrace(level=1/0)>), out_tracer_refs=[<weakref at 0x7c31a1bd8310; to 'JaxprTracer' at 0x7c31a1bdbe20>], out_avals=[ShapedArray(float32[16])], primitive=pji

Training: 100%|██████████| 85/85 [03:58<00:00,  2.80s/it, loss=0.169]


Evaluating after epoch 3...


Evaluating: 100%|██████████| 21/21 [00:12<00:00,  1.73it/s]

Accuracy: 0.8631
Precision: 0.8840
Recall: 0.8649
Confusion Matrix:
[[130  21]
 [ 25 160]]





# Final Evaluation

In [14]:
eval_results = after_evaluate_model(model, state.params, test_dataset)

# Print the evaluation results
print(f"Accuracy: {eval_results['accuracy']:.4f}")
print(f"Precision: {eval_results['precision']:.4f}")
print(f"Recall: {eval_results['recall']:.4f}")
print(f"Confusion Matrix:\n{eval_results['confusion_matrix']}")

Evaluating: 100%|██████████| 21/21 [00:12<00:00,  1.72it/s]

Accuracy: 0.8631
Precision: 0.8840
Recall: 0.8649
Confusion Matrix:
[[130  21]
 [ 25 160]]



