In [9]:
import numpy as np
from typing import List, Set, Union, Dict, Optional, Tuple
from collections import OrderedDict

In [10]:
"""
Parser
"""

class EinopsError(Exception):
    pass

class ParsedExpression:
    """
    Parses einops pattern expression to apply operations.
    Arguments:
        expression (str): The einops pattern to parse
    Attributes:
        has_ellipsis (bool): Whether the pattern contains an ellipsis
        identifiers (Set[str]): Set of unique dimension names in the pattern
        composition (List[Union[List[str], str]]): Parsed structure of the pattern
        actual_dim_count (int): Number of actual dimensions after considering grouping
    """
    def __init__(self, expression: str):
        self.has_ellipsis = False
        self.identifiers: Set[str] = set()
        self.composition: List[Union[List[str], str]] = []
        self.actual_dim_count = 0

        # Validate expression is not empty
        if not expression or expression.isspace():
            raise EinopsError("Expression cannot be empty")

        # Handle ellipsis first
        if "..." in expression:
            if str.count(expression, "...") > 1:
                raise EinopsError("Multiple ellipsis not allowed in pattern")
            expression = expression.replace("...", "_ellipsis_")
            self.has_ellipsis = True

        tokens = []
        current_token = ""
        for char in expression:
            # Check for invalid special characters
            if not (char.isalnum() or char in "() _." or char.isspace()):
                raise EinopsError(f"Invalid character in pattern: '{char}'")

            if char in "() ":
                if current_token:
                    tokens.append(current_token)
                    current_token = ""
                if char != " ":
                    tokens.append(char)
            else:
                current_token += char
        if current_token:
            tokens.append(current_token)

        bracket_group = None

        def add_axis_name(name: str):
            if name == "_ellipsis_":
                if bracket_group is not None:
                    raise EinopsError("Ellipsis inside parenthesis not allowed")
                self.composition.append("_ellipsis_")
                return

            if not name:
                raise EinopsError("Empty axis name not allowed")
            elif name[0].isdigit():
                raise EinopsError(f"Axis name cannot start with a number: '{name}'")
            elif not all(c.isalnum() or c == '_' for c in name):
                raise EinopsError(f"Invalid axis name: {name}")

            if bracket_group is not None:
                bracket_group.append(name)
            else:
                self.composition.append([name])

            self.identifiers.add(name)

        # Process tokens
        for token in tokens:
            if token == "(":
                if bracket_group is not None:
                    raise EinopsError("Nested parentheses not allowed")
                bracket_group = []
            elif token == ")":
                if bracket_group is None:
                    raise EinopsError("Unmatched closing parenthesis")
                if not bracket_group:
                    raise EinopsError("Empty parentheses not allowed")
                self.composition.append(bracket_group)
                self.actual_dim_count += 1
                bracket_group = None
            else:
                add_axis_name(token)
                if bracket_group is None:
                    self.actual_dim_count += 1

        if bracket_group is not None:
            raise EinopsError("Unclosed parenthesis")

In [11]:
"""
Rearrange function
"""

def _get_shape_dict(tensor: np.ndarray, source_parsed: ParsedExpression, named_sizes: Dict[str, int]) -> Dict[str, int]:
    """
    Create a dictionary mapping dimension names to their sizes.
    Arguments:
        tensor (np.ndarray): The input tensor
        source_parsed (ParsedExpression): The parsed expression
        named_sizes (Dict[str, int]): Additional named dimensions and their sizes
    Returns:
        shape_dict (Dict[str, int]): Dictionary mapping dimension names to their sizes
    """
    shape_dict = named_sizes.copy()
    current_dim = 0

    for item in source_parsed.composition:
        if isinstance(item, list):
            total_size = tensor.shape[current_dim]

            if len(item) == 1:
                axis_name = item[0]
                if axis_name not in shape_dict:
                    shape_dict[axis_name] = total_size
                elif shape_dict[axis_name] != total_size:
                    raise EinopsError(f"Inconsistent size for dimension {axis_name}: got {total_size}, expected {shape_dict[axis_name]}")
                current_dim += 1
                continue

            unknown_dims = tuple(dim for dim in item if dim not in shape_dict)

            if not unknown_dims:
                product = np.prod(tuple(shape_dict[dim] for dim in item))
                if product != total_size:
                    raise EinopsError(f"Shape mismatch: {item} product {product} != {total_size}")
            elif len(unknown_dims) == 1:
                known_dims = tuple(shape_dict[dim] for dim in item if dim in shape_dict)
                known_product = np.prod(known_dims)
                if total_size % known_product != 0:
                    raise EinopsError(f"Cannot divide dimension size {total_size} by {known_product}")
                shape_dict[unknown_dims[0]] = total_size // known_product
            else:
                raise EinopsError(f"Cannot infer sizes for multiple unknown dimensions in {item}")
            current_dim += 1

        elif item == "_ellipsis_":
            remaining_dims = sum(1 for x in source_parsed.composition[source_parsed.composition.index(item)+1:]
                               if isinstance(x, list))
            ellipsis_dims = len(tensor.shape) - current_dim - remaining_dims
            if ellipsis_dims < 0:
                raise EinopsError("Pattern has more dimensions than tensor")
            current_dim += ellipsis_dims

    return shape_dict

