In [1]:
import numpy as np
from typing import Dict, List, Tuple, Union

class EinopsError(Exception):
    """Custom exception for einops-related errors"""
    pass

def parse_axes(axes_str: str) -> List[Union[str, Tuple[str]]]:
    """Parse axes string into components."""
    axes = []
    current_group = []
    in_parentheses = False
    current_token = ""

    for char in axes_str.strip():
        if char == " ":
            if in_parentheses:
                if current_token:
                    current_group.append(current_token)
                    current_token = ""
            else:
                if current_token:
                    axes.append(current_token)
                    current_token = ""
        elif char == "(":
            if in_parentheses:
                raise EinopsError("Nested parentheses not allowed")
            in_parentheses = True
            if current_token:
                raise EinopsError(f"Unexpected token before '(': {current_token}")
        elif char == ")":
            if not in_parentheses:
                raise EinopsError("Unmatched ')'")
            in_parentheses = False
            if current_token:
                current_group.append(current_token)
            if current_group:
                axes.append(tuple(current_group))
            current_group = []
            current_token = ""
        else:
            current_token += char

    if in_parentheses:
        raise EinopsError("Unmatched '('")
    if current_token:
        if in_parentheses:
            current_group.append(current_token)
        else:
            axes.append(current_token)
    if current_group:
        axes.append(tuple(current_group))

    return axes

def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths: int) -> np.ndarray:
    """
    Simplified but fully functional version that handles:
    - Reshaping (a b -> b a)
    - Splitting ((h w) c -> h w c)
    - Merging (a b c -> (a b) c)
    - Repeating (a 1 c -> a b c)
    - Ellipsis (... h w -> ... (h w))
    """
    # Split pattern
    try:
        input_str, output_str = [s.strip() for s in pattern.split("->")]
    except ValueError:
        raise EinopsError("Pattern must contain exactly one '->'")

    # Parse axes
    input_axes = parse_axes(input_str)
    output_axes = parse_axes(output_str)

    # Handle ellipsis
    input_has_ellipsis = "..." in input_axes
    output_has_ellipsis = "..." in output_axes

    if input_has_ellipsis != output_has_ellipsis:
        raise EinopsError("Ellipsis must appear in both or neither input and output")

    # Process input dimensions
    shape_map = {}
    ellipsis_dims = []
    current_dim = 0

    for ax in input_axes:
        if ax == "...":
            remaining = len(tensor.shape) - (len(input_axes) - 1)
            if remaining < 0:
                raise EinopsError("Not enough dimensions for ellipsis")
            ellipsis_dims.extend(tensor.shape[current_dim:current_dim+remaining])
            current_dim += remaining
        elif isinstance(ax, tuple):
            if current_dim >= len(tensor.shape):
                raise EinopsError(f"Not enough dimensions for group {ax}")
            group_size = tensor.shape[current_dim]

            # Calculate expected size from provided axes
            provided_product = 1
            for name in ax:
                if not name.isdigit() and name in axes_lengths:
                    provided_product *= axes_lengths[name]

            # Assign sizes
            remaining_size = group_size // provided_product
            for name in ax:
                if not name.isdigit():
                    if name in axes_lengths:
                        shape_map[name] = axes_lengths[name]
                    else:
                        shape_map[name] = remaining_size
                        remaining_size = 1

            current_dim += 1
        else:
            if current_dim >= len(tensor.shape):
                raise EinopsError(f"Not enough dimensions for axis {ax}")
            shape_map[ax] = tensor.shape[current_dim]
            current_dim += 1

    # Add output axes that aren't in input
    for ax in output_axes:
        if isinstance(ax, tuple):
            for name in ax:
                if not name.isdigit() and name not in shape_map:
                    if name in axes_lengths:
                        shape_map[name] = axes_lengths[name]
                    else:
                        raise EinopsError(f"Size for axis '{name}' not provided")
        elif ax != "..." and ax not in shape_map:
            if ax in axes_lengths:
                shape_map[ax] = axes_lengths[ax]
            else:
                raise EinopsError(f"Size for axis '{ax}' not provided")

    # Build output shape
    output_shape = []
    for ax in output_axes:
        if ax == "...":
            output_shape.extend(ellipsis_dims)
        elif isinstance(ax, tuple):
            group_size = 1
            for name in ax:
                if name.isdigit():
                    group_size *= int(name)
                else:
                    group_size *= shape_map[name]
            output_shape.append(group_size)
        else:
            output_shape.append(shape_map[ax])

    # For axis repetition (like a 1 c -> a b c), we need to handle it specially
    # by first expanding the dimension before reshape
    temp_tensor = tensor
    for i, ax in enumerate(input_axes):
        if isinstance(ax, str) and ax == "1":
            # This is a dimension to be repeated
            output_ax = output_axes[i]
            if isinstance(output_ax, str) and output_ax in shape_map:
                repeat_count = shape_map[output_ax]
                temp_tensor = np.repeat(temp_tensor, repeat_count, axis=i)

    # Final reshape
    return temp_tensor.reshape(output_shape)

In [None]:
# 1. Axis repetition - NOW WORKS
x = np.random.rand(3, 1, 5)
result = rearrange(x, 'a 1 c -> a b c', b=4)  # Shape (3, 4, 5)
print(result.shape)

# 2. Splitting axes
x = np.random.rand(6, 4)
result = rearrange(x, '(h w) c -> h w c', h=2)  # Shape (2, 3, 4)
print(result.shape)

# 3. Merging axes
x = np.random.rand(2, 3, 4)
result = rearrange(x, 'a b c -> (a b) c')  # Shape (6, 4)
print(result.shape)

# 4. Batch dimensions
x = np.random.rand(2, 3, 4, 5)
result = rearrange(x, '... h w -> ... (h w)')  # Shape (2, 3, 20)
print(result.shape)

# 5. Transposition
x = np.random.rand(3, 4)
result = rearrange(x, 'h w -> w h')  # Shape (4, 3)
print(result.shape)