In [1]:
import numpy as np
from typing import Dict, List, Tuple, Union, Optional

class EinopsError(Exception):
    """Custom exception for einops-related errors with detailed messages"""
    pass

def parse_axes(axes_str: str) -> List[Union[str, Tuple[str]]]:
    """Parse axes string into components with strict validation."""
    if not isinstance(axes_str, str):
        raise EinopsError("Axes specification must be a string")

    axes = []
    current_group = []
    in_parentheses = False
    current_token = ""
    seen_ellipsis = False

    for char in axes_str.strip():
        if char == " ":
            if in_parentheses:
                if current_token:
                    if not current_token.isalnum():
                        raise EinopsError(f"Invalid axis name '{current_token}'")
                    current_group.append(current_token)
                    current_token = ""
            else:
                if current_token:
                    if current_token == "...":
                        if seen_ellipsis:
                            raise EinopsError("Only one ellipsis allowed per pattern")
                        seen_ellipsis = True
                        axes.append(current_token)
                    else:
                        if not current_token.isalnum():
                            raise EinopsError(f"Invalid axis name '{current_token}'")
                        axes.append(current_token)
                    current_token = ""
        elif char == "(":
            if in_parentheses:
                raise EinopsError("Nested parentheses not allowed")
            in_parentheses = True
            if current_token:
                raise EinopsError(f"Unexpected token before '(': '{current_token}'")
        elif char == ")":
            if not in_parentheses:
                raise EinopsError("Unmatched ')'")
            in_parentheses = False
            if current_token:
                if not current_token.isalnum():
                    raise EinopsError(f"Invalid axis name '{current_token}'")
                current_group.append(current_token)
            if not current_group:
                raise EinopsError("Empty parentheses group")
            axes.append(tuple(current_group))
            current_group = []
            current_token = ""
        else:
            if char not in (".", "_") and not char.isalnum():
                raise EinopsError(f"Invalid character '{char}' in pattern")
            current_token += char

    if in_parentheses:
        raise EinopsError("Unmatched '(' in pattern")
    if current_token:
        if current_token == "...":
            if seen_ellipsis:
                raise EinopsError("Only one ellipsis allowed per pattern")
            axes.append(current_token)
        else:
            if not current_token.isalnum():
                raise EinopsError(f"Invalid axis name '{current_token}'")
            if in_parentheses:
                current_group.append(current_token)
                axes.append(tuple(current_group))
            else:
                axes.append(current_token)

    # Validate no numbers except '1' for repetition
    for ax in axes:
        if isinstance(ax, tuple):
            for name in ax:
                if name.isdigit() and name != "1":
                    raise EinopsError("Numeric values other than '1' are not allowed")
        elif ax != "..." and ax.isdigit() and ax != "1":
            raise EinopsError("Numeric values other than '1' are not allowed")

    return axes

