# Implement einops from scratch

## Introduction & Assignment Overview

This notebook presents a custom implementation of the core functionality of `einops.rearrange`, developed using only Python and the NumPy library, as required by the assignment. The goal is to replicate the intuitive and powerful tensor manipulation capabilities offered by the original `einops` library without importing or relying on it.

The `einops` library, particularly its `rearrange` operation, provides a concise, readable syntax (the "pattern string") for performing complex tensor operations like reshaping, transposition, axis splitting, axis merging, and axis repetition. This implementation aims to deliver these capabilities.

**Core Requirements Addressed:**

1.  **`rearrange` Function:** Implementing the function `rearrange(tensor: np.ndarray, pattern: str, **axes_lengths) -> np.ndarray`.
2.  **Supported Operations:** Enabling reshaping, transposition, splitting, merging, and repeating of axes based on the pattern.
3.  **Pattern Parsing:** Developing a parser to interpret the input/output pattern string, including handling composite axes (parentheses) and ellipsis (`...`) for batch dimensions.
4.  **Error Handling:** Incorporating robust validation for patterns, tensor shapes, and `axes_lengths`, providing clear and informative error messages.
5.  **Performance:** Optimizing the implementation by minimizing intermediate tensor operations and efficiently parsing patterns.
6.  **Testing & Documentation:** Providing comprehensive unit tests and clear documentation (including a README.md and code comments/docstrings).

This notebook will walk through the implementation details (including helper functions, parsing logic, and the main `rearrange` function), and the extensive test suite used to verify correctness across various use cases and edge cases.


## Setup

In [234]:
import numpy as np
import functools
import keyword
import warnings
from typing import Dict, List, Tuple, Set, Optional, Union, Any, Iterable

# Create a simple test wrapper for consistent testing
def test_case(description):
    """Decorator for test cases to provide consistent formatting"""
    def decorator(func):
        def wrapper(*args, **kwargs):
            print(f"\n📋 TEST: {description}")
            print("-" * 50)
            try:
                result = func(*args, **kwargs)
                print("✅ Test Passed")
                return result
            except Exception as e:
                print(f"❌ Test Failed: {str(e)}")
                raise
        return wrapper
    return decorator

print("Environment set up successfully!")


Environment set up successfully!


# **Implementation Section**

## Exception Handling & Utilities

Implements foundational components: custom EinopsError, dimension product calculation, and handling for anonymous numeric axes in patterns.

In [235]:
class EinopsError(Exception):
    """Exception class for einops-specific errors"""
    pass

def _product(dimensions: Iterable[int]) -> int:
    """Calculate product of dimensions efficiently"""
    dim_list = list(dimensions)
    if not dim_list:
        return 1
    result = np.prod(dim_list)
    return int(result)

class AnonymousAxis:
    """Class for representing anonymous axes (like '2' or '3' in patterns)"""
    def __init__(self, value: str):
        self.value = int(value)
        if self.value <= 1:
            if self.value == 1:
                raise EinopsError("No need to create anonymous axis of length 1")
            else:
                raise EinopsError(f"Anonymous axis should have positive length, not {self.value}")

    def __repr__(self):
        return f"{self.value}-axis"

    def __str__(self):
        return str(self.value)

    def __eq__(self, other):
        if isinstance(other, AnonymousAxis):
            return self.value == other.value
        return False

    def __hash__(self):
        return hash(self.value)

## Pattern Parsing Logic

 Develops the system to parse einops pattern strings, validate syntax, handle parentheses/ellipsis, and cache parsed results.

In [236]:
_ellipsis = "…"  # Single unicode symbol

class ParsedExpression:
    """
    Non-mutable structure that contains information about one side of an einops pattern expression
    and keeps some information important for downstream operations.
    """
    def __init__(self, expression: str, allow_underscore: bool = False):
        self.has_ellipsis = False
        self.has_ellipsis_parenthesized = None
        self.identifiers = set()
        self.has_non_unitary_anonymous_axes = False
        self.composition = []

        # Handle ellipsis notation
        if "..." in expression:
            if expression.count("...") != 1:
                raise EinopsError("Expression may contain only one ellipsis (...)")
            expression = expression.replace("...", _ellipsis)
            self.has_ellipsis = True

        bracket_group = None

        def add_axis_name(x):
            if x in self.identifiers and x != _ellipsis:
                if not (allow_underscore and x == "_"):
                    raise EinopsError(f'Duplicate dimension "{x}" in expression')

            if x == _ellipsis:
                self.identifiers.add(_ellipsis)
                if bracket_group is None:
                    self.composition.append(_ellipsis)
                    self.has_ellipsis_parenthesized = False
                else:
                    bracket_group.append(_ellipsis)
                    self.has_ellipsis_parenthesized = True
            else:
                is_number = x.isdecimal()
                if is_number and int(x) == 1:
                    # Handle anonymous axis of length 1
                    if bracket_group is None:
                        self.composition.append([])
                    return

                is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore)
                if not (is_number or is_axis_name):
                    raise EinopsError(f"Invalid axis identifier: {x}\n{reason}")

                if is_number:
                    x = AnonymousAxis(x)

                self.identifiers.add(x)
                if is_number:
                    self.has_non_unitary_anonymous_axes = True

                if bracket_group is None:
                    self.composition.append([x])
                else:
                    bracket_group.append(x)

        current_identifier = None
        for char in expression:
            if char in "() ":
                if current_identifier is not None:
                    add_axis_name(current_identifier)
                current_identifier = None

                if char == "(":
                    if bracket_group is not None:
                        raise EinopsError("Nested parentheses are not allowed")
                    bracket_group = []
                elif char == ")":
                    if bracket_group is None:
                        raise EinopsError("Unbalanced parentheses")
                    self.composition.append(bracket_group)
                    bracket_group = None
            elif char.isalnum() or char in ["_", _ellipsis]:
                if current_identifier is None:
                    current_identifier = char
                else:
                    current_identifier += char
            else:
                raise EinopsError(f"Unknown character '{char}'")

        if bracket_group is not None:
            raise EinopsError(f"Unbalanced parentheses in expression: '{expression}'")

        if current_identifier is not None:
            add_axis_name(current_identifier)

    @staticmethod
    def check_axis_name_return_reason(name: str, allow_underscore: bool = False):
        """
        Check if a name is a valid axis name and return reason if not.
        """
        if not name.isidentifier():
            return False, "Not a valid Python identifier"
        elif name[0] == "_" or name[-1] == "_":
            if name == "_" and allow_underscore:
                return True, ""
            return False, "Axis name should not start or end with underscore"
        else:
            if keyword.iskeyword(name):
                warnings.warn(f"Using Python keywords as axis names is discouraged: {name}", RuntimeWarning)
            return True, ""

@functools.lru_cache(maxsize=128)
def parse_pattern(pattern: str) -> Tuple[ParsedExpression, ParsedExpression]:
    """
    Parse einops pattern string with caching for performance.
    Returns:
        Tuple of ParsedExpression objects for input and output patterns.
    """
    if '->' not in pattern:
        raise EinopsError("Pattern must contain '->'")

    left, right = pattern.split('->')
    left = left.strip()
    right = right.strip()

    return ParsedExpression(left), ParsedExpression(right)


## Shape Reconstruction Logic