In [14]:
def _compute_output_shape(tensor: np.ndarray, target_parsed: ParsedExpression,
                         source_parsed: ParsedExpression, shape_dict: Dict[str, int]) -> Tuple[List[int], List[int], List[int]]:
    """
    Compute shapes and permutations for the output tensor.
    Arguments:
        tensor (np.ndarray): The input tensor
        target_parsed (ParsedExpression): The parsed target expression
        source_parsed (ParsedExpression): The parsed source expression
        shape_dict (Dict[str, int]): Dictionary mapping dimension names to their sizes
    Returns:
        final_shape (List[int]): Final target shape after regrouping dimensions
        permutation (List[int]): List of indices showing how dimensions should be reordered
        intermediate_shape (List[int]): Shape after ungrouping dimensions but before permutation
    """
    intermediate_shape = []
    final_shape = []
    source_positions = {}
    permutation = []
    current_dim = 0

    # Handle source pattern
    for item in source_parsed.composition:
        if isinstance(item, list):
            if len(item) == 1:
                intermediate_shape.append(shape_dict[item[0]])
                source_positions[item[0]] = current_dim
                current_dim += 1
            else:
                for axis in item:
                    intermediate_shape.append(shape_dict[axis])
                    source_positions[axis] = current_dim
                    current_dim += 1
        elif item == "_ellipsis_":
            ellipsis_start = current_dim
            ellipsis_end = len(tensor.shape) - sum(1 for x in source_parsed.composition[source_parsed.composition.index(item)+1:] if isinstance(x, list))
            ellipsis_dims = list(range(ellipsis_start, ellipsis_end))
            intermediate_shape.extend(tensor.shape[ellipsis_start:ellipsis_end])
            current_dim = ellipsis_end

    # Build output shape and permutation
    for item in target_parsed.composition:
        if isinstance(item, list):
            if len(item) == 1:
                if item[0] not in shape_dict:
                    raise EinopsError(f"Unknown dimension: {item[0]}")
                final_shape.append(shape_dict[item[0]])
                permutation.append(source_positions[item[0]])
            else:
                size = np.prod(tuple(shape_dict[axis] for axis in item))
                final_shape.append(size)
                for axis in item:
                    permutation.append(source_positions[axis])
        elif item == "_ellipsis_":
            final_shape.extend(tensor.shape[d] for d in ellipsis_dims)
            permutation.extend(ellipsis_dims)

    # Validate final shape
    if np.prod(tuple(intermediate_shape)) != np.prod(tuple(final_shape)):
        raise EinopsError(f"Cannot reshape array of size {np.prod(tuple(intermediate_shape))} into shape {tuple(final_shape)}")

    return intermediate_shape, final_shape, permutation