def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths: int) -> np.ndarray:
    """
    Robust tensor rearrangement supporting:
    - Reshaping (a b -> b a)
    - Splitting ((h w) c -> h w c)
    - Merging (a b c -> (a b) c)
    - Repeating (a 1 c -> a b c)
    - Ellipsis (... h w -> ... (h w))

    Args:
        tensor: Input numpy array
        pattern: Rearrangement pattern string
        **axes_lengths: Required sizes for new axes

    Returns:
        Rearranged numpy array

    Raises:
        EinopsError: For any invalid pattern or shape mismatch
    """
    # Input validation
    if not isinstance(tensor, np.ndarray):
        raise EinopsError(f"Input must be numpy array, got {type(tensor)}")
    if not isinstance(pattern, str):
        raise EinopsError("Pattern must be a string")
    if not pattern:
        raise EinopsError("Pattern cannot be empty")

    # Split pattern
    try:
        input_str, output_str = [s.strip() for s in pattern.split("->", 1)]
        if not input_str or not output_str:
            raise ValueError
    except ValueError:
        raise EinopsError("Pattern must contain exactly one '->' with non-empty sides")

    # Parse axes with strict validation
    try:
        input_axes = parse_axes(input_str)
        output_axes = parse_axes(output_str)
    except EinopsError as e:
        raise EinopsError(f"Invalid pattern: {str(e)}")

    # Validate ellipsis consistency
    input_has_ellipsis = "..." in input_axes
    output_has_ellipsis = "..." in output_axes
    if input_has_ellipsis != output_has_ellipsis:
        raise EinopsError("Ellipsis must appear in both or neither input and output")

    # Process input dimensions with shape validation
    shape_map = {}
    ellipsis_dims = []
    current_dim = 0
    input_axes_flat = []

    for ax in input_axes:
        if ax == "...":
            remaining_dims = len(tensor.shape) - (len(input_axes) - 1)
            if remaining_dims < 0:
                raise EinopsError(
                    f"Tensor has {len(tensor.shape)} dims but pattern needs at least "
                    f"{len(input_axes)-1} for ellipsis"
                )
            ellipsis_dims = list(tensor.shape[current_dim:current_dim+remaining_dims])
            current_dim += remaining_dims
            input_axes_flat.extend([f"__batch_{i}" for i in range(remaining_dims)])
        elif isinstance(ax, tuple):
            if current_dim >= len(tensor.shape):
                raise EinopsError(
                    f"Tensor has only {len(tensor.shape)} dims but pattern needs more"
                )
            group_size = tensor.shape[current_dim]
            provided_product = 1
            for name in ax:
                if name == "1":
                    continue
                if name in axes_lengths:
                    provided_product *= axes_lengths[name]
                elif name.isdigit():
                    provided_product *= int(name)

            # Calculate remaining size after accounting for specified axes
            remaining_size = group_size // provided_product
            if group_size % provided_product != 0:
                raise EinopsError(
                    f"Cannot split axis of size {group_size} into product {provided_product}"
                )

            for name in ax:
                if name == "1":
                    shape_map[name] = 1
                elif name in axes_lengths:
                    shape_map[name] = axes_lengths[name]
                elif name.isdigit():
                    shape_map[name] = int(name)
                else:
                    shape_map[name] = remaining_size
                    remaining_size = 1
            current_dim += 1
            input_axes_flat.extend(ax)
        else:
            if current_dim >= len(tensor.shape):
                raise EinopsError(
                    f"Tensor has only {len(tensor.shape)} dims but pattern needs more"
                )
            if ax == "1":
                if tensor.shape[current_dim] != 1:
                    raise EinopsError(f"Axis marked as 1 must have size 1, got {tensor.shape[current_dim]}")
                shape_map[ax] = 1
            else:
                shape_map[ax] = tensor.shape[current_dim]
            current_dim += 1
            input_axes_flat.append(ax)

    # Verify all input dimensions were consumed
    if current_dim != len(tensor.shape):
        raise EinopsError(
            f"Pattern doesn't match tensor shape. Pattern uses {current_dim} dims "
            f"but tensor has {len(tensor.shape)}"
        )

    # Process output axes and validate
    output_shape = []
    output_axes_flat = []
    axis_positions = {ax: i for i, ax in enumerate(input_axes_flat) if ax != "..."}

    for ax in output_axes:
        if ax == "...":
            output_shape.extend(ellipsis_dims)
            output_axes_flat.extend([f"__batch_{i}" for i in range(len(ellipsis_dims))])
        elif isinstance(ax, tuple):
            group_size = 1
            for name in ax:
                if name == "1":
                    group_size *= 1
                elif name in shape_map:
                    group_size *= shape_map[name]
                elif name in axes_lengths:
                    group_size *= axes_lengths[name]
                else:
                    raise EinopsError(f"Unknown axis '{name}' in output pattern")
            output_shape.append(group_size)
            output_axes_flat.extend(ax)
        else:
            if ax == "1":
                output_shape.append(1)
            elif ax in shape_map:
                output_shape.append(shape_map[ax])
            elif ax in axes_lengths:
                output_shape.append(axes_lengths[ax])
            else:
                raise EinopsError(f"Unknown axis '{ax}' in output pattern")
            output_axes_flat.append(ax)

    # Handle axis repetition and transposition
    temp_tensor = tensor
    if input_axes_flat != output_axes_flat:
        # Handle repetition first
        for i, (in_ax, out_ax) in enumerate(zip(input_axes_flat, output_axes_flat)):
            if in_ax == "1" and out_ax != "1":
                if out_ax not in axes_lengths:
                    raise EinopsError(f"Must specify size for new axis '{out_ax}'")
                temp_tensor = np.repeat(temp_tensor, axes_lengths[out_ax], axis=i)

        # Then handle transposition
        try:
            transpose_order = [input_axes_flat.index(ax) for ax in output_axes_flat]
            temp_tensor = np.transpose(temp_tensor, transpose_order)
        except ValueError as e:
            raise EinopsError(f"Axis mismatch between input and output: {str(e)}")

    # Final reshape
    try:
        return temp_tensor.reshape(output_shape)
    except ValueError as e:
        raise EinopsError(f"Shape mismatch during reshape: {str(e)}")