Contains the logic to interpret parsed patterns against input shapes, infer dimensions, and determine necessary tensor transformations.

In [237]:
def identify_axes(parsed_expression):
    """
    Identify all axes in a ParsedExpression, excluding ellipsis.

    Args:
        parsed_expression (ParsedExpression): The parsed expression object.

    Returns:
        Set[str]: A set of axis names.
    """
    axes = set()
    # Iterate over each group in the composition
    for group in parsed_expression.composition:
        if isinstance(group, list):
            # Add each axis (as string) that is not the ellipsis
            for axis in group:
                if axis != _ellipsis:
                    axes.add(str(axis))
    return axes

def handle_ellipsis(input_dims, explicit_dims):
    """
    Calculate the number of dimensions represented by ellipsis.

    Args:
        input_dims (int): Total number of dimensions in the input tensor.
        explicit_dims (int): Number of explicitly defined dimensions in the pattern.

    Returns:
        int: Number of dimensions represented by ellipsis.
    """
    # Ellipsis represents the extra dimensions not explicitly defined
    return max(0, input_dims - explicit_dims)

def optimize_operations(init_shapes, axes_permutation, final_shapes):
    """
    Optimize operations by collapsing consecutive axes when possible.

    Args:
        init_shapes (List[int]): Initial shape after first reshape.
        axes_permutation (List[int]): Permutation of axes for transpose.
        final_shapes (List[int]): Final shape after last reshape.

    Returns:
        Tuple[List[int], List[int], List[int]]: Optimized initial shape, axes permutation, and final shape.
    """
    if init_shapes is None or axes_permutation is None:
        return init_shapes, axes_permutation, final_shapes

    # Iterate backwards over the initial shape to collapse consecutive axes
    for i in range(len(init_shapes) - 1, 0, -1):
        # Check if consecutive axes are maintained in the permutation order
        if i in axes_permutation and i - 1 in axes_permutation and \
           axes_permutation.index(i) == axes_permutation.index(i - 1) + 1:
            # Multiply the sizes together and remove the collapsed axis
            init_shapes[i - 1] *= init_shapes[i]
            init_shapes.pop(i)
            # Update permutation to reflect the removed axis
            axes_permutation = [ax if ax < i else ax - 1 for ax in axes_permutation if ax != i]

    # If the permutation is an identity, we don't need to perform any transpose
    if axes_permutation == list(range(len(axes_permutation))):
        axes_permutation = None

    return init_shapes, axes_permutation, final_shapes

def reconstruct_from_shape(input_shape, input_expr, output_expr, axes_lengths):
    """
    Reconstruct dimensions for all axes based on input shape and pattern.

    This function infers axis lengths from the input shape and user-provided axes_lengths,
    handles ellipsis replacement, composite axes, and computes the initial and final shapes
    needed for the rearrangement.

    Args:
        input_shape (Tuple[int]): The shape of the input tensor.
        input_expr (ParsedExpression): Parsed expression for input pattern.
        output_expr (ParsedExpression): Parsed expression for output pattern.
        axes_lengths (dict): User-provided axis lengths for dynamic dimensions.

    Returns:
        Tuple[List[int], List[int], List[int], Dict[int, int]]:
            - Initial shape for first reshape.
            - Permutation for transposition (if necessary).
            - Final shape for the final reshape.
            - Dictionary of any added axes (axes present in output but not input).
    """
    # Count dimensions defined explicitly (ignoring ellipsis)
    explicit_input_dims = sum(1 for group in input_expr.composition if isinstance(group, list))

    # Create mapping from axis to its length from both anonymous and user-provided axes
    axis_to_length = {}

    # Process all anonymous axes (numeric axes) and add their known sizes
    for group in input_expr.composition:
        if isinstance(group, list):
            for axis in group:
                if isinstance(axis, AnonymousAxis):
                    axis_to_length[axis] = axis.value

    # Incorporate user-provided axes_lengths into the mapping
    for axis_name, length in axes_lengths.items():
        for axis in input_expr.identifiers.union(output_expr.identifiers):
            if str(axis) == axis_name:
                axis_to_length[axis] = length
                break

    # Handle ellipsis if it exists in the input pattern
    if input_expr.has_ellipsis:
        # Find the position of ellipsis in the input composition
        ellipsis_position = next((i for i, group in enumerate(input_expr.composition) if group == _ellipsis), None)
        if ellipsis_position is None:
            raise EinopsError("Ellipsis not found in input composition despite has_ellipsis being True")

        # Calculate how many dimensions the ellipsis should represent
        ellipsis_dims = max(0, len(input_shape) - explicit_input_dims)
        # Generate placeholder axis names for the ellipsis dimensions
        ellipsis_axes = [f'_e{i}' for i in range(ellipsis_dims)]

        # Replace the ellipsis in the input composition with the generated axes
        input_composition = []
        for i, group in enumerate(input_expr.composition):
            if i == ellipsis_position:
                for ax in ellipsis_axes:
                    input_composition.append([ax])
            else:
                input_composition.append(group)

        # Process ellipsis in the output expression similarly, if present
        if output_expr.has_ellipsis:
            output_ellipsis_position = next((i for i, group in enumerate(output_expr.composition) if group == _ellipsis), None)
            if output_ellipsis_position is None:
                raise EinopsError("Ellipsis not found in output composition despite has_ellipsis being True")

            output_composition = []
            for i, group in enumerate(output_expr.composition):
                if i == output_ellipsis_position:
                    for ax in ellipsis_axes:
                        output_composition.append([ax])
                else:
                    output_composition.append(group)
        else:
            output_composition = output_expr.composition

        # Update axis lengths for the ellipsis axes based on the input shape
        for i, axis in enumerate(ellipsis_axes):
            if ellipsis_position + i < len(input_shape):
                axis_to_length[axis] = input_shape[ellipsis_position + i]
            else:
                raise EinopsError("Ellipsis expands to more dimensions than available in input shape")
    else:
        input_composition = input_expr.composition
        output_composition = output_expr.composition

    # Process composite axes to infer lengths for unknown axes within a group
    composed_axis_elements = {}
    current_input_pos = 0
    for group_idx, group in enumerate(input_composition):
        if isinstance(group, list):
            if len(group) == 1:
                # For single axes, map directly from input shape if not already set
                axis = group[0]
                if axis not in axis_to_length:
                    axis_to_length[axis] = input_shape[current_input_pos]
            elif len(group) > 1:
                # For composite axes, calculate the product and infer any unknown axis
                composed_axis_elements[group_idx] = group
                product = input_shape[current_input_pos]
                unknown_axes = []
                for axis in group:
                    if axis in axis_to_length:
                        product //= axis_to_length[axis]
                    else:
                        unknown_axes.append(axis)
                if len(unknown_axes) == 1:
                    axis_to_length[unknown_axes[0]] = product
                elif len(unknown_axes) > 1:
                    raise EinopsError(f"Cannot infer sizes for multiple unknown axes: {unknown_axes}")
            current_input_pos += 1
            if current_input_pos >= len(input_shape):
                break

    # Ensure that all axes in the output composition have known lengths
    for group in output_composition:
        if isinstance(group, list):
            for axis in group:
                if axis not in axis_to_length:
                    axis_str = str(axis)
                    if axis_str in axes_lengths:
                        axis_to_length[axis] = axes_lengths[axis_str]
                    else:
                        raise EinopsError(f"Size for axis {axis_str} is unknown. Please provide its length.")

    # Build the initial shape for the first reshape operation
    init_shape = []
    for group in input_composition:
        if isinstance(group, list):
            if len(group) == 0:  # Empty group is treated as a singleton dimension
                init_shape.append(1)
            elif len(group) == 1:
                axis = group[0]
                init_shape.append(axis_to_length[axis])
            else:
                # For composite axes, multiply the sizes of the individual axes
                init_shape.append(_product([axis_to_length[axis] for axis in group]))

    # Helper function to flatten the composition to a list of axes
    def _flatten_axes(composition):
        result = []
        for group in composition:
            if isinstance(group, list):
                result.extend(group)
        return result

    # Build a transformation recipe for complex operations (e.g., patch extraction)
    def _build_transformation_recipe():
        flat_input_axes = _flatten_axes(input_composition)
        flat_output_axes = _flatten_axes(output_composition)
        composed_axes = {}
        for i, group in enumerate(input_composition):
            if isinstance(group, list) and len(group) > 1:
                composed_axes[i] = group
        if composed_axes and len(flat_input_axes) != len(flat_output_axes):
            perm = []
            expanded_shape = []
            expanded_input_axes = []
            # Expand composite axes into their individual components
            for i, group in enumerate(input_composition):
                if isinstance(group, list):
                    if len(group) == 1:
                        expanded_input_axes.append(group[0])
                        expanded_shape.append(axis_to_length[group[0]])
                    elif len(group) > 1:
                        for axis in group:
                            expanded_input_axes.append(axis)
                            expanded_shape.append(axis_to_length[axis])
            for axis in flat_output_axes:
                if axis in expanded_input_axes:
                    perm.append(expanded_input_axes.index(axis))
            final_expanded_shape = [axis_to_length[axis] for axis in flat_output_axes]
            return expanded_shape, perm, final_expanded_shape
        return None

    recipe = _build_transformation_recipe()
    if recipe:
        expanded_shape, permutation, final_shape = recipe
        return expanded_shape, permutation, final_shape, {}

    # Standard transformation: calculate permutation and final shape for simple cases
    input_axes_flat = _flatten_axes(input_composition)
    output_axes_flat = _flatten_axes(output_composition)

    axes_permutation = None
    common_axes = [axis for axis in input_axes_flat if axis in output_axes_flat]
    if common_axes and len(common_axes) > 1:
        input_order = [axis for axis in input_axes_flat if axis in common_axes]
        output_order = [axis for axis in output_axes_flat if axis in common_axes]
        if input_order != output_order and len(input_order) == len(output_order):
            axes_permutation = []
            for axis in output_order:
                axes_permutation.append(input_axes_flat.index(axis))

    final_shape = []
    for group in output_composition:
        if isinstance(group, list):
            if len(group) == 0:
                final_shape.append(1)
            else:
                final_shape.append(_product([axis_to_length[axis] for axis in group]))

    # Identify any added axes present in output but not in input
    added_axes = {}
    for i, axis in enumerate(output_axes_flat):
        if axis not in input_axes_flat:
            added_axes[i] = axis_to_length[axis]

    return init_shape, axes_permutation, final_shape, added_axes


