In [7]:
import mlx.core as mx
import mlx.nn as nn

In [8]:
# Create input tensor and layer
x = mx.random.normal((1, 10))  # Example input
layer = nn.Linear(10, 5)  # Example layer

In [9]:
# Create a function to compute loss (example: sum of outputs)
def loss_fn(x, layer):
    return mx.sum(layer(x))

In [10]:
# Compute the gradient of loss with respect to input
grad_fn = mx.grad(loss_fn, argnums=0)  # For input gradient
dx = grad_fn(x, layer)
print(dx)

array([[0.535994, 0.335815, -0.0523646, ..., -0.10516, 0.389606, -0.120561]], dtype=float32)


In [11]:
# Compute the gradient of loss with respect to layer parameters
grad_fn_params = mx.grad(loss_fn, argnums=1)  # For layer parameter gradients
dlayer = grad_fn_params(x, layer)
print(dlayer)

{'weight': array([[-0.757283, 0.135378, -0.308067, ..., -0.00997045, 0.140571, 0.39742],
       [-0.757283, 0.135378, -0.308067, ..., -0.00997045, 0.140571, 0.39742],
       [-0.757283, 0.135378, -0.308067, ..., -0.00997045, 0.140571, 0.39742],
       [-0.757283, 0.135378, -0.308067, ..., -0.00997045, 0.140571, 0.39742],
       [-0.757283, 0.135378, -0.308067, ..., -0.00997045, 0.140571, 0.39742]], dtype=float32), 'bias': array([1, 1, 1, 1, 1], dtype=float32)}


In [12]:
# Compute the gradient of loss with respect to layer parameters (Alt.)
grad_fn_params = nn.value_and_grad(layer, loss_fn)  # For layer parameter gradients
_, dlayer = grad_fn_params(x, layer)
print(dlayer)

{'weight': array([[-0.757283, 0.135378, -0.308067, ..., -0.00997045, 0.140571, 0.39742],
       [-0.757283, 0.135378, -0.308067, ..., -0.00997045, 0.140571, 0.39742],
       [-0.757283, 0.135378, -0.308067, ..., -0.00997045, 0.140571, 0.39742],
       [-0.757283, 0.135378, -0.308067, ..., -0.00997045, 0.140571, 0.39742],
       [-0.757283, 0.135378, -0.308067, ..., -0.00997045, 0.140571, 0.39742]], dtype=float32), 'bias': array([1, 1, 1, 1, 1], dtype=float32)}