In [7]:
import numpy as np


def run_robustness_tests():
    print("🚀 Running Robustness and Edge Case Tests")

    # 1. Axis Repetition Edge Cases
    print("\n🔁 Testing Axis Repetition Edge Cases:")
    try:
        # Case 1a: Basic repetition
        x = np.random.rand(3, 1, 5)
        result = rearrange(x, 'a 1 c -> a b c', b=4)
        assert result.shape == (3, 4, 5)
        print("✅ 1a: Basic repetition (3,1,5)->(3,4,5) works")

        # Case 1b: Non-1 dimension attempt (should fail)
        x = np.random.rand(3, 2, 5)
        try:
            rearrange(x, 'a 1 c -> a b c', b=4)
            print("❌ 1b: Failed to catch non-1 dimension")
        except EinopsError:
            print("✅ 1b: Correctly blocked non-1 dimension repetition")

        # Case 1c: Missing size specification
        try:
            rearrange(x, 'a 1 c -> a b c')
            print("❌ 1c: Failed to catch missing size")
        except EinopsError:
            print("✅ 1c: Correctly caught missing size specification")
    except Exception as e:
        print(f"❌ Axis Repetition tests failed: {str(e)}")

    # 2. Splitting Axes Edge Cases
    print("\n✂️ Testing Axis Splitting Edge Cases:")
    try:
        # Case 2a: Basic splitting
        x = np.random.rand(6, 4)
        result = rearrange(x, '(h w) c -> h w c', h=2)
        assert result.shape == (2, 3, 4)
        print("✅ 2a: Basic splitting (6,4)->(2,3,4) works")

        # Case 2b: Non-divisible split (should fail)
        try:
            rearrange(x, '(h w) c -> h w c', h=5)
            print("❌ 2b: Failed to catch non-divisible split")
        except EinopsError:
            print("✅ 2b: Correctly caught non-divisible split (6/5)")

        # Case 2c: Nested splitting
        x = np.random.rand(24, 10)
        result = rearrange(x, '((h1 h2) w) c -> h1 h2 w c', h1=2, h2=3)
        assert result.shape == (2, 3, 4, 10)
        print("✅ 2c: Nested splitting (24,10)->(2,3,4,10) works")
    except Exception as e:
        print(f"❌ Axis Splitting tests failed: {str(e)}")

    # 3. Merging Axes Edge Cases
    print("\n🔄 Testing Axis Merging Edge Cases:")
    try:
        # Case 3a: Basic merging
        x = np.random.rand(2, 3, 4)
        result = rearrange(x, 'a b c -> (a b) c')
        assert result.shape == (6, 4)
        print("✅ 3a: Basic merging (2,3,4)->(6,4) works")

        # Case 3b: Empty merging pattern
        try:
            rearrange(x, 'a b c -> () c')
            print("❌ 3b: Failed to catch empty merge")
        except EinopsError:
            print("✅ 3b: Correctly caught empty merge pattern")

        # Case 3c: Partial merging
        x = np.random.rand(2, 3, 4, 5)
        result = rearrange(x, 'a b ... -> a (b ...)')
        assert result.shape == (2, 3*4*5)
        print("✅ 3c: Partial merging (2,3,4,5)->(2,60) works")
    except Exception as e:
        print(f"❌ Axis Merging tests failed: {str(e)}")

    # 4. Batch Dimensions Edge Cases
    print("\n📦 Testing Batch Dimension Edge Cases:")
    try:
        # Case 4a: Basic ellipsis
        x = np.random.rand(2, 3, 4, 5)
        result = rearrange(x, '... h w -> ... (h w)')
        assert result.shape == (2, 3, 20)
        print("✅ 4a: Basic ellipsis (2,3,4,5)->(2,3,20) works")

        # Case 4b: Ellipsis at start
        result = rearrange(x, '... -> ...')
        assert result.shape == (2, 3, 4, 5)
        print("✅ 4b: Ellipsis-only pattern works")

        # Case 4c: Mismatched ellipsis
        try:
            rearrange(x, '... h w -> h w ...')
            print("❌ 4c: Failed to catch ellipsis mismatch")
        except EinopsError:
            print("✅ 4c: Correctly caught ellipsis position mismatch")
    except Exception as e:
        print(f"❌ Batch Dimension tests failed: {str(e)}")

    # 5. Transposition Edge Cases
    print("\n🔄 Testing Transposition Edge Cases:")
    try:
        # Case 5a: Basic transposition
        x = np.random.rand(3, 4)
        result = rearrange(x, 'h w -> w h')
        assert result.shape == (4, 3)
        print("✅ 5a: Basic transposition (3,4)->(4,3) works")

        # Case 5b: Identity transposition
        result = rearrange(x, 'h w -> h w')
        assert result.shape == (3, 4)
        print("✅ 5b: Identity transposition works")

        # Case 5c: Invalid transposition
        try:
            rearrange(x, 'h w -> x y')
            print("❌ 5c: Failed to catch invalid axes")
        except EinopsError:
            print("✅ 5c: Correctly caught invalid axis names")
    except Exception as e:
        print(f"❌ Transposition tests failed: {str(e)}")

    # 6. Extreme Edge Cases
    print("\n⚠️ Testing Extreme Edge Cases:")
    try:
        # Case 6a: Empty array
        x = np.array([]).reshape(0, 2, 3)
        result = rearrange(x, 'a b c -> c b a')
        assert result.shape == (3, 2, 0)
        print("✅ 6a: Empty array handling works")

        # Case 6b: Single element array
        x = np.array([[[42]]])
        result = rearrange(x, 'a b c -> c b a')
        assert result[0,0,0] == 42
        print("✅ 6b: Single element array works")

        # Case 6c: Invalid pattern syntax
        try:
            rearrange(x, 'invalid pattern ->')
            print("❌ 6c: Failed to catch invalid syntax")
        except EinopsError:
            print("✅ 6c: Correctly caught invalid pattern syntax")
    except Exception as e:
        print(f"❌ Extreme Edge Cases tests failed: {str(e)}")

    print("\n🎉 Robustness testing completed!")