## Main Rearrange Function

Defines the core rearrange function, integrating parsing, shape calculation, input validation, and executing NumPy tensor operations.

In [238]:
def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths) -> np.ndarray:
    """
    Rearranges elements in a tensor according to the provided pattern.

    Parameters:
        tensor: numpy array to rearrange.
        pattern: rearrangement pattern with '->' separating input and output.
        **axes_lengths: additional specifications for dynamic dimensions.

    Returns:
        A numpy array with rearranged elements.
    """
    # Validate inputs
    if not isinstance(tensor, np.ndarray):
        raise TypeError("Input tensor must be a numpy array")
    if not isinstance(pattern, str):
        raise TypeError("Pattern must be a string")
    if '->' not in pattern:
        raise EinopsError("Pattern must contain '->' to separate input and output")


    # Parse the pattern into input and output expressions
    try:
        input_expr, output_expr = parse_pattern(pattern)
    except Exception as e:
        raise EinopsError(f"Failed to parse pattern '{pattern}': {str(e)}")

    # Validate that the input tensor's dimensions match the pattern (if no ellipsis is present)
    input_dims = len(tensor.shape)
    explicit_dims = sum(1 for group in input_expr.composition if isinstance(group, list))
    if not input_expr.has_ellipsis and input_dims != explicit_dims:
        raise EinopsError(
            f"Input tensor has {input_dims} dimensions, but pattern has {explicit_dims} dimensions"
        )

    # Reconstruct target shapes and axis permutation using the input shape and provided axes lengths
    try:
        init_shape, axes_permutation, final_shape, added_axes = reconstruct_from_shape(
            tensor.shape, input_expr, output_expr, axes_lengths
        )
    except Exception as e:
        raise EinopsError(f"Error reconstructing shape: {str(e)}")

    result = tensor

    # Reshape to the computed initial shape
    if init_shape is not None:
        try:
            result = result.reshape(init_shape)
        except ValueError as e:
            raise EinopsError(f"Cannot reshape tensor of shape {tensor.shape} to {init_shape}: {e}")

    # Apply transpose if a permutation is specified
    if axes_permutation is not None:
        try:
            result = result.transpose(axes_permutation)
        except ValueError as e:
            raise EinopsError(f"Cannot transpose tensor with axes {axes_permutation}: {e}")


    # Insert and repeat new axes as needed (for axes present in output but not in input)
    if added_axes:
        for pos, length in sorted(added_axes.items()):
            try:
                # Insert a new axis of size 1 at the given position
                new_shape = list(result.shape)
                new_shape.insert(pos, 1)
                result = result.reshape(new_shape)

                # Repeat along the new axis to reach the desired length
                repeats = [1] * len(new_shape)
                repeats[pos] = length
                result = np.tile(result, repeats)
            except Exception as e:
                raise EinopsError(f"Error when adding axis at position {pos} with length {length}: {e}")

    # Finally, reshape to the final target shape
    if final_shape is not None:
        try:
            result = result.reshape(final_shape)
        except ValueError as e:
            raise EinopsError(f"Cannot reshape tensor of shape {result.shape} to {final_shape}: {e}")

    return result


# **Test Section**

## Utility Functions and Error Classes

These tests that validate foundational components, including error handling, utility operations, and anonymous axis management, ensuring correct base functionality for reliability.

In [239]:
# Test the exception class
@test_case("EinopsError class")
def test_einops_error():
    try:
        raise EinopsError("This is a test error")
    except EinopsError as e:
        print(f"EinopsError caught successfully: {e}")
        return True

# Test the _product function
@test_case("_product utility function")
def test_product():
    assert _product([2, 3, 4]) == 24, "Product calculation is incorrect"
    assert _product([]) == 1, "Empty list should return 1"
    print(f"_product([2, 3, 4]) = {_product([2, 3, 4])}")
    print(f"_product([]) = {_product([])}")
    return True

