# **Indexing Arrays in MLX**

MLX follows NumPy-like indexing behavior, allowing **integer indexing, slicing, ellipsis (`...`), new axis insertion, and array-based indexing**. However, there are key differences, such as **lack of bounds checking** and **no support for boolean masks yet**.

---

## **1. Basic Indexing in MLX**

MLX indexing works similarly to NumPy's `numpy.ndarray`, allowing integer-based indexing and slicing.

### **Integer and Slice Indexing**
- Index a **single element** using an integer.
- Use **negative indexing** to access elements from the end.
- Use **slicing** with `[start:stop:step]` format.


In [1]:
import mlx.core as mx

# Create a 1D array with values from 0 to 9
arr = mx.arange(10)
print(f"Original Array: {arr}")

# Index a single element
print(f"arr[3]: {arr[3]}")  # Access the 4th element (0-based index)

# Negative indexing
print(f"arr[-2]: {arr[-2]}")  # Second last element

# Slicing with step
print(f"arr[2:8:2]: {arr[2:8:2]}")  # Elements from index 2 to 8, skipping every second element


Original Array: array([0, 1, 2, ..., 7, 8, 9], dtype=int32)
arr[3]: array(3, dtype=int32)
arr[-2]: array(8, dtype=int32)
arr[2:8:2]: array([2, 4, 6], dtype=int32)


## **2. Indexing Multi-Dimensional Arrays**

MLX supports **multi-dimensional indexing**, including the **ellipsis (`...`)** to simplify selection.

### **Multi-Dimensional Indexing**
- You can index multi-dimensional arrays in MLX just like NumPy.
- The **ellipsis (`...`)** replaces multiple `:` when selecting slices.


In [2]:
# Create a 3D array
arr = mx.arange(8).reshape(2, 2, 2)
print("Original Array:")
print(arr)

# Select all rows and columns, but only the first depth dimension
print("arr[:, :, 0]:")
print(arr[:, :, 0])

# Use ellipsis to achieve the same result
print("arr[..., 0]:")
print(arr[..., 0])


Original Array:
array([[[0, 1],
        [2, 3]],
       [[4, 5],
        [6, 7]]], dtype=int32)
arr[:, :, 0]:
array([[0, 2],
       [4, 6]], dtype=int32)
arr[..., 0]:
array([[0, 2],
       [4, 6]], dtype=int32)


## **3. Adding a New Axis**

You can add a new axis to an array by indexing with `None`. This increases the number of dimensions in an array.

### **Expanding Dimensions**
- Using `None` converts a one-dimensional array into a two-dimensional array with a shape of `[1, n]`.


In [3]:
arr = mx.arange(8)
print(f"Original shape: {arr.shape}")  # [8]

# Adding a new axis
arr_new_axis = arr[None]
print(f"New shape: {arr_new_axis.shape}")  # [1, 8]


Original shape: (8,)
New shape: (1, 8)


This method is useful for adding a batch dimension or converting a vector into a row matrix.

## **4. Indexing with Another Array**

MLX allows advanced indexing, where one array is used to select values from another array.

### **Array-Based Indexing**
Instead of selecting elements using individual integers, an array of indices can be used to retrieve multiple values at once.

#### **Example: Selecting Elements with an Index Array**
1. Create an array from 0 to 9.
2. Use an index array to select specific elements.



In [4]:
arr = mx.arange(10)  # Array: [0, 1, 2, ..., 9]
idx = mx.array([5, 7])  # Index positions
print(f"Selected values: {arr[idx]}")  # Select elements at index 5 and 7


Selected values: array([5, 7], dtype=int32)



This method is useful when selecting multiple values dynamically based on predefined indices.

---

## **5. Useful Functions for Indexing**

MLX provides additional functions for working with array indices.

### **Using the `take()` Function**
The `take()` function allows selecting elements at specific indices.

