In [1]:
from jax import numpy as jnp
from optax.losses import softmax_cross_entropy
from jax.nn import softmax, log_softmax

In [2]:
def safe_cross_entropy_loss(predictions, targets, epsilon=1e-12):
    """
    Computes the cross-entropy loss between predictions and targets.

    :param predictions: Array of predicted probabilities, shape (batch_size, num_classes)
    :param targets: Array of one-hot encoded target probabilities, shape (batch_size, num_classes)
    :param epsilon: Small value to ensure numerical stability (avoid log(0))
    :return: Scalar value representing the cross-entropy loss for the batch
    """
    # Clip predictions to avoid log(0) and log(1) errors
    predictions = jnp.clip(predictions, epsilon, 1.0 - epsilon)

    # Compute the cross-entropy loss
    loss = -jnp.sum(targets * jnp.log(predictions), axis=-1)

    # Return the mean loss over the batch
    return loss

In [10]:
def cross_entropy_loss(*, logits, labels):
    """
    Computes the cross-entropy loss between predictions and targets.    
    Assumes that y is the softmax normalized output of network and labels are one-hot.
    """
    # normalize logits with log_softmax
    y = log_softmax(logits) 

    # Compute the cross-entropy loss
    loss = -jnp.sum(labels * y, axis=-1)

    # Return the mean loss over the batch
    return loss

In [4]:
logits = jnp.array([
    [[1.0, -1.0], [2.0, -2.0], [-1.0, 1.0]],  # Batch 1
    [[-1.0, 1.0], [100.0,- 2.0], [1.0, -1.0]]    # Batch 2
])
labels = jnp.array([
    [[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]],  # Batch 1
    [[0.0, 1.0], [1.0, 0.0], [1.0, 0.0]]   # Batch 2
])


In [22]:
loss_softmax = softmax_cross_entropy(logits=logits, labels=labels)
loss_softmax

Array([[ 0.12692805,  4.01815   ,  2.126928  ],
       [ 0.12692805, -0.        ,  0.12692805]], dtype=float32)

In [11]:
loss_cross = cross_entropy_loss(logits=logits, labels=labels)
loss_cross

Array([[ 0.12692805,  4.01815   ,  2.126928  ],
       [ 0.12692805, -0.        ,  0.12692805]], dtype=float32)

In [24]:
y = softmax(logits)
y

Array([[[0.880797  , 0.11920292],
        [0.98201376, 0.01798621],
        [0.11920292, 0.880797  ]],

       [[0.11920292, 0.880797  ],
        [1.        , 0.        ],
        [0.880797  , 0.11920292]]], dtype=float32)

In [25]:
loss_cross = safe_cross_entropy_loss(predictions=y, targets=labels)
loss_cross

Array([[ 0.12692808,  4.01815   ,  2.126928  ],
       [ 0.12692808, -0.        ,  0.12692808]], dtype=float32)

In [1]:
import tasks
eval_batch=  list(tasks.cue_accumulation_task(n_batches=64, batch_size=8))

In [2]:
batch = eval_batch[0]

In [5]:
batch['label'].shape

(8, 2550, 2)