# Test the AnonymousAxis class
@test_case("AnonymousAxis class")
def test_anonymous_axis():
    axis = AnonymousAxis("3")
    print(f"Created axis: {axis}")

    # Test error handling for invalid axes
    try:
        axis_invalid = AnonymousAxis("0")
        assert False, "Should have raised an error for value 0"
    except EinopsError as e:
        print(f"Correctly caught error for value 0: {e}")

    try:
        axis_invalid = AnonymousAxis("1")
        assert False, "Should have raised an error for value 1"
    except EinopsError as e:
        print(f"Correctly caught error for value 1: {e}")

    # Test equality
    axis2 = AnonymousAxis("3")
    axis3 = AnonymousAxis("4")
    print(f"Equality check: {axis} == {axis2} is {axis == axis2}")
    print(f"Inequality check: {axis} == {axis3} is {axis == axis3}")

    return True

# Run the tests
test_einops_error()
test_product()
test_anonymous_axis()


📋 TEST: EinopsError class
--------------------------------------------------
EinopsError caught successfully: This is a test error
✅ Test Passed

📋 TEST: _product utility function
--------------------------------------------------
_product([2, 3, 4]) = 24
_product([]) = 1
✅ Test Passed

📋 TEST: AnonymousAxis class
--------------------------------------------------
Created axis: 3
Correctly caught error for value 0: Anonymous axis should have positive length, not 0
Correctly caught error for value 1: No need to create anonymous axis of length 1
Equality check: 3 == 3 is True
Inequality check: 3 == 4 is False
✅ Test Passed


True

## Pattern Parsing and Caching
These tests focus on parsing pattern strings, handling parentheses, ellipsis, anonymous axes, caching results, and identifying input/output axis arrangements accurately correctly.

In [240]:
# Test the ParsedExpression class for basic parsing
@test_case("ParsedExpression - Basic Parsing")
def test_parsed_expression_basic():
    expr = ParsedExpression("a b c")
    print(f"Composition: {expr.composition}")
    assert len(expr.composition) == 3
    assert expr.composition[0] == ['a']
    assert expr.composition[1] == ['b']
    assert expr.composition[2] == ['c']
    return True

# Test handling composite axes with parentheses
@test_case("ParsedExpression - Parentheses")
def test_parsed_expression_parentheses():
    expr = ParsedExpression("a (b c) d")
    print(f"Composition: {expr.composition}")
    assert len(expr.composition) == 3
    assert expr.composition[0] == ['a']
    assert expr.composition[1] == ['b', 'c']
    assert expr.composition[2] == ['d']
    return True

# Test ellipsis recognition
@test_case("ParsedExpression - Ellipsis")
def test_parsed_expression_ellipsis():
    expr = ParsedExpression("... a b")
    print(f"Composition: {expr.composition}")
    assert len(expr.composition) == 3
    assert expr.composition[0] == _ellipsis
    assert expr.has_ellipsis
    assert not expr.has_ellipsis_parenthesized
    return True

# Test anonymous axes
@test_case("ParsedExpression - Anonymous Axis")
def test_parsed_expression_anonymous_axis():
    expr = ParsedExpression("a 2 b")
    print(f"Composition: {expr.composition}")
    assert len(expr.composition) == 3
    assert isinstance(expr.composition[1][0], AnonymousAxis)
    assert expr.has_non_unitary_anonymous_axes
    return True

# Test error handling in parsing
@test_case("ParsedExpression - Error Handling")
def test_parsed_expression_errors():
    # Test unbalanced parentheses
    try:
        ParsedExpression("a (b c")
        assert False, "Should have raised error for unbalanced parentheses"
    except EinopsError as e:
        print(f"✓ Caught expected error: {e}")

    # Test nested parentheses
    try:
        ParsedExpression("a (b (c d))")
        assert False, "Should have raised error for nested parentheses"
    except EinopsError as e:
        print(f"✓ Caught expected error: {e}")

    # Test invalid character
    try:
        ParsedExpression("a b@c")
        assert False, "Should have raised error for invalid character"
    except EinopsError as e:
        print(f"✓ Caught expected error: {e}")

    # Test multiple ellipses
    try:
        ParsedExpression("... a ... b")
        assert False, "Should have raised error for multiple ellipses"
    except EinopsError as e:
        print(f"✓ Caught expected error: {e}")

    return True

# Test pattern caching
@test_case("Pattern Parsing and Caching")
def test_pattern_parsing():
    pattern = "b c (h w) -> (b c) h w"
    input_expr, output_expr = parse_pattern(pattern)

    print(f"Input composition: {input_expr.composition}")
    print(f"Output composition: {output_expr.composition}")

    assert len(input_expr.composition) == 3
    assert len(output_expr.composition) == 3

    # Verify caching
    cached_input_expr, cached_output_expr = parse_pattern(pattern)
    assert input_expr is cached_input_expr, "Cache miss for input expression"
    assert output_expr is cached_output_expr, "Cache miss for output expression"
    print("Cache hit verified!")

    return True

# Test the identification of axes
@test_case("Identify axes in patterns")
def test_identify_axes():
    expr = ParsedExpression("a (b c) ...")
    axes = identify_axes(expr)
    print(f"Identified axes: {axes}")
    assert axes == {"a", "b", "c"}, "Axes identification failed"
    return True

# Test ellipsis dimension handling
@test_case("Ellipsis dimension handling")
def test_ellipsis_handling():
    input_dims = 5
    explicit_dims = 3
    ellipsis_dims = handle_ellipsis(input_dims, explicit_dims)
    print(f"Ellipsis dimensions: {ellipsis_dims}")
    assert ellipsis_dims == 2, "Ellipsis dimension handling failed"
    return True

# Run the tests
test_parsed_expression_basic()
test_parsed_expression_parentheses()
test_parsed_expression_ellipsis()
test_parsed_expression_anonymous_axis()
test_parsed_expression_errors()
test_pattern_parsing()
test_identify_axes()
test_ellipsis_handling()


📋 TEST: ParsedExpression - Basic Parsing
--------------------------------------------------
Composition: [['a'], ['b'], ['c']]
✅ Test Passed

📋 TEST: ParsedExpression - Parentheses
--------------------------------------------------
Composition: [['a'], ['b', 'c'], ['d']]
✅ Test Passed

📋 TEST: ParsedExpression - Ellipsis
--------------------------------------------------
Composition: ['…', ['a'], ['b']]
✅ Test Passed

📋 TEST: ParsedExpression - Anonymous Axis
--------------------------------------------------
Composition: [['a'], [2-axis], ['b']]
✅ Test Passed

📋 TEST: ParsedExpression - Error Handling
--------------------------------------------------
✓ Caught expected error: Unbalanced parentheses in expression: 'a (b c'
✓ Caught expected error: Nested parentheses are not allowed
✓ Caught expected error: Unknown character '@'
✓ Caught expected error: Expression may contain only one ellipsis (...)
✅ Test Passed

📋 TEST: Pattern Parsing and Caching
------------------------------------

True

## Shape Reconstruction
These tests ensure the accurate reconstruction of tensor shapes during rearrangement, merging, splitting, and preserving dimensions based on parsed pattern operations.

