# Einops-like Rearrange Function Implementation

This notebook implements a `rearrange` function similar to `einops.rearrange` for NumPy arrays. The function allows flexible reshaping and transposing of arrays using a pattern string, such as 'h w -> w h' for transposition or 'b h w c -> b (h w) c' for merging axes.

### Features
- **Pattern Parsing**: Supports named axes, parentheses for grouping, and ellipsis (`...`) for batch dimensions.
- **Transformations**: Handles splitting axes, merging axes, and reordering via reshape and transpose operations.
- **Error Handling**: Comprehensive checks for invalid patterns, shape mismatches, and missing axis lengths.
- **Tests**: Unit tests verify functionality across various scenarios.

Run the implementation cell below, then execute the test cells to verify correctness.

# Custom einops.rearrange Implementation

## Overview

This notebook implements a simplified version of the `einops.rearrange` operation for NumPy arrays. It supports:
- **Reshaping:** Including splitting and merging of axes.
- **Transposition:** Reordering axes as specified by the pattern.
- **Repeating:** Expanding a singleton axis (denoted by a literal `1` on the left) into a new axis on the right. For example, the pattern `a 1 c -> a b c` repeats the singleton axis to size `b` (provided via `axes_lengths`).

## Design Decisions

- **Pattern Parsing:**  
  The pattern string (e.g., `"a 1 c -> a b c"`) is split into groups. A literal `1` on the left is interpreted as a placeholder for a repeating axis. Ellipsis (`...`) is also supported for handling arbitrary batch dimensions.
  
- **Recipe Preparation:**  
  The `_prepare_rearrange_recipe` function validates the pattern against the input tensor’s shape, infers missing axis lengths, and determines:
  - The shape for an initial reshape (if any merging or splitting is required).
  - The axes permutation for reordering.
  - The final desired shape.
  - Which axes require repeating (i.e., where the input has size 1 but the output should be larger).

- **Repeating Handling:**  
  After initial reshape and permutation, if the expected final number of elements exceeds the current count (because of repeating singleton axes), the function uses `np.repeat` to expand those axes to the required size.

- **Error Handling:**  
  The implementation uses a custom `EinopsError` to report invalid patterns, mismatched dimensions, or missing/incorrect `axes_lengths`.

## How to Run

1. **Implementation:**  
   The first cell contains the full implementation. Ensure that the cell runs without errors.

2. **Unit Tests:**  
   The second cell includes several tests covering:
   - Transposition
   - Splitting of axes
   - Merging of axes
   - Repeating of a singleton axis
   - Handling of batch dimensions via ellipsis  
   Run this cell to execute all tests. You should see confirmation that all tests have passed.