In [15]:
def rearrange(tensor: np.ndarray, pattern: str, **named_sizes: Dict[str, int]) -> np.ndarray:
    """
    Rearrange tensor dimensions according to the pattern.
    Arguments:
        tensor (np.ndarray): The input tensor
        pattern (str): The einops pattern to apply
        **named_sizes (Dict[str, int]): Additional named dimensions and their sizes
    Returns:
        np.ndarray: The rearranged tensor
    """
    if '->' not in pattern:
        raise EinopsError("Pattern must contain '->'")

    source, target = pattern.split('->')
    source_parsed = ParsedExpression(source.strip())
    target_parsed = ParsedExpression(target.strip())

    # Validate dimension counts match tensor shape
    source_dims = source_parsed.actual_dim_count
    if source_parsed.has_ellipsis:
        source_dims += len(tensor.shape) - source_dims
    if source_dims != len(tensor.shape):
        if source_dims < len(tensor.shape):
            raise EinopsError("Pattern requires fewer dimensions")
        else:
            raise EinopsError("Pattern requires more dimensions")

    # Get dimension sizes
    try:
        shape_dict = _get_shape_dict(tensor, source_parsed, named_sizes)
    except ValueError as e:
        raise EinopsError(f"Cannot infer sizes: {str(e)}")

    target_identifiers = target_parsed.identifiers - {'...'}
    source_identifiers = source_parsed.identifiers - {'...'}
    unknown_dims = target_identifiers - source_identifiers
    if unknown_dims:
        raise EinopsError(f"Unknown dimension(s): {', '.join(unknown_dims)}")

    # Compute shapes and permutation
    try:
        intermediate_shape, final_shape, permutation = _compute_output_shape(tensor, target_parsed, source_parsed, shape_dict)
    except KeyError as e:
        raise EinopsError(f"Unknown dimension: {str(e)}")

    # Perform the rearrangement
    try:
        if permutation == list(range(len(permutation))):
            return tensor.reshape(final_shape)
        return tensor.reshape(intermediate_shape).transpose(permutation).reshape(final_shape)
    except ValueError as e:
        raise EinopsError(f"Shape Mismatch: {str(e)}")

In [23]:
"""
UNIT TESTS
"""

import unittest
import numpy as np
from typing import List, Tuple

class TestParsedExpression(unittest.TestCase):
    """Test cases for pattern parsing"""

    def test_basic_parsing(self):
        """Test basic pattern parsing without special cases"""
        expr = ParsedExpression('a b c')
        self.assertEqual(expr.composition, [['a'], ['b'], ['c']])
        self.assertEqual(expr.identifiers, {'a', 'b', 'c'})
        self.assertFalse(expr.has_ellipsis)

    def test_grouped_dimensions(self):
        """Test parsing of grouped dimensions"""
        expr = ParsedExpression('a (b c) d')
        self.assertEqual(expr.composition, [['a'], ['b', 'c'], ['d']])
        self.assertEqual(expr.identifiers, {'a', 'b', 'c', 'd'})

    def test_ellipsis(self):
        """Test ellipsis handling"""
        expr = ParsedExpression('a ... c')
        self.assertEqual(expr.composition, [['a'], '_ellipsis_', ['c']])
        self.assertTrue(expr.has_ellipsis)

    def test_invalid_patterns(self):
        """Test various invalid pattern cases"""
        invalid_patterns = [
            'a ((b c)) d',  # Nested parentheses
            'a (b c d',     # Unclosed parenthesis
            'a b) c',       # Unmatched closing parenthesis
            'a ... ... c',  # Multiple ellipsis
            'a (... b) c',  # Ellipsis in parentheses
            '123',          # Invalid identifier
            'a @b c',       # Invalid character
            'a () c',       # Empty parentheses
            '',            # Empty expression
            ' ',           # Whitespace only
        ]

        for pattern in invalid_patterns:
            with self.subTest(pattern=pattern):
                with self.assertRaises(EinopsError):
                    ParsedExpression(pattern)

    def test_whitespace_handling(self):
        """Test handling of various whitespace patterns"""
        patterns = [
            ('a  b   c', [['a'], ['b'], ['c']]),
            ('a(b c)d', [['a'], ['b', 'c'], ['d']]),
            (' a b c ', [['a'], ['b'], ['c']]),
        ]

        for pattern, expected in patterns:
            with self.subTest(pattern=pattern):
                expr = ParsedExpression(pattern)
                self.assertEqual(expr.composition, expected)