In [241]:
# Test shape reconstruction with a simple case
@test_case("Basic shape reconstruction")
def test_basic_shape_reconstruction():
    # Create a simple example
    input_shape = (2, 3, 4)
    input_expr = ParsedExpression("a b c")
    output_expr = ParsedExpression("c a b")
    axes_lengths = {}

    # Reconstruct shapes
    init_shape, axes_permutation, final_shape, added_axes = reconstruct_from_shape(
        input_shape, input_expr, output_expr, axes_lengths
    )

    print(f"Input shape: {input_shape}")
    print(f"Initial shape: {init_shape}")
    print(f"Axes permutation: {axes_permutation}")
    print(f"Final shape: {final_shape}")
    print(f"Added axes: {added_axes}")

    assert init_shape == [2, 3, 4], "Initial shape incorrect"
    assert axes_permutation == [2, 0, 1], "Axes permutation incorrect"
    assert final_shape == [4, 2, 3], "Final shape incorrect"

    return True

# Test shape reconstruction with a merge case
@test_case("Shape reconstruction with merged axes")
def test_merge_shape_reconstruction():
    input_shape = (2, 3, 4)
    input_expr = ParsedExpression("a b c")
    output_expr = ParsedExpression("a (b c)")
    axes_lengths = {}

    init_shape, axes_permutation, final_shape, added_axes = reconstruct_from_shape(
        input_shape, input_expr, output_expr, axes_lengths
    )

    print(f"Input shape: {input_shape}")
    print(f"Initial shape: {init_shape}")
    print(f"Axes permutation: {axes_permutation}")
    print(f"Final shape: {final_shape}")

    assert init_shape == [2, 3, 4], "Initial shape incorrect"
    assert final_shape == [2, 12], "Final shape incorrect"

    return True

# Test shape reconstruction with a split case
@test_case("Shape reconstruction with split axes")
def test_split_shape_reconstruction():
    input_shape = (2, 12)
    input_expr = ParsedExpression("a (b c)")
    output_expr = ParsedExpression("a b c")
    axes_lengths = {'b': 3, 'c': 4}

    init_shape, axes_permutation, final_shape, added_axes = reconstruct_from_shape(
        input_shape, input_expr, output_expr, axes_lengths
    )

    print(f"Input shape: {input_shape}")
    print(f"Initial shape: {init_shape}")
    print(f"Axes permutation: {axes_permutation}")
    print(f"Final shape: {final_shape}")

    assert init_shape == [2, 12], "Initial shape incorrect"
    assert final_shape == [2, 3, 4], "Final shape incorrect"

    return True

# Run the tests
test_basic_shape_reconstruction()
test_merge_shape_reconstruction()
test_split_shape_reconstruction()


📋 TEST: Basic shape reconstruction
--------------------------------------------------
Input shape: (2, 3, 4)
Initial shape: [2, 3, 4]
Axes permutation: [2, 0, 1]
Final shape: [4, 2, 3]
Added axes: {}
✅ Test Passed

📋 TEST: Shape reconstruction with merged axes
--------------------------------------------------
Input shape: (2, 3, 4)
Initial shape: [2, 3, 4]
Axes permutation: None
Final shape: [2, 12]
✅ Test Passed

📋 TEST: Shape reconstruction with split axes
--------------------------------------------------
Input shape: (2, 12)
Initial shape: [2, 12]
Axes permutation: None
Final shape: [2, 3, 4]
✅ Test Passed


True

## Tensor Operations

These tests validate fundamental tensor operations, including transposition, reshaping, identity, permutation, and input validation ensuring proper tensor transformation across various dimensions.

In [242]:
# Test the rearrange function with a simple transpose
@test_case("Basic transpose operation")
def test_basic_transpose():
    x = np.random.rand(3, 4)
    result = rearrange(x, 'h w -> w h')

    # Verify result
    assert result.shape == (4, 3)
    assert np.array_equal(result, x.T)
    print(f"Input shape: {x.shape}, Output shape: {result.shape}")
    return True

# Test input validation
@test_case("Input validation")
def test_input_validation():
    # Test non-numpy array
    try:
        rearrange([1, 2, 3], 'a -> a')
        assert False, "Should have raised TypeError for non-numpy array"
    except TypeError as e:
        print(f"✓ Correctly caught error: {e}")

    # Test invalid pattern format
    try:
        x = np.random.rand(3, 4)
        rearrange(x, 'h w')
        assert False, "Should have raised error for missing arrow"
    except EinopsError as e:
        print(f"✓ Correctly caught error: {e}")

    return True

# Test transpose operation
@test_case("Transpose operation")
def test_transpose():
    # 2D transpose
    x = np.random.rand(3, 4)
    result = rearrange(x, 'h w -> w h')

    assert result.shape == (4, 3)
    assert np.array_equal(result, x.T)
    print(f"2D transpose: {x.shape} -> {result.shape}")

    # 3D transpose
    x = np.random.rand(2, 3, 4)
    result = rearrange(x, 'a b c -> c b a')

    assert result.shape == (4, 3, 2)
    assert np.array_equal(result, np.transpose(x, (2, 1, 0)))
    print(f"3D transpose: {x.shape} -> {result.shape}")

    return True

# Test reshape operation
@test_case("Reshape operation")
def test_reshape():
    # Flatten a 2D tensor
    x = np.random.rand(3, 4)
    result = rearrange(x, 'h w -> (h w)')

    assert result.shape == (12,)
    assert np.array_equal(result, x.reshape(-1))
    print(f"Flatten: {x.shape} -> {result.shape}")

    # Reshape 3D to 2D
    x = np.random.rand(2, 3, 4)
    result = rearrange(x, 'a b c -> a (b c)')

    assert result.shape == (2, 12)
    assert np.array_equal(result, x.reshape(2, -1))
    print(f"3D to 2D: {x.shape} -> {result.shape}")

    # Reshape 1D to 2D
    x = np.random.rand(12)
    result = rearrange(x, '(h w) -> h w', h=3)

    assert result.shape == (3, 4)
    assert np.array_equal(result, x.reshape(3, 4))
    print(f"1D to 2D: {x.shape} -> {result.shape}")

    return True

# Test identity transformation
@test_case("Identity transformation")
def test_identity():
    # Simple identity
    x = np.random.rand(3, 4)
    result = rearrange(x, 'h w -> h w')

    assert result.shape == (3, 4)
    assert np.array_equal(result, x)
    print(f"2D identity: {x.shape} maintained")

    # More complex identity with reordering then restoring
    x = np.random.rand(2, 3, 4)
    intermediate = rearrange(x, 'a b c -> c b a')
    result = rearrange(intermediate, 'c b a -> a b c')

    assert result.shape == (2, 3, 4)
    assert np.array_equal(result, x)
    print(f"3D identity through multiple transformations: {x.shape} maintained")

    return True

# Test permutation operation
@test_case("Permutation operation")
def test_permutation():
    # 3D permutation
    x = np.random.rand(2, 3, 4)
    result = rearrange(x, 'a b c -> b a c')

    assert result.shape == (3, 2, 4)
    assert np.array_equal(result, np.transpose(x, (1, 0, 2)))
    print(f"3D permutation: {x.shape} -> {result.shape}")

    # 4D permutation
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, 'a b c d -> b d a c')

    assert result.shape == (3, 5, 2, 4)
    assert np.array_equal(result, np.transpose(x, (1, 3, 0, 2)))
    print(f"4D permutation: {x.shape} -> {result.shape}")

    # Verify complex permutation against numpy operations
    x = np.random.rand(1, 2, 3, 4, 5)
    result = rearrange(x, 'a b c d e -> e a c b d')

    expected = np.transpose(x, (4, 0, 2, 1, 3))
    assert result.shape == (5, 1, 3, 2, 4)
    assert np.array_equal(result, expected)
    print(f"5D complex permutation: {x.shape} -> {result.shape}")

    return True