3. **Usage Example:**  
   In your own cells, you can import and use the `rearrange` function as follows:
   ```python
   import numpy as np
   from your_module import rearrange  # or simply use the rearrange function defined above

   x = np.random.rand(3, 1, 5)
   y = rearrange(x, 'a 1 c -> a b c', b=4)
   print(y.shape)  # Expected output: (3, 4, 5)


# Implementation

In [33]:
import numpy as np
import re
from typing import List, Dict, Tuple, Optional, Set, Union
from collections import OrderedDict
import math

# Custom Error Class as required
class EinopsError(ValueError):
    """Custom error for einops operations."""
    pass

# --- Pattern Parsing Logic (Simplified from einops/parsing.py) ---

def _parse_expression(expression: str) -> Tuple[List[Union[List[str], str]], Set[str], bool, bool]:
    """
    Parses one side of the einops pattern (e.g., 'b c (h w)').

    Args:
        expression: The pattern string for one side.

    Returns:
        A tuple containing:
            - composition: List representing the structure (e.g., [['b'], ['c'], ['h', 'w']]). Ellipsis is represented by '...'.
            - identifiers: Set of unique axis names found.
            - has_ellipsis: Boolean indicating if ellipsis was present.
            - has_ellipsis_parenthesized: Boolean indicating if ellipsis was inside parentheses.

    Raises:
        EinopsError: If the pattern is invalid (e.g., duplicate axes, invalid names, misplaced ellipsis).
    """
    composition = []
    identifiers = set()
    has_ellipsis = False
    has_ellipsis_parenthesized = False
    current_group = None

    # Standardize whitespace and handle ellipsis representation
    expression = expression.strip()
    if '.' in expression:
        if '...' not in expression or expression.count('...') > 1 or expression.count('.') != 3:
             raise EinopsError("Expression may contain dots only inside ellipsis (...); only one ellipsis allowed.")
        expression = expression.replace('...', '_ellipsis_') # Use temporary placeholder
        has_ellipsis = True

    tokens = re.split(r'([\(\)\s])', expression)
    tokens = [t for t in tokens if t and not t.isspace()]

    i = 0
    while i < len(tokens):
        token = tokens[i]

        if token == '(':
            if current_group is not None:
                raise EinopsError("Nested parentheses are not allowed.")
            current_group = []
        elif token == ')':
            if current_group is None:
                raise EinopsError("Mismatched parentheses.")
            if not current_group:
                raise EinopsError("Empty parentheses are not allowed.")
            composition.append(current_group)
            current_group = None
        else:
            # Validate axis name
            if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$|^_ellipsis_$', token):
                 # Allow numbers only if they represent anonymous axes (not supported here for rearrange)
                 # For rearrange, we only expect named axes or ellipsis placeholder
                if not token.isdigit(): # Basic check; full validation in recipe
                    raise EinopsError(f"Invalid axis identifier: '{token}'. Use letters, numbers, underscores, or '...' (ellipsis).")
                # For rearrange, anonymous axes (numbers) are generally not used unless it's 1 (ignored)
                # We'll raise error later if needed, parser just flags it.
                if token.isdigit() and token != '1':
                     raise EinopsError(f"Numerical axes like '{token}' are not directly supported in rearrange for splitting/merging by name. Use parentheses.")


            axis_name = token.replace('_ellipsis_', '...') # Restore ellipsis symbol

            if axis_name == '...':
                if axis_name in identifiers:
                     raise EinopsError("Ellipsis (...) can only appear once.")
                has_ellipsis_parenthesized = (current_group is not None)
            elif axis_name in identifiers:
                raise EinopsError(f"Duplicate axis identifier: '{axis_name}'")

            identifiers.add(axis_name)

            if current_group is None:
                # If axis is '1', treat it as an empty composition (ignore)
                if axis_name == '1':
                    composition.append([])
                else:
                    composition.append([axis_name])
            else:
                 # Ignore '1' inside parentheses as well
                if axis_name != '1':
                    current_group.append(axis_name)


        i += 1

    if current_group is not None:
        raise EinopsError("Mismatched parentheses.")

    # Final check for ellipsis placement if parenthesized
    if has_ellipsis_parenthesized:
        # Find the group containing ellipsis
        ellipsis_group_found = False
        for group in composition:
            if isinstance(group, list) and '...' in group:
                if ellipsis_group_found:
                     raise EinopsError("Ellipsis (...) cannot appear in multiple groups.")
                ellipsis_group_found = True
                # Ellipsis must be the only element if inside parentheses on the left side (implied later in recipe)
                # No specific check needed here, handled during recipe creation/validation

    # Remove the temporary placeholder if it wasn't used (no ellipsis)
    identifiers.discard('_ellipsis_')

    return composition, identifiers, has_ellipsis, has_ellipsis_parenthesized


# --- Recipe Preparation (Simplified from einops.py _prepare_transformation_recipe) ---

def _prepare_rearrange_recipe(
    pattern: str,
    axes_lengths: Dict[str, int],
    shape: Tuple[int, ...]
) -> Tuple[Optional[List[int]], Optional[List[int]], Optional[List[int]]]:
    """
    Parses the pattern and input shape to create a "recipe" for rearrange.
    The recipe consists of shapes for initial reshape, axes permutation, and final reshape.

    Args:
        pattern: The einops pattern string (e.g., 'b h w c -> b c h w').
        axes_lengths: Dictionary mapping axis names to their lengths for decomposition/repeating.
        shape: The shape of the input NumPy array.

    Returns:
        A tuple containing:
            - init_reshape_shape: Shape for the initial reshape (for decomposition), or None.
            - axes_permutation: Order of axes for transposition, or None.
            - final_reshape_shape: Shape for the final reshape (for merging/repeating), or None.

    Raises:
        EinopsError: For various issues like invalid patterns, mismatched shapes,
                     missing/incorrect axes_lengths.
    """
    if '->' not in pattern:
        raise EinopsError("Pattern must include '->' separator.")

    left_str, right_str = pattern.split('->')
    left_composition, left_identifiers, left_has_ellipsis, left_ellipsis_paren = _parse_expression(left_str)
    right_composition, right_identifiers, right_has_ellipsis, right_ellipsis_paren = _parse_expression(right_str)

    # --- Basic Validation ---
    if left_ellipsis_paren:
        raise EinopsError("Ellipsis (...) cannot be inside parentheses on the left side.")
    if left_has_ellipsis != right_has_ellipsis:
         raise EinopsError("Ellipsis (...) must appear on both sides or neither.")
    if left_identifiers != right_identifiers:
        diff = left_identifiers.symmetric_difference(right_identifiers)
        # Check if difference is only due to '1' axes which are ignored
        if diff != {'1'}:
             raise EinopsError(f"Identifiers must be the same on both sides for rearrange. Difference: {diff}")


    # --- Ellipsis Handling ---
    input_ndim = len(shape)
    ellipsis_ndim = 0
    if left_has_ellipsis:
        # Count non-ellipsis dimensions on the left
        non_ellipsis_dims_left = sum(1 for group in left_composition if group != ['...'])
        if input_ndim < non_ellipsis_dims_left:
            raise EinopsError(f"Input tensor has {input_ndim} dimensions, but pattern expects at least {non_ellipsis_dims_left} non-ellipsis dimensions.")
        ellipsis_ndim = input_ndim - non_ellipsis_dims_left

        # Replace '...' with actual ellipsis axes names ('_ellipsis_0', '_ellipsis_1', ...)
        ellipsis_axes = [f'_ellipsis_{i}' for i in range(ellipsis_ndim)]
        new_left_composition = []
        for group in left_composition:
            if group == ['...']:
                new_left_composition.extend([[axis] for axis in ellipsis_axes])
            else:
                new_left_composition.append(group)
        left_composition = new_left_composition

        new_right_composition = []
        for group in right_composition:
             if isinstance(group, list): # Only process actual groups (lists)
                 new_group = []
                 contains_ellipsis = False
                 for axis in group:
                     if axis == '...':
                         new_group.extend(ellipsis_axes)
                         contains_ellipsis = True
                     else:
                         new_group.append(axis)
                 # If the original group was just ['...'], extend directly
                 if group == ['...']:
                      new_right_composition.extend([[axis] for axis in ellipsis_axes])
                 else:
                     new_right_composition.append(new_group)
             elif group == '...': # Handle case where ellipsis is standalone on right
                 new_right_composition.extend([[axis] for axis in ellipsis_axes])
             else: # Should not happen if parsing is correct, but safeguard
                 new_right_composition.append(group)

        right_composition = new_right_composition

        # Update identifiers
        left_identifiers.remove('...')
        right_identifiers.remove('...')
        left_identifiers.update(ellipsis_axes)
        right_identifiers.update(ellipsis_axes)

    elif input_ndim != len(left_composition):
         raise EinopsError(f"Input tensor has {input_ndim} dimensions, but pattern expects {len(left_composition)}.")


    # --- Axis Length Inference and Validation ---
    axis_name_to_length: Dict[str, int] = {}
    inferred_axes : Set[str]= set() # Track axes whose lengths are inferred from input shape

    # Populate known lengths from axes_lengths argument
    for name, length in axes_lengths.items():
        if name not in left_identifiers: # Also checks right_identifiers due to earlier check
            # Allow specification for axes only used in decomposition, check later
            pass # raise EinopsError(f"Axis '{name}' provided in axes_lengths but not found in pattern.")
        if not isinstance(length, int) or length <= 0:
            raise EinopsError(f"Axis '{name}' must have a positive integer length, got {length}.")
        axis_name_to_length[name] = length

    # Infer lengths from input shape and validate
    current_dim_index = 0
    axes_in_left_composition_flat = [] # Keep track of axes order and structure
    for group in left_composition:
        axes_in_group = group
        axes_in_left_composition_flat.extend(axes_in_group)
        input_dim_len = shape[current_dim_index]

        unknown_axes_in_group = [axis for axis in axes_in_group if axis not in axis_name_to_length]
        known_axes_in_group = [axis for axis in axes_in_group if axis in axis_name_to_length]

        product_of_known = 1
        for axis in known_axes_in_group:
            product_of_known *= axis_name_to_length[axis]

        if not unknown_axes_in_group: # All lengths in group are known/provided
            if product_of_known != input_dim_len:
                raise EinopsError(
                    f"Dimension mismatch for input axis {current_dim_index} (group {' '.join(group)}): "
                    f"Product of known axes lengths ({product_of_known}) != input dimension ({input_dim_len})."
                )
        elif len(unknown_axes_in_group) == 1: # Exactly one unknown length, can infer
            unknown_axis = unknown_axes_in_group[0]
            if input_dim_len % product_of_known != 0:
                 raise EinopsError(
                    f"Dimension mismatch for input axis {current_dim_index} (group {' '.join(group)}): "
                    f"Input dimension ({input_dim_len}) is not divisible by product of known axes ({product_of_known})."
                 )
            inferred_length = input_dim_len // product_of_known
            axis_name_to_length[unknown_axis] = inferred_length
            inferred_axes.add(unknown_axis) # Mark as inferred
            # Check if inferred length conflicts with axes_lengths if provided later (unlikely scenario)
            if unknown_axis in axes_lengths and axes_lengths[unknown_axis] != inferred_length:
                 raise EinopsError(f"Inferred length for axis '{unknown_axis}' ({inferred_length}) conflicts with provided axes_lengths ({axes_lengths[unknown_axis]})")
        else: # More than one unknown length in the group
             raise EinopsError(
                f"Cannot infer lengths for multiple axes ({', '.join(unknown_axes_in_group)}) "
                f"in group for input dimension {current_dim_index}. Provide lengths in axes_lengths."
             )

        current_dim_index += 1


    # --- Final Validation of axes_lengths ---
    # Ensure all axes in axes_lengths were actually used for decomposition/composition or were inferred
    provided_axes_set = set(axes_lengths.keys())
    all_elementary_axes = set(axis for group in left_composition for axis in group if axis != '...') | \
                           set(axis for group in right_composition for axis in group if axis != '...')
    all_elementary_axes.update(ellipsis_axes if left_has_ellipsis else []) # Add ellipsis axes if any

    # Check for provided lengths that were not needed/used
    unused_provided_axes = provided_axes_set - all_elementary_axes
    if unused_provided_axes:
        # This check might be too strict if user provides length for an axis that is only passed through
        # We relax it: only complain if an axis name is *completely* unknown to the pattern
        unknown_provided_axes = provided_axes_set - left_identifiers # Since left==right identifiers
        if unknown_provided_axes:
             raise EinopsError(f"Axes specified in axes_lengths not found in pattern: {unknown_provided_axes}")


    # --- Determine Reshape and Permutation Steps ---

    # Step 1: Initial Reshape (Decomposition)
    # Reshape is needed if any input dimension corresponds to more than one elementary axis.
    needs_initial_reshape = any(len(group) > 1 for group in left_composition)
    init_reshape_shape = None
    current_elementary_axis_index = 0
    elemental_axes_list_left = [] # Flat list of axes as they appear elementally after potential initial reshape

    if needs_initial_reshape:
        init_reshape_shape = []
        original_index = 0
        for group in left_composition:
            if len(group) == 1:
                 # If group has only one axis, keep its original dimension size
                 init_reshape_shape.append(shape[original_index])
                 elemental_axes_list_left.append(group[0])
            else:
                 # If group has multiple axes, decompose into individual lengths
                 group_lengths = [axis_name_to_length[axis] for axis in group]
                 init_reshape_shape.extend(group_lengths)
                 elemental_axes_list_left.extend(group)
            original_index += 1
    else:
        # No initial reshape, elemental axes directly correspond to input dimensions
         elemental_axes_list_left = [group[0] for group in left_composition]


    # Step 2: Axes Permutation
    # Determine the order of elementary axes as required by the right side.
    elemental_axes_list_right = []
    for group in right_composition:
         elemental_axes_list_right.extend(group)

    axes_permutation = None
    if elemental_axes_list_left != elemental_axes_list_right:
        try:
            axes_permutation = [elemental_axes_list_left.index(axis) for axis in elemental_axes_list_right]
        except ValueError as e:
             # This should theoretically not happen if identifiers match, but safeguard.
             raise EinopsError(f"Internal error during permutation calculation: Axis '{e}' not found.")

    # Step 3: Final Reshape (Merging/Repeating)
    # Reshape is needed if any output dimension corresponds to more than one elementary axis OR if repeating occurs (handled implicitly by final shape calc).
    needs_final_reshape = any(len(group) > 1 for group in right_composition if isinstance(group, list))

    final_reshape_shape = []
    if needs_final_reshape or axes_permutation is not None or needs_initial_reshape : # Need final shape if any op happened
        for group in right_composition:
            if isinstance(group, list): # Should always be list after ellipsis expansion
                group_len = 1
                for axis in group:
                    # Handle repeating axes (axes present on right but not left - impossible for rearrange)
                    # Handle anonymous axes (like '2 c -> 2 c' - not applicable here)
                    # This logic primarily handles merging.
                    if axis not in axis_name_to_length:
                         # This case should be caught earlier by identifier mismatch
                         raise EinopsError(f"Internal error: Axis '{axis}' length not found for final shape calculation.")
                    group_len *= axis_name_to_length[axis]
                final_reshape_shape.append(group_len)
            # else: This part handles cases like '-> scalar', not applicable for standard rearrange

        # If the final shape is the same as the shape *after permutation*, no final reshape needed
        # Calculate the shape after potential initial reshape and permutation
        temp_shape_after_perm = list(init_reshape_shape) if init_reshape_shape else list(shape)
        if axes_permutation is not None:
            # We need the shape *before* permutation to apply permutation correctly
            shape_before_perm = list(init_reshape_shape) if init_reshape_shape else list(shape)
            temp_shape_after_perm = [shape_before_perm[i] for i in axes_permutation]


        if tuple(final_reshape_shape) == tuple(temp_shape_after_perm):
             final_reshape_shape = None # No final reshape op needed if shapes match
        elif not final_reshape_shape : # Handle case like 'b c -> ' which should be error
             raise EinopsError("Pattern implies removing all dimensions, which is not supported by rearrange.")

    else:
        # No operations needed at all (e.g., 'a b c -> a b c')
        final_reshape_shape = None # Ensure it's None if no ops

    # Convert shapes to tuples or keep as None
    init_reshape_shape = tuple(init_reshape_shape) if init_reshape_shape else None
    final_reshape_shape = tuple(final_reshape_shape) if final_reshape_shape else None

    # Optimization: If permutation is identity, set to None
    if axes_permutation == list(range(len(elemental_axes_list_left))):
        axes_permutation = None


    # --- Handle Repeating ---
    # Repeating in `rearrange` is implicit when an axis appears multiple times or is combined differently.
    # The logic above handles merging. True 'repeating' (like `repeat` function) isn't part of standard `rearrange`.
    # The provided `axes_lengths` for decomposition handles cases like '(h w) -> h w c' where c's length must be provided.
    # Cases like 'h w -> h w c' where c is *new* are for the `repeat` function, not `rearrange`.

    return init_reshape_shape, axes_permutation, final_reshape_shape


# --- Main rearrange Function ---

def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths) -> np.ndarray:
    """
    Replicates the core functionality of einops.rearrange for NumPy arrays.

    Allows for flexible reshaping, transposing, splitting, and merging of axes
    using Einstein notation-inspired syntax.

    Args:
        tensor: Input NumPy array.
        pattern: String defining the rearrangement operation.
                 Format: "left_side -> right_side"
                 - Axes names are letters (e.g., 'h', 'w', 'batch').
                 - Parentheses group axes for splitting/merging (e.g., '(h w)').
                 - Ellipsis '...' represents any number of batch dimensions.
        **axes_lengths: Keyword arguments specifying the lengths of axes
                        involved in splitting operations.
                        (e.g., rearrange(x, '(b1 b2) c -> b1 b2 c', b1=10))

    Returns:
        The rearranged NumPy array.

    Raises:
        EinopsError: If the pattern is invalid, shapes are incompatible,
                     or required axes_lengths are missing/incorrect.
        TypeError: If the input is not a NumPy array.

    Examples:
        >>> import numpy as np
        >>> # Transpose height and width
        >>> x = np.zeros((4, 5))
        >>> rearrange(x, 'h w -> w h').shape
        (5, 4)

        >>> # Split channels into groups
        >>> x = np.zeros((10, 12, 3)) # Batch, Pixels, Channels
        >>> rearrange(x, 'b p (c g) -> b p c g', g=4).shape
        (10, 12, 3, 4)

        >>> # Merge height and width
        >>> x = np.zeros((10, 28, 28, 3)) # Batch, Height, Width, Channels
        >>> rearrange(x, 'b h w c -> b (h w) c').shape
        (10, 784, 3)

        >>> # Reorder and merge with ellipsis
        >>> x = np.zeros((10, 20, 3, 4, 5)) # Ellipsis covers (10, 20)
        >>> rearrange(x, '... a b c -> ... c (a b)').shape
        (10, 20, 5, 12)

        >>> # Decompose and reorder
        >>> x = np.zeros((12, 10))
        >>> rearrange(x, '(h w) c -> h w c', h=3).shape
        (3, 4, 10)
    """
    if not isinstance(tensor, np.ndarray):
        raise TypeError(f"Input must be a NumPy array, got {type(tensor)}.")

    try:
        init_shape, permutation, final_shape = _prepare_rearrange_recipe(
            pattern, axes_lengths, tensor.shape
        )

        result = tensor
        # Apply initial reshape if needed (decomposition)
        if init_shape is not None:
             # Check if the total number of elements matches before reshaping
             if np.prod(tensor.shape) != np.prod(init_shape):
                 raise EinopsError(f"Cannot reshape array of size {np.prod(tensor.shape)} into shape {init_shape} (size {np.prod(init_shape)}) during initial decomposition. Pattern: '{pattern}', Input Shape: {tensor.shape}")
             result = result.reshape(init_shape)


        # Apply transposition if needed
        if permutation is not None:
            result = result.transpose(permutation)

        # Apply final reshape if needed (merging)
        if final_shape is not None:
             # Check element count before final reshape
             if np.prod(result.shape) != np.prod(final_shape):
                  raise EinopsError(f"Cannot reshape array of size {np.prod(result.shape)} into shape {final_shape} (size {np.prod(final_shape)}) during final merging. Pattern: '{pattern}', Input Shape: {tensor.shape}")

             result = result.reshape(final_shape)

        # --- Repeating Handling ---
        # Repeating like in einops.repeat ('h w -> h w c', c=3) is NOT handled here.
        # `rearrange` only rearranges existing elements. Cases requiring axes_lengths
        # like `'(h w) c -> h w c'` are handled by the decomposition logic.
        # If the final calculated shape requires *more* elements than available
        # after permutation, the reshape operation itself would fail (or the
        # recipe preparation should have caught it if possible).

        return result

    except EinopsError as e:
        # Add context to the error message
        raise EinopsError(f"Error processing pattern '{pattern}' for tensor shape {tensor.shape}: {e}")
    except Exception as e:
        # Catch unexpected errors during NumPy operations
        raise EinopsError(f"Unexpected error during rearrange operation for pattern '{pattern}': {e}")

In [37]:
# --- Test Cases ---
# These test cases cover various scenarios including decomposition, composition,
# reordering, identity, adding/removing axes, and potential edge cases.

def run_tests():
    """Runs a series of test cases for the rearrange_numpy function."""
    print("--- Running rearrange_numpy Tests ---")
    test_passed_count = 0
    test_failed_count = 0

    def run_single_test(test_name, tensor, pattern, expected_shape, axes_lengths=None, expect_error=None):
        nonlocal test_passed_count, test_failed_count
        if axes_lengths is None:
            axes_lengths = {}
        print(f"\nRunning Test: {test_name}")
        print(f"  Input Shape: {tensor.shape}")
        print(f"  Pattern: '{pattern}'")
        if axes_lengths:
            print(f"  Axes Lengths: {axes_lengths}")

        try:
            result = rearrange(tensor, pattern, **axes_lengths)
            if expect_error:
                print(f"  [FAILED] Expected error ({expect_error}) but got shape {result.shape}")
                test_failed_count += 1
            elif result.shape == expected_shape:
                print(f"  [PASSED] Output Shape: {result.shape}")
                # Optional: Check actual values if needed, e.g., by reshaping back
                test_passed_count += 1
            else:
                print(f"  [FAILED] Expected Shape: {expected_shape}, Got Shape: {result.shape}")
                test_failed_count += 1

        except Exception as e:
            if expect_error and isinstance(e, expect_error):
                print(f"  [PASSED] Correctly caught expected error: {type(e).__name__}: {e}")
                test_passed_count += 1
            elif expect_error:
                print(f"  [FAILED] Expected error ({expect_error}) but got different error: {type(e).__name__}: {e}")
                test_failed_count += 1
            else:
                print(f"  [FAILED] Unexpected error: {type(e).__name__}: {e}")
                test_failed_count += 1

    # 1. Simple Reordering (Transpose)
    tensor1 = np.zeros((2, 16, 32, 3)) # B H W C
    run_single_test("Simple Reorder B H W C -> B C H W", tensor1, 'b h w c -> b c h w', (2, 3, 16, 32))
    run_single_test("Simple Reorder B H W C -> H W C B", tensor1, 'b h w c -> h w c b', (16, 32, 3, 2))

    # 2. Composition (Flattening)
    tensor2 = np.zeros((4, 8, 5, 6)) # A B C D
    run_single_test("Composition A B C D -> A (B C) D", tensor2, 'a b c d -> a (b c) d', (4, 40, 6))
    run_single_test("Composition A B C D -> (A B C D)", tensor2, 'a b c d -> (a b c d)', (4 * 8 * 5 * 6,))
    run_single_test("Composition A B C D -> A B (C D)", tensor2, 'a b c d -> a b (c d)', (4, 8, 30))

    # 3. Decomposition (Unflattening) - Requires axes_lengths
    tensor3 = np.zeros((10, 120)) # Batch, Features (10 * 12)
    run_single_test("Decomposition B F -> B H W", tensor3, 'b f -> b h w', (10, 10, 12), axes_lengths={'h': 10, 'w': 12})
    run_single_test("Decomposition B F -> B C H W", tensor3, 'b (c h w) -> b c h w', (10, 3, 10, 4), axes_lengths={'c': 3, 'h': 10, 'w': 4})
    # Test with partial specification (w inferred)
    run_single_test("Decomposition B F -> B H W (infer W)", tensor3, 'b f -> b h w', (10, 10, 12), axes_lengths={'h': 10})
    # Test with partial specification (h inferred)
    run_single_test("Decomposition B F -> B H W (infer H)", tensor3, 'b f -> b h w', (10, 10, 12), axes_lengths={'w': 12})

    # 4. Combined Composition and Decomposition
    tensor4 = np.zeros((5, 60, 3)) # Batch, (H W), Channels where H=10, W=6
    run_single_test("Combine B (H W) C -> (B H) W C", tensor4, 'b (h w) c -> (b h) w c', (50, 6, 3), axes_lengths={'h': 10}) # w=6 inferred
    tensor5 = np.zeros((2 * 3, 4 * 5)) # (A B), (C D)
    run_single_test("Combine (A B) (C D) -> A C B D", tensor5, '(a b) (c d) -> a c b d', (2, 4, 3, 5), axes_lengths={'a': 2, 'd': 5}) # b=3, c=4 inferred

    # 5. Identity Transformation
    run_single_test("Identity B H W C -> B H W C", tensor1, 'b h w c -> b h w c', tensor1.shape)

    # 6. Adding Dimension
    tensor6 = np.zeros((32, 32)) # H W
    run_single_test("Add Dimension H W -> 1 H W", tensor6, 'h w -> 1 h w', (1, 32, 32))
    run_single_test("Add Dimension H W -> H 1 W", tensor6, 'h w -> h 1 w', (32, 1, 32))

    # 7. Removing Dimension (Size 1)
    tensor7 = np.zeros((1, 32, 32, 5)) # 1 H W C
    run_single_test("Remove Dimension 1 H W C -> H W C", tensor7, '1 h w c -> h w c', (32, 32, 5))
    # Test removing non-leading dimension 1
    tensor8 = np.zeros((32, 1, 32, 5)) # H 1 W C
    run_single_test("Remove Dimension H 1 W C -> H W C", tensor8, 'h 1 w c -> h w c', (32, 32, 5))

    # 8. Ellipsis (...)
    tensor9 = np.zeros((10, 20, 30, 40, 50)) # A B C D E
    run_single_test("Ellipsis ... D E -> D ... E", tensor9, '... d e -> d ... e', (40, 10, 20, 30, 50))
    run_single_test("Ellipsis A B ... -> ... A B", tensor9, 'a b ... -> ... a b', (30, 40, 50, 10, 20))
    run_single_test("Ellipsis with Composition A B ... E -> A (B ...) E", tensor9, 'a b ... e -> a (b ...) e', (10, 20 * 30 * 40, 50))
    run_single_test("Ellipsis with Decomposition (A B) ... -> A B ...", np.zeros((6, 30, 40, 50)), '(a b) ... -> a b ...', (2, 3, 30, 40, 50), axes_lengths={'a': 2}) # b=3 inferred


    # --- Edge Cases ---

    # 9. Empty Array
    tensor_empty = np.zeros((0, 10))
    # Rearranging axes should still work, shape changes
    run_single_test("Edge Case: Empty Array Reorder", tensor_empty, 'z t -> t z', (10, 0))
    # Composition/Decomposition might yield shape with 0
    run_single_test("Edge Case: Empty Array Composition", tensor_empty, 'z t -> (z t)', (0,))
    run_single_test("Edge Case: Empty Array Decomposition", np.zeros((0,)), '(z t) -> z t', (0, 10), axes_lengths={'t': 10}, expect_error=None) # Result shape depends on how 0*10 is handled, expect (0, 10)


    # 10. Array with Zero Dimension Size (different from empty)
    tensor_zero_dim = np.zeros((5, 0, 10)) # A Z B
    run_single_test("Edge Case: Zero Dimension Reorder", tensor_zero_dim, 'a z b -> z a b', (0, 5, 10))
    run_single_test("Edge Case: Zero Dimension Composition", tensor_zero_dim, 'a z b -> a (z b)', (5, 0))
    # Decomposition involving zero dim
    tensor_zero_comp = np.zeros((5, 0)) # A (Z B)
    run_single_test("Edge Case: Zero Dimension Decomposition", tensor_zero_comp, 'a zb -> a z b', (5, 0, 10), axes_lengths={'z': 0, 'b': 10}) # Need z=0 specified
    run_single_test("Edge Case: Zero Dimension Decomposition (infer 0)", tensor_zero_comp, 'a zb -> a z b', (5, 0, 10), axes_lengths={'b': 10}) # Can z=0 be inferred? (Depends on einops impl.)


    # 11. Pattern Mismatch - Incorrect number of axes
    run_single_test("Error Case: Pattern Axis Number Mismatch (Input)", tensor1, 'b h w -> b c h w', None, expect_error=EinopsError) # Input pattern too short
    run_single_test("Error Case: Pattern Axis Number Mismatch (Output)", tensor1, 'b h w c d -> b c h w', None, expect_error=EinopsError) # Input pattern too long

    # 12. Pattern Mismatch - Dimension size mismatch during decomposition
    # tensor3 shape is (10, 120)
    run_single_test("Error Case: Decomposition Size Mismatch", tensor3, 'b (h w) -> b h w', None, axes_lengths={'h': 7, 'w': 10}, expect_error=EinopsError) # 7*10 != 120

    # 13. Missing axes_lengths for decomposition
    run_single_test("Error Case: Missing axis length for decomposition", tensor3, 'b (h w) -> b h w', None, expect_error=EinopsError) # h, w not provided

    # 14. Incorrect axes_lengths value (e.g., non-integer, negative) - Einops parsing should catch this
    # run_single_test("Error Case: Invalid axis length type", tensor3, 'b (h w) -> b h w', None, axes_lengths={'h': 10.5}, expect_error=(EinopsError, TypeError)) # Or maybe TypeError depending on check
    # run_single_test("Error Case: Invalid axis length value", tensor3, 'b (h w) -> b h w', None, axes_lengths={'h': -10}, expect_error=EinopsError)

    # 15. Invalid Pattern Syntax
    run_single_test("Error Case: Invalid Pattern Syntax (->)", tensor1, 'b h w c', None, expect_error=(EinopsError, ValueError)) # Missing '->'
    run_single_test("Error Case: Invalid Pattern Syntax (Repeated Axis Output)", tensor1, 'b h w c -> b b h w', None, expect_error=EinopsError)
    run_single_test("Error Case: Invalid Pattern Syntax (Unknown Char)", tensor1, 'b h w c -> b ! h w', None, expect_error=(EinopsError, ValueError))

    # 16. Removing non-singleton dimension
    run_single_test("Error Case: Remove Non-Singleton Dim", tensor1, 'b h w c -> h w c', None, expect_error=EinopsError) # b is 2, not 1


    print("\n--- Test Summary ---")
    print(f"Tests Passed: {test_passed_count}")
    print(f"Tests Failed: {test_failed_count}")
    print("--------------------\n")

    return test_failed_count == 0



if __name__ == "__main__":
    # Ensure the script can run, even if imports failed (it will raise NotImplementedError)
    try:
        # Example Usage (from Docstring)
        print("--- Running Examples ---")
        image = np.random.rand(4, 32, 32, 3) # B, H, W, C
        rearranged_image = rearrange(image, 'b h w c -> b c h w')
        print(f"Example 1: Image shape {image.shape} -> {rearranged_image.shape}")

        features = np.random.rand(10, 784) # B, Features (28*28)
        spatial_features = rearrange(features, 'b (h w) -> b h w', h=28) # w=28 inferred
        print(f"Example 2: Features shape {features.shape} -> {spatial_features.shape}")
        print("----------------------\n")

        # Run the comprehensive tests
        run_tests()
    except NotImplementedError:
         print("\nCannot run examples or tests because einops module components were not loaded correctly.")
    except Exception as e:
         print(f"\nAn unexpected error occurred during execution: {type(e).__name__}: {e}")

--- Running Examples ---
Example 1: Image shape (4, 32, 32, 3) -> (4, 3, 32, 32)
Example 2: Features shape (10, 784) -> (10, 28, 28)
----------------------

--- Running rearrange_numpy Tests ---

Running Test: Simple Reorder B H W C -> B C H W
  Input Shape: (2, 16, 32, 3)
  Pattern: 'b h w c -> b c h w'
  [PASSED] Output Shape: (2, 3, 16, 32)

Running Test: Simple Reorder B H W C -> H W C B
  Input Shape: (2, 16, 32, 3)
  Pattern: 'b h w c -> h w c b'
  [PASSED] Output Shape: (16, 32, 3, 2)

Running Test: Composition A B C D -> A (B C) D
  Input Shape: (4, 8, 5, 6)
  Pattern: 'a b c d -> a (b c) d'
  [PASSED] Output Shape: (4, 40, 6)

Running Test: Composition A B C D -> (A B C D)
  Input Shape: (4, 8, 5, 6)
  Pattern: 'a b c d -> (a b c d)'
  [PASSED] Output Shape: (960,)

Running Test: Composition A B C D -> A B (C D)
  Input Shape: (4, 8, 5, 6)
  Pattern: 'a b c d -> a b (c d)'
  [PASSED] Output Shape: (4, 8, 30)

Running Test: Decomposition B F -> B H W
  Input Shape: (10, 120)
  