class TestRearrange(unittest.TestCase):
    """Test cases for tensor rearrangement"""

    def setUp(self):
        """Set up common test tensors"""
        self.tensor_2d = np.arange(6).reshape(2, 3)
        self.tensor_3d = np.arange(24).reshape(2, 3, 4)
        self.tensor_4d = np.arange(120).reshape(2, 3, 4, 5)

    def assert_shapes_equal(self, tensor: np.ndarray, expected_shape: Tuple[int, ...]):
        """Helper to assert tensor shapes match"""
        self.assertEqual(tensor.shape, expected_shape)

    def test_basic_permutations(self):
        """Test basic dimension permutations"""
        cases = [
            (self.tensor_3d, 'a b c -> b c a', (3, 4, 2)),
            (self.tensor_3d, 'a b c -> c a b', (4, 2, 3)),
            (self.tensor_4d, 'a b c d -> d a b c', (5, 2, 3, 4)),
        ]

        for tensor, pattern, expected_shape in cases:
            with self.subTest(pattern=pattern):
                result = rearrange(tensor, pattern)
                self.assert_shapes_equal(result, expected_shape)

    def test_merging_dimensions(self):
        """Test merging of dimensions"""
        cases = [
            (self.tensor_3d, 'a b c -> a (b c)', (2, 12)),
            (self.tensor_4d, 'a b c d -> (a b) (c d)', (6, 20)),
            (self.tensor_4d, 'a b c d -> a (b c d)', (2, 60)),
        ]

        for tensor, pattern, expected_shape in cases:
            with self.subTest(pattern=pattern):
                result = rearrange(tensor, pattern)
                self.assert_shapes_equal(result, expected_shape)

    def test_splitting_dimensions(self):
        """Test splitting of dimensions"""
        tensor = np.zeros((4, 6))
        cases = [
            ('(a b) c -> a b c', {'a': 2}, (2, 2, 6)),
            ('a (b c) -> a b c', {'b': 2}, (4, 2, 3)),
        ]

        for pattern, sizes, expected_shape in cases:
            with self.subTest(pattern=pattern):
                result = rearrange(tensor, pattern, **sizes)
                self.assert_shapes_equal(result, expected_shape)

    def test_error_cases(self):
        """Test various error conditions"""
        tensor = np.zeros((2, 3, 4))
        error_cases = [
            ('a b -> a b c', {}),
            ('a b c d -> a b c', {}),
            ('(a b) c -> a b c', {}),
            ('a ... ... -> a', {}),
            ('a b -> b a c', {}),
            ('a b c -> a b d', {}),
            ('a b -> (a b) c', {}),
            ('a b c -> a (b d)', {}),
            ('', {}),
            ('a b c -> ', {}),
            ('-> a b c', {}),
            ('a b c', {}),
        ]

        for pattern, sizes in error_cases:
            with self.subTest(pattern=pattern):
                with self.assertRaises(EinopsError):
                    rearrange(tensor, pattern, **sizes)

    def test_dimension_inference(self):
        """Test automatic dimension size inference"""
        tensor = np.zeros((6, 8))
        result = rearrange(tensor, '(a b) c -> a b c', b=2)
        self.assert_shapes_equal(result, (3, 2, 8))

        with self.assertRaises(EinopsError):
            rearrange(tensor, '(a b) c -> a b c')

    def test_inconsistent_sizes(self):
        """Test handling of inconsistent dimension sizes"""
        tensor = np.zeros((4, 4))
        with self.assertRaises(EinopsError):
            rearrange(tensor, 'a a -> a a', a=2)

    def test_shape_mismatch(self):
        """Test handling of shape mismatches"""
        tensor = np.zeros((4, 4))
        with self.assertRaises(EinopsError):
            rearrange(tensor, '(a b) c -> a b c', a=3)

    def test_value_preservation(self):
        """Test that values are correctly preserved after rearrangement"""
        tensor = np.arange(24).reshape(2, 3, 4)
        result = rearrange(tensor, 'a b c -> c b a')
        np.testing.assert_array_equal(
            tensor[0, 1, 2],
            result[2, 1, 0]
        )

# Run the tests directly
if __name__ == '__main__':
    unittest.main(argv=[''], verbosity=2, exit=False)

test_basic_parsing (__main__.TestParsedExpression.test_basic_parsing)
Test basic pattern parsing without special cases ... ok
test_ellipsis (__main__.TestParsedExpression.test_ellipsis)
Test ellipsis handling ... ok
test_grouped_dimensions (__main__.TestParsedExpression.test_grouped_dimensions)
Test parsing of grouped dimensions ... ok
test_invalid_patterns (__main__.TestParsedExpression.test_invalid_patterns)
Test various invalid pattern cases ... ok
test_whitespace_handling (__main__.TestParsedExpression.test_whitespace_handling)
Test handling of various whitespace patterns ... ok
test_basic_permutations (__main__.TestRearrange.test_basic_permutations)
Test basic dimension permutations ... ok
test_dimension_inference (__main__.TestRearrange.test_dimension_inference)
Test automatic dimension size inference ... ok
test_error_cases (__main__.TestRearrange.test_error_cases)
Test various error conditions ... ok
test_inconsistent_sizes (__main__.TestRearrange.test_inconsistent_sizes)
Test 