# Run the tests
test_basic_transpose()
test_input_validation()
test_transpose()
test_reshape()
test_identity()
test_permutation()


📋 TEST: Basic transpose operation
--------------------------------------------------
Input shape: (3, 4), Output shape: (4, 3)
✅ Test Passed

📋 TEST: Input validation
--------------------------------------------------
✓ Correctly caught error: Input tensor must be a numpy array
✓ Correctly caught error: Pattern must contain '->' to separate input and output
✅ Test Passed

📋 TEST: Transpose operation
--------------------------------------------------
2D transpose: (3, 4) -> (4, 3)
3D transpose: (2, 3, 4) -> (4, 3, 2)
✅ Test Passed

📋 TEST: Reshape operation
--------------------------------------------------
Flatten: (3, 4) -> (12,)
3D to 2D: (2, 3, 4) -> (2, 12)
1D to 2D: (12,) -> (3, 4)
✅ Test Passed

📋 TEST: Identity transformation
--------------------------------------------------
2D identity: (3, 4) maintained
3D identity through multiple transformations: (2, 3, 4) maintained
✅ Test Passed

📋 TEST: Permutation operation
--------------------------------------------------
3D permutat

True

## Advanced Axis Manipulations
Tests verify complex axis manipulations: splitting, merging, combined operations, ellipsis handling, and repeating axes with precise anonymous and inferred dimensions.

In [243]:
# Test splitting axes
@test_case("Splitting axes")
def test_split_axes():
    """Test various ways of splitting axes"""

    # Basic split: (h*w, c) -> (h, w, c)
    x = np.random.rand(12, 10)
    result = rearrange(x, '(h w) c -> h w c', h=3)

    assert result.shape == (3, 4, 10)
    expected = x.reshape(3, 4, 10)
    assert np.array_equal(result, expected)
    print(f"Basic split: {x.shape} -> {result.shape}")

    # Split with anonymous axes: (h*w, c) -> (h, w, c)
    x = np.random.rand(12, 10)
    result = rearrange(x, '(3 w) c -> 3 w c')

    assert result.shape == (3, 4, 10)
    expected = x.reshape(3, 4, 10)
    assert np.array_equal(result, expected)
    print(f"Split with anonymous axes: {x.shape} -> {result.shape}")

    # Multiple splits: (b*c, h*w) -> (b, c, h, w)
    x = np.random.rand(6, 15)
    result = rearrange(x, '(b c) (h w) -> b c h w', b=2, c=3, h=3)

    assert result.shape == (2, 3, 3, 5)
    expected = x.reshape(2, 3, 3, 5)
    assert np.array_equal(result, expected)
    print(f"Multiple splits: {x.shape} -> {result.shape}")

    # Split with dimension inference: let einops infer width
    x = np.random.rand(30, 40)
    result = rearrange(x, '(b h w) c -> b h w c', b=2, h=3)

    assert result.shape == (2, 3, 5, 40)
    expected = x.reshape(2, 3, 5, 40)
    assert np.array_equal(result, expected)
    print(f"Split with dimension inference: {x.shape} -> {result.shape}")

    return True

# Test merging axes
@test_case("Merging axes")
def test_merge_axes():
    """Test various ways of merging axes"""

    # Basic merge: (h, w, c) -> (h*w, c)
    x = np.random.rand(3, 4, 10)
    result = rearrange(x, 'h w c -> (h w) c')

    assert result.shape == (12, 10)
    expected = x.reshape(12, 10)
    assert np.array_equal(result, expected)
    print(f"Basic merge: {x.shape} -> {result.shape}")

    # Multiple merges: (b, c, h, w) -> (b*c, h*w)
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, 'b c h w -> (b c) (h w)')

    assert result.shape == (6, 20)
    expected = x.reshape(6, 20)
    assert np.array_equal(result, expected)
    print(f"Multiple merges: {x.shape} -> {result.shape}")

    # Merge with reordering: (b, h, w, c) -> (c, b*h*w)
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, 'b h w c -> c (b h w)')

    assert result.shape == (5, 24)
    expected = x.transpose(3, 0, 1, 2).reshape(5, 24)
    assert np.array_equal(result, expected)
    print(f"Merge with reordering: {x.shape} -> {result.shape}")

    # Merge with singleton dimension: (b, 1, h, w) -> (b, h, w)
    x = np.random.rand(2, 1, 3, 4)
    result = rearrange(x, 'b 1 h w -> b h w')

    assert result.shape == (2, 3, 4)
    expected = x.reshape(2, 3, 4)
    assert np.array_equal(result, expected)
    print(f"Merge with singleton dimension: {x.shape} -> {result.shape}")

    return True

# Test combined splitting and merging
@test_case("Combined splitting and merging")
def test_combined_split_merge():
    """Test operations that both split and merge axes"""

    # Split and merge: (b, c*h, w) -> (b, c, h*w)
    x = np.random.rand(2, 6, 5)
    result = rearrange(x, 'b (c h) w -> b c (h w)', c=2)

    assert result.shape == (2, 2, 15)
    # Manually compute the expected result
    expected = x.reshape(2, 2, 3, 5).reshape(2, 2, 15)
    assert np.array_equal(result, expected)
    print(f"Split and merge: {x.shape} -> {result.shape}")

    # Complex rearrangement: (b, c, h, w) -> (b, c, w, h)
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, 'b c h w -> b c w h')

    assert result.shape == (2, 3, 5, 4)
    expected = x.transpose(0, 1, 3, 2)
    assert np.array_equal(result, expected)
    print(f"Complex rearrangement: {x.shape} -> {result.shape}")

    return True

# Test ellipsis handling
@test_case("Ellipsis handling")
def test_ellipsis_handling():
    """Test handling of ellipsis for batch dimensions"""

    # Basic ellipsis: apply operation to last dimensions
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, '... h w -> ... (h w)')

    assert result.shape == (2, 3, 20)
    expected = x.reshape(2, 3, 20)
    assert np.array_equal(result, expected)
    print(f"Basic ellipsis: {x.shape} -> {result.shape}")

    # Ellipsis at end: apply operation to first dimensions
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, 'b c ... -> (b c) ...')

    assert result.shape == (6, 4, 5)
    expected = x.reshape(6, 4, 5)
    assert np.array_equal(result, expected)
    print(f"Ellipsis at end: {x.shape} -> {result.shape}")


    # Ellipsis with varying batch dimensions
    arrs = [np.random.rand(1, 4, 5),
            np.random.rand(2, 4, 5),
            np.random.rand(3, 4, 5)]

    # Process each array
    results = []
    for arr in arrs:
        results.append(rearrange(arr, '... h w -> ... (h w)'))

    # Check first result
    assert results[0].shape == (1, 20)
    assert np.array_equal(results[0], arrs[0].reshape(1, 20))

    # Check second result
    assert results[1].shape == (2, 20)
    assert np.array_equal(results[1], arrs[1].reshape(2, 20))

    # Check third result
    assert results[2].shape == (3, 20)
    assert np.array_equal(results[2], arrs[2].reshape(3, 20))

    print("Ellipsis with varying batch dimensions: Correct for all inputs")

    return True