#### **Example: Selecting Elements Using `take()`**
1. Create an array of five elements.
2. Use `take()` to select values at specific indices.



In [5]:
arr = mx.array([10, 20, 30, 40, 50])
idx = mx.array([2, 4])
selected = mx.take(arr, idx)
print(f"Selected elements: {selected}")


Selected elements: array([30, 50], dtype=int32)



### **Using the `take_along_axis()` Function**
The `take_along_axis()` function selects elements along a specific axis, which is useful for multi-dimensional arrays.

---

## **6. Differences from NumPy**

MLX indexing is similar to NumPy but has key differences.

### **Key Differences**
1. **No Bounds Checking**
   - In NumPy, indexing out of bounds raises an error, but in MLX, the behavior is undefined.
   - There are no automatic checks to prevent accessing elements outside the valid range.

2. **No Boolean Mask Indexing Yet**
   - MLX does not currently support selecting elements based on boolean conditions.
   - Example of unsupported functionality:
   
```python
         arr[arr>5] # not supported in MLX
```

3. **Limited Support for Operations That Change Shape**
- MLX does not yet support functions like `numpy.nonzero()` or the single-input version of `numpy.where()`, as they dynamically determine the output shape.

---

## **7. In-Place Updates**

MLX supports in-place updates, meaning that modifying an array's values will affect all references to it.

### **Example: Modifying an Array in Place**
1. Create an array with three elements.
2. Modify one of the elements directly.



In [6]:
a = mx.array([1, 2, 3])
a[2] = 0
print(f"Modified Array: {a}")  # Updates the last element to 0


Modified Array: array([1, 2, 0], dtype=int32)



### **Effect on References**
If one array is assigned to another variable, modifying one will reflect in the other.



In [7]:
a = mx.array([1, 2, 3])
b = a  # b now refers to the same data as a
b[2] = 0  # Modify b
print(f"b: {b}")  # Shows updated values
print(f"a: {a}")  # Also shows updated values (same reference)


b: array([1, 2, 0], dtype=int32)
a: array([1, 2, 0], dtype=int32)



Since `a` and `b` share the same memory, changes made to one affect both.

---

## **8. Gradients with In-Place Updates**

MLX supports automatic differentiation, even when using in-place updates.

### **Computing Gradients with In-Place Updates**
1. Define a function that modifies an array element.
2. Compute the gradient of the function.



In [8]:
def fun(x, idx):
    x[idx] = 2.0
    return x.sum()

dfdx = mx.grad(fun)(mx.array([1.0, 2.0, 3.0]), mx.array([1]))
print(f"Gradient: {dfdx}")


Gradient: array([1, 0, 1], dtype=float32)



### **Explanation**
- The gradient is `1` everywhere except at the indexed position, where it is `0`.
- This correctly represents that changing `x[idx]` does not affect the sum.

---

## **9. Summary**

| Feature | MLX Support | Notes |
|---------|------------|-------|
| Integer Indexing | ✅ Yes | Works like NumPy |
| Negative Indexing | ✅ Yes | Supports negative indices |
| Slicing | ✅ Yes | `[start:stop:step]` format |
| Multi-Dimensional Indexing | ✅ Yes | Supports `...` syntax |
| Boolean Mask Indexing | ❌ No | Not yet supported |
| Indexing Out of Bounds | ❌ No | Undefined behavior |
| In-Place Updates | ✅ Yes | Works like NumPy |
| Gradient Computation with In-Place Updates | ✅ Yes | Correctly computes gradients |

---

## **Key Takeaways**
1. MLX indexing is similar to NumPy but lacks bounds checking and boolean mask support.
2. In-place updates affect all references, making memory management efficient.
3. Gradients with in-place updates work correctly, marking modified indices as `0`.

This guide covers **advanced indexing techniques in MLX**, including **multi-dimensional selection, array-based indexing, in-place updates, and gradient computation**. Let me know if you need further clarifications!