if __name__ == "__main__":
    run_robustness_tests()

🚀 Running Robustness and Edge Case Tests

🔁 Testing Axis Repetition Edge Cases:
❌ Axis Repetition tests failed: Axis mismatch between input and output: 'b' is not in list

✂️ Testing Axis Splitting Edge Cases:
✅ 2a: Basic splitting (6,4)->(2,3,4) works
✅ 2b: Correctly caught non-divisible split (6/5)
❌ Axis Splitting tests failed: Invalid pattern: Nested parentheses not allowed

🔄 Testing Axis Merging Edge Cases:
✅ 3a: Basic merging (2,3,4)->(6,4) works
✅ 3b: Correctly caught empty merge pattern
❌ Axis Merging tests failed: Invalid pattern: Invalid axis name '...'

📦 Testing Batch Dimension Edge Cases:
✅ 4a: Basic ellipsis (2,3,4,5)->(2,3,20) works
✅ 4b: Ellipsis-only pattern works
❌ 4c: Failed to catch ellipsis mismatch

🔄 Testing Transposition Edge Cases:
✅ 5a: Basic transposition (3,4)->(4,3) works
✅ 5b: Identity transposition works
✅ 5c: Correctly caught invalid axis names

⚠️ Testing Extreme Edge Cases:
✅ 6a: Empty array handling works
✅ 6b: Single element array works
✅ 6c: Correc