# Test repeating axes
@test_case("Repeating axes")
def test_repeating_axes():
    """Test repeating of axes"""

    # Basic repeat: (a, 1, b) -> (a, c, b)
    x = np.random.rand(3, 1, 5)
    result = rearrange(x, 'a 1 b -> a c b', c=4)

    assert result.shape == (3, 4, 5)
    expected = np.repeat(x, 4, axis=1)
    assert np.array_equal(result, expected)
    print(f"Basic repeat: {x.shape} -> {result.shape}")

    # Multiple repeats: (a, 1, 1, b) -> (a, c, d, b)
    x = np.random.rand(3, 1, 1, 5)
    result = rearrange(x, 'a 1 1 b -> a c d b', c=4, d=2)

    assert result.shape == (3, 4, 2, 5)
    expected = np.repeat(np.repeat(x, 4, axis=1), 2, axis=2)
    assert np.array_equal(result, expected)
    print(f"Multiple repeats: {x.shape} -> {result.shape}")

    # Adding a new dimension: (a, b) -> (a, 1, b)
    x = np.random.rand(3, 5)
    result = rearrange(x, 'a b -> a 1 b')

    assert result.shape == (3, 1, 5)
    expected = x.reshape(3, 1, 5)
    assert np.array_equal(result, expected)
    print(f"Adding a new dimension: {x.shape} -> {result.shape}")

    # Adding and repeating: (a, b) -> (a, c, b)
    x = np.random.rand(3, 5)
    result = rearrange(x, 'a b -> a c b', c=4)

    assert result.shape == (3, 4, 5)
    expected = x.reshape(3, 1, 5).repeat(4, axis=1)
    assert np.array_equal(result, expected)
    print(f"Adding and repeating: {x.shape} -> {result.shape}")

    return True

# Run all tests
test_split_axes()
test_merge_axes()
test_combined_split_merge()
test_ellipsis_handling()
test_repeating_axes()



📋 TEST: Splitting axes
--------------------------------------------------
Basic split: (12, 10) -> (3, 4, 10)
Split with anonymous axes: (12, 10) -> (3, 4, 10)
Multiple splits: (6, 15) -> (2, 3, 3, 5)
Split with dimension inference: (30, 40) -> (2, 3, 5, 40)
✅ Test Passed

📋 TEST: Merging axes
--------------------------------------------------
Basic merge: (3, 4, 10) -> (12, 10)
Multiple merges: (2, 3, 4, 5) -> (6, 20)
Merge with reordering: (2, 3, 4, 5) -> (5, 24)
Merge with singleton dimension: (2, 1, 3, 4) -> (2, 3, 4)
✅ Test Passed

📋 TEST: Combined splitting and merging
--------------------------------------------------
Split and merge: (2, 6, 5) -> (2, 2, 15)
Complex rearrangement: (2, 3, 4, 5) -> (2, 3, 5, 4)
✅ Test Passed

📋 TEST: Ellipsis handling
--------------------------------------------------
Basic ellipsis: (2, 3, 4, 5) -> (2, 3, 20)
Ellipsis at end: (2, 3, 4, 5) -> (6, 4, 5)
Ellipsis with varying batch dimensions: Correct for all inputs
✅ Test Passed

📋 TEST: Repeating

True

## Error and Exception Handling in Rearrangement
These tests verify proper error raising for invalid patterns, mismatched dimensions, duplicate axes, missing specifications, ensuring comprehensive informative error messages thoroughly.

In [244]:
# Test invalid pattern errors
@test_case("Invalid pattern errors")
def test_invalid_pattern_errors():
    """Test error handling for invalid pattern strings"""

    # Test missing arrow
    try:
        rearrange(np.random.rand(3, 4), 'h w')
        assert False, "Should have raised error for missing arrow"
    except EinopsError as e:
        print(f"✓ Caught expected error for missing arrow: {e}")

    # Test missing input part
    try:
        rearrange(np.random.rand(3, 4), '-> h w')
        assert False, "Should have raised error for missing input part"
    except EinopsError as e:
        print(f"✓ Caught expected error for missing input part: {e}")

    # Test missing output part
    try:
        rearrange(np.random.rand(3, 4), 'h w ->')
        assert False, "Should have raised error for missing output part"
    except EinopsError as e:
        print(f"✓ Caught expected error for missing output part: {e}")

    # Test multiple arrows
    try:
        rearrange(np.random.rand(3, 4), 'h w -> h w -> h w')
        assert False, "Should have raised error for multiple arrows"
    except EinopsError as e:
        print(f"✓ Caught expected error for multiple arrows: {e}")

    return True

# Test mismatched shape errors
@test_case("Mismatched shape errors")
def test_mismatched_shape_errors():
    """Test error handling for mismatched tensor shapes"""

    # Test extra dimension in pattern
    try:
        rearrange(np.random.rand(3, 4), 'h w d -> h w d')
        assert False, "Should have raised error for extra dimension"
    except EinopsError as e:
        print(f"✓ Caught expected error for extra dimension: {e}")

    # Test missing input dimension
    try:
        rearrange(np.random.rand(3, 4, 5), 'h w -> h w')
        assert False, "Should have raised error for missing input dimension"
    except EinopsError as e:
        print(f"✓ Caught expected error for missing input dimension: {e}")

    # This should NOT raise an error (ellipsis handles extra dimensions)
    try:
        result = rearrange(np.random.rand(3, 4, 5), '... w -> ... w')
        print("✓ Successfully used ellipsis to handle extra dimensions")
    except EinopsError as e:
        assert False, f"Should not have raised an error when using ellipsis: {e}"

    return True

# Test missing dimension specification errors
@test_case("Missing dimension specification errors")
def test_missing_dimension_specification_errors():
    """Test error handling for missing dimension specifications"""

    # Test not providing required dimensions
    try:
        rearrange(np.random.rand(12, 10), '(h w) c -> h w c')
        assert False, "Should have raised error for missing dimension specification"
    except EinopsError as e:
        print(f"✓ Caught expected error for missing dimension specification: {e}")

    # Test inconsistent dimensions
    try:
        rearrange(np.random.rand(12, 10), '(h w) c -> h w c', h=5)
        assert False, "Should have raised error for inconsistent dimensions"
    except EinopsError as e:
        print(f"✓ Caught expected error for inconsistent dimensions: {e}")

    # Test multiple unknown dimensions
    try:
        rearrange(np.random.rand(12, 10), '(h w d) c -> h w d c', h=3)
        assert False, "Should have raised error for multiple unknown dimensions"
    except EinopsError as e:
        print(f"✓ Caught expected error for multiple unknown dimensions: {e}")

    # This should work (only one unknown dimension that can be inferred)
    try:
        result = rearrange(np.random.rand(12, 10), '(h w) c -> h w c', h=3)
        print(f"✓ Successfully inferred the dimension w as {result.shape[1]}")
    except EinopsError as e:
        assert False, f"Should not have raised an error when one dimension can be inferred: {e}"

    return True

# Run all tests
test_invalid_pattern_errors()
test_mismatched_shape_errors()
test_missing_dimension_specification_errors()



📋 TEST: Invalid pattern errors
--------------------------------------------------
✓ Caught expected error for missing arrow: Pattern must contain '->' to separate input and output
✓ Caught expected error for missing input part: Input tensor has 2 dimensions, but pattern has 0 dimensions
✓ Caught expected error for missing output part: Cannot reshape tensor of shape (3, 4) to []: cannot reshape array of size 12 into shape ()
✓ Caught expected error for multiple arrows: Failed to parse pattern 'h w -> h w -> h w': too many values to unpack (expected 2)
✅ Test Passed

📋 TEST: Mismatched shape errors
--------------------------------------------------
✓ Caught expected error for extra dimension: Input tensor has 2 dimensions, but pattern has 3 dimensions
✓ Caught expected error for missing input dimension: Input tensor has 3 dimensions, but pattern has 2 dimensions
✓ Successfully used ellipsis to handle extra dimensions
✅ Test Passed

📋 TEST: Missing dimension specification errors
--------

True

## Edge and Corner Cases
These tests target edge scenarios like singleton dimensions, extreme reshapes, zero-sized dimensions, scalar conversions, and additional unexpected or unusual patterns thoroughly.

In [245]:
# Test edge cases
@test_case("Edge cases and corner cases")
def test_edge_cases():
    """Test edge cases like singleton dimensions and empty parentheses"""

    # Singleton dimensions
    x = np.random.rand(3, 1, 5)
    result = rearrange(x, 'a 1 c -> a c')
    assert result.shape == (3, 5)
    print(f"Remove singleton dimension: {x.shape} -> {result.shape}")

    # Empty parentheses for adding a singleton dimension
    x = np.random.rand(3, 5)
    result = rearrange(x, 'a c -> a () c')
    assert result.shape == (3, 1, 5)
    print(f"Add singleton dimension with empty parentheses: {x.shape} -> {result.shape}")

    # Handle tensor with all dimensions of size 1
    x = np.random.rand(1, 1, 1)
    result = rearrange(x, '1 1 1 -> 1 1')
    assert result.shape == (1, 1)
    print(f"Reshape tensor of all ones: {x.shape} -> {result.shape}")

    # Handle tensors with extreme shape differences
    x = np.random.rand(100, 1)
    result = rearrange(x, 'a 1 -> 1 a')
    assert result.shape == (1, 100)
    print(f"Reshape with extreme shape difference: {x.shape} -> {result.shape}")

    # Zero-dimensional tensor to N-dimensional tensor (scalar to tensor)
    x = np.array(5.0)  # Scalar
    try:
        result = rearrange(x, '... -> () ()')
        assert result.shape == (1, 1)
        print(f"Scalar to 2D tensor: {x.shape} -> {result.shape}")
    except Exception as e:
        print(f"Note: Scalar to tensor conversion not supported: {e}")

    return True
@test_case("Additional Edge Cases")
def test_additional_edge_cases():
    """Test additional edge cases mentioned in the prompt"""

    # 1. Duplicate Axis Names
    try:
        x = np.random.rand(3, 4)
        rearrange(x, 'a a -> b c')
        assert False, "Should have raised error for duplicate axis name 'a'"
    except EinopsError as e:
        print(f"✓ Caught expected error for duplicate axis names: {e}")

    try:
        x = np.random.rand(3, 4, 5)
        rearrange(x, 'a b c -> (a a) c')
        assert False, "Should have raised error for duplicate axis name 'a' in output"
    except EinopsError as e:
        print(f"✓ Caught expected error for duplicate axis in output composition: {e}")


    # 2. Axis Name Constraints
    try:
        x = np.random.rand(3, 4)
        rearrange(x, 'a(x) b -> b a(x)')
        assert False, "Should have raised error for invalid character in axis name"
    except EinopsError as e:
        print(f"✓ Caught expected error for invalid character in axis name: {e}")

    try:
        x = np.random.rand(3, 4)
        rearrange(x, 'a-b c -> c a-b')
        assert False, "Should have raised error for invalid character in axis name"
    except EinopsError as e:
        print(f"✓ Caught expected error for invalid character in axis name: {e}")

    # 3. Repetition Errors (repeating non-singleton dimensions)
    try:
        x = np.random.rand(2, 3, 4)
        rearrange(x, 'a b c -> a d c', d=5)
        assert False, "Should have raised error for repeating non-singleton dimension"
    except EinopsError as e:
        print(f"✓ Caught expected error for repeating non-singleton dimension: {e}")

    # 4. Zero-Dimension Tensors
    try:
        x = np.zeros((5, 0, 4))
        result = rearrange(x, 'a b c -> a c b')
        print(f"✓ Successfully rearranged tensor with zero-sized dimension: {x.shape} -> {result.shape}")
        assert result.shape == (5, 4, 0)
    except Exception as e:
        print(f"✗ Failed to handle tensor with zero-sized dimension: {e}")

    # Test with complex pattern
    try:
        x = np.zeros((3, 0, 5))
        result = rearrange(x, 'a b c -> (a c) b')
        print(f"✓ Successfully rearranged complex pattern with zero-sized dimension: {x.shape} -> {result.shape}")
        assert result.shape == (15, 0)
    except Exception as e:
        print(f"✗ Failed to handle complex pattern with zero-sized dimension: {e}")

    return True

# Run additional edge case tests
test_edge_cases()
test_additional_edge_cases()



📋 TEST: Edge cases and corner cases
--------------------------------------------------
Remove singleton dimension: (3, 1, 5) -> (3, 5)
Add singleton dimension with empty parentheses: (3, 5) -> (3, 1, 5)
Reshape tensor of all ones: (1, 1, 1) -> (1, 1)
Reshape with extreme shape difference: (100, 1) -> (1, 100)
Scalar to 2D tensor: () -> (1, 1)
✅ Test Passed

📋 TEST: Additional Edge Cases
--------------------------------------------------
✓ Caught expected error for duplicate axis names: Failed to parse pattern 'a a -> b c': Duplicate dimension "a" in expression
✓ Caught expected error for duplicate axis in output composition: Failed to parse pattern 'a b c -> (a a) c': Duplicate dimension "a" in expression
✓ Caught expected error for invalid character in axis name: Input tensor has 2 dimensions, but pattern has 3 dimensions
✓ Caught expected error for invalid character in axis name: Failed to parse pattern 'a-b c -> c a-b': Unknown character '-'
✓ Caught expected error for repeating no

True

# Usage

In [246]:
import numpy as np

# 1) Transpose
x = np.random.rand(3, 4)
result = rearrange(x, 'h w -> w h')

print(result.shape)

# 2) Split an axis
x = np.random.rand(12, 10)
result = rearrange(x, '(h w) c -> h w c', h=3)

print(result.shape)

# 3) Merge axes
x = np.random.rand(3, 4, 5)
result = rearrange(x, 'a b c -> (a b) c')

print(result.shape)

# 4) Repeat an axis
x = np.random.rand(3, 1, 5)
result = rearrange(x, 'a 1 c -> a b c', b=4)

print(result.shape)

# 5) Handle batch dimensions
x = np.random.rand(2, 3, 4, 5)
result = rearrange(x, '... h w -> ... (h w)')

print(result.shape)

(4, 3)
(3, 4, 10)
(12, 5)
(3, 4, 5)
(2, 3, 20)
