# Einops from Scratch Implementation
## **Implementation of rearrange operation for NumPy arrays**



In [1]:
import numpy as np
import re
from typing import Dict, List, Tuple, Set

### Core Exception and Helper Functions

In [2]:
class EinopsError(ValueError):
    """Custom exception for einops operations"""
    pass

def _collect_identifiers(pattern: str) -> Set[str]:
    """Extract unique identifiers from pattern string"""
    clean_pattern = pattern.replace('(', ' ').replace(')', ' ').replace('...', ' ')
    return {token for token in clean_pattern.split() if token}

def validate_axes_lengths(pattern: str, axes_lengths: Dict[str, int]) -> None:
    """Check if provided axes lengths match pattern dimensions"""
    pattern_identifiers = _collect_identifiers(pattern)
    extra_keys = set(axes_lengths.keys()) - pattern_identifiers
    if extra_keys:
        raise EinopsError(f"Unexpected axes_lengths keys: {extra_keys}")

### Pattern Parsing

In [3]:
def parse_pattern(pattern: str) -> Tuple[List[str], List[str]]:
    """Split and tokenize input/output pattern components"""
    if '->' not in pattern:
        raise EinopsError(f"Pattern must contain '->': {pattern}")

    input_part, output_part = pattern.split('->')

    def tokenize(expr: str) -> List[str]:
        expr = expr.replace('...', ' ... ')
        return re.findall(r'\(.*?\)|\.\.\.|\S+', expr.strip())

    return tokenize(input_part), tokenize(output_part)

### Main Rearrange Implementation

In [4]:
def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths) -> np.ndarray:
    """
    Rearrange tensor dimensions according to einops pattern

    Args:
        tensor: Input numpy array
        pattern: Einops pattern string (e.g., 'b c h w -> b h w c')
        **axes_lengths: Named dimension sizes

    Returns:
        Rearranged numpy array
    """
    # Initial validation
    validate_axes_lengths(pattern, axes_lengths)
    input_dims, output_dims = parse_pattern(pattern)

    # Pattern validation
    if any('(' in dim and not dim.endswith(')') for dim in input_dims + output_dims):
        raise EinopsError("Mismatched parentheses in pattern")
    if sum(1 for dim in input_dims if dim == '...') > 1:
        raise EinopsError("Multiple ellipses in input pattern")
    if ('...' in input_dims) != ('...' in output_dims):
        raise EinopsError("Ellipsis must appear in both input and output patterns if used")

    # Setup dimension mappings
    shape = tensor.shape
    ndim = len(shape)
    input_ellipsis_idx = input_dims.index('...') if '...' in input_dims else -1
    output_ellipsis_idx = output_dims.index('...') if '...' in output_dims else -1


# Ellipsis Handling & Dimension Processing


    # Process ellipsis dimensions
    if input_ellipsis_idx >= 0:
        n_explicit = len([d for d in input_dims if d != '...'])
        n_ellipsis = ndim - n_explicit
        if n_ellipsis < 0:
            raise EinopsError("Not enough dimensions in tensor")
        full_input_dims = (
            input_dims[:input_ellipsis_idx] +
            [f'_d{i}' for i in range(n_ellipsis)] +
            input_dims[input_ellipsis_idx + 1:]
        )
        full_output_dims = (
            output_dims[:output_ellipsis_idx] +
            [f'_d{i}' for i in range(n_ellipsis)] +
            output_dims[output_ellipsis_idx + 1:]
        )
    else:
        full_input_dims = input_dims
        full_output_dims = output_dims
        n_ellipsis = 0

  # Initialize tracking variables
    dim_sizes = {}
    composite_dims = {}
    pos = 0
    input_dim_to_pos = {}

    # Process input dimensions
    for dim in full_input_dims:
        if dim.startswith('_d'):
            dim_sizes[dim] = shape[pos]
            input_dim_to_pos[dim] = pos
            pos += 1
        elif '(' in dim:
            components = dim.strip('()').split()
            total_size = shape[pos]
            composite_dims[dim] = components

            known_size = 1
            unknown = None
            for comp in components:
                if comp in axes_lengths:
                    known_size *= axes_lengths[comp]
                    dim_sizes[comp] = axes_lengths[comp]
                else:
                    if unknown is not None:
                        raise EinopsError(f"Multiple unknown dimensions in {dim}")
                    unknown = comp

            if unknown is not None:
                if total_size % known_size != 0:
                    raise EinopsError(f"Cannot divide dimension size {total_size} by {known_size}")
                dim_sizes[unknown] = total_size // known_size
            elif total_size != known_size:
                raise EinopsError(f"Size mismatch for {dim}: expected {known_size}, got {total_size}")

            for comp in components:
                input_dim_to_pos[comp] = pos
            pos += 1
        else:
            if pos >= len(shape):
                raise EinopsError("Not enough dimensions in tensor")
            if dim in axes_lengths:
                if shape[pos] != axes_lengths[dim] and shape[pos] != 1:
                    raise EinopsError(f"Dimension size mismatch for {dim}")
                dim_sizes[dim] = axes_lengths[dim]
            else:
                dim_sizes[dim] = shape[pos]
            input_dim_to_pos[dim] = pos
            pos += 1


#Transposition & Shape Calculation and Final Operations & Output


    if (not any('(' in d for d in full_input_dims + full_output_dims) and
        not any(d not in input_dim_to_pos for d in full_output_dims if d != '...')):
        permutation = []
        used_dims = set()

        for dim in full_output_dims:
            if dim != '...' and dim not in used_dims:
                permutation.append(input_dim_to_pos[dim])
                used_dims.add(dim)

        current = np.transpose(tensor, permutation) if len(permutation) > 1 else tensor
    else:
        current = tensor

    # Calculate output shape
    final_shape = []
    repeat_indices = []

    for i, dim in enumerate(full_output_dims):
        if dim == '...':
            for j in range(n_ellipsis):
                final_shape.append(dim_sizes[f'_d{j}'])
        elif '(' in dim:
            components = dim.strip('()').split()
            size = 1
            for comp in components:
                size *= dim_sizes[comp]
            final_shape.append(size)
        else:
            if dim in axes_lengths:
                if dim not in dim_sizes or dim_sizes[dim] == 1:
                    final_shape.append(axes_lengths[dim])
                    repeat_indices.append((len(final_shape) - 1, axes_lengths[dim]))
                else:
                    final_shape.append(dim_sizes[dim])
            else:
                final_shape.append(dim_sizes[dim])

    if repeat_indices:
        sorted_repeat_indices = sorted(repeat_indices, key=lambda x: -x[0])
        for idx, _ in sorted_repeat_indices:
            current = np.expand_dims(current, axis=idx)

        for idx, size in sorted_repeat_indices:
            current = np.repeat(np.expand_dims(current, axis=idx), size, axis=idx)

    # Final reshape
    try:
        return current.reshape(final_shape)
    except ValueError as e:
        raise EinopsError(f"Cannot reshape tensor to shape {final_shape}: {str(e)}")



# Comprehensive Test Suite

- Basic Operations Tests

In [5]:
def test_basic_transpose():
    x = np.random.rand(95, 24)
    result = rearrange(x, 'h w -> w h')
    assert result.shape == (24, 95), f"Expected (24, 95), got {result.shape}"

def test_basic_transposition_with_values():
    x = np.array([[1, 2], [3, 4]])
    result = rearrange(x, 'a b -> b a')
    expected = np.array([[1, 3], [2, 4]])
    if not np.array_equal(result, expected):
        print("Test failed!")
        print("Expected output:\n", expected)
        print("Actual output:\n", result)
    assert np.array_equal(result, expected), "Value-based transposition failed"

def test_identity_operation():
    x = np.random.rand(2, 3, 4)
    result = rearrange(x, 'a b c -> a b c')
    assert np.array_equal(x, result), "Identity operation failed"




- Dimension Splitting
- Merging Tests



In [6]:
def test_axis_splitting():
    x = np.random.rand(10, 12)
    result = rearrange(x, 'h (w c) -> h w c', c=3)
    assert result.shape == (10, 4, 3), f"Expected (10, 4, 3), got {result.shape}"

def test_axis_merging():
    x = np.random.rand(3, 4, 5)
    result = rearrange(x, 'a b c -> (a b) c')
    assert result.shape == (12, 5), f"Expected (12, 5), got {result.shape}"

def test_nested_splitting():
    x = np.arange(2*3*4*5).reshape(2, 3, 4, 5)
    result = rearrange(x, 'b (c1 c2) h w -> b c1 c2 h w', c1=3)
    assert result.shape == (2, 3, 1, 4, 5), "Nested splitting failed"

def test_complex_merge_split():
    x = np.random.rand(24, 10)
    result = rearrange(rearrange(x, '(a b) c -> (b a) c', a=4), '(a b) c -> a b c', a=6)
    assert result.shape == (6, 4, 10), "Merge-split combination failed"

def test_implicit_anonymous_dims():
    x = np.random.rand(24, 10)
    result = rearrange(x, '(a b c) d -> a b c d', a=2, c=3)
    assert result.shape == (2, 4, 3, 10), "Implicit dimension inference failed"

- Ellipsis Tests


In [7]:
def test_basic_ellipsis():
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, '... h w -> ... (h w)')
    assert result.shape == (2, 3, 20), f"Expected (2, 3, 20), got {result.shape}"

def test_leading_ellipsis():
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, '... h w -> h w ...')
    assert result.shape == (4, 5, 2, 3), "Leading ellipsis handling failed"

def test_ellipsis_edge_cases():
    # Single dimension
    x = np.random.rand(5)
    result = rearrange(x, '... -> ...')
    assert np.array_equal(x, result), "Single dimension ellipsis failed"

    # Multiple dimensions
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, '... -> ...')
    assert np.array_equal(x, result), "Multiple dimension ellipsis failed"

def test_valid_transpose_with_ellipsis():
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, 'a b ... -> b a ...')
    assert result.shape == (3, 2, 4, 5), "Transpose with ellipsis failed"

def test_parametric_ellipsis():
    x = np.random.rand(10, 20, 30, 40)
    result = rearrange(x, '... (h h2) (w w2) c -> ... h w (c h2 w2)', h2=2, w2=3)
    assert result.shape == (10, 10, 10, 40*2*3), "Parametric ellipsis failed"


 - Repeating Singleton Tests
 - Advanced and Performance Tests


In [8]:
def test_repeating():
    x = np.random.rand(3, 1, 5)
    result = rearrange(x, 'a 1 c -> a b c', b=4)
    assert result.shape == (3, 4, 5), f"Expected (3, 4, 5), got {result.shape}"

def test_implicit_repeat():
    x = np.random.rand(3, 1, 5)
    result = rearrange(x, 'a 1 c -> a c 1')
    assert np.array_equal(result, x.transpose(0, 2, 1)), "Implicit repeat failed"

def test_singleton_dimensions():
    x = np.random.rand(3, 1, 5)
    result = rearrange(x, 'a 1 c -> a c 1 1')
    assert result.shape == (3, 5, 1, 1), "Singleton dimension handling failed"

# Advanced and Performance Tests
def test_advanced_reshaping():
    x = np.random.rand(2, 32, 32, 3)
    result = rearrange(x, 'b (h h2) (w w2) c -> b h w (c h2 w2)', h2=2, w2=2)
    assert result.shape == (2, 16, 16, 12), "Advanced reshaping failed"

def test_performance_sensitive_operations():
    x = np.ones((1000, 1000))
    result = rearrange(x, 'h w -> w h')
    assert result.shape == (1000, 1000), "Large tensor operation failed"



-  Error Handling Tests
-  Edge Cases Tests


In [9]:
# Error Handling Tests
def test_basic_errors():
    try:
        x = np.random.rand(3, 4)
        rearrange(x, 'a b c -> c b a')
    except EinopsError:
        pass
    else:
        assert False, "Failed to catch dimension mismatch"

    try:
        x = np.random.rand(3, 4)
        rearrange(x, 'a (b c) -> a b c', b=3)
    except EinopsError:
        pass
    else:
        assert False, "Failed to catch missing dimension"

def test_error_insufficient_dims():
    try:
        x = np.random.rand(3, 4)
        rearrange(x, 'a b c -> c b a')
    except EinopsError as e:
        assert "Not enough dimensions" in str(e), "Wrong error message"
    else:
        assert False, "Failed to catch insufficient dimensions"

def test_error_multiple_unknowns():
    try:
        x = np.random.rand(12, 10)
        rearrange(x, '(a b c) d -> a b c d')
    except EinopsError as e:
        assert "Multiple unknown dimensions" in str(e), "Wrong error message"
    else:
        assert False, "Failed to catch multiple unknowns"


# Edge Cases Tests

def test_large_dimensions():
    # Test with very large dimensions
    try:
        x = np.random.rand(1000000, 2)
        result = rearrange(x, 'h w -> w h')
        assert result.shape == (2, 1000000), "Shape mismatch for large dimensions"
    except MemoryError:
        print("Warning: System memory insufficient for large dimension test")

def test_different_dtypes():
    # Test integer types
    x = np.random.randint(0, 100, size=(3, 4), dtype=np.int32)
    result = rearrange(x, 'h w -> w h')
    assert result.dtype == np.int32, "Should preserve integer dtype"

    # Test floating point types
    x = np.random.rand(3, 4).astype(np.float32)
    result = rearrange(x, 'h w -> w h')
    assert result.dtype == np.float32, "Should preserve float dtype"

    # Test boolean type
    x = np.random.choice([True, False], size=(3, 4))
    result = rearrange(x, 'h w -> w h')
    assert result.dtype == bool, "Should preserve boolean dtype"

def test_zero_dim_arrays():
    # Test empty array
    x = np.array([])
    try:
        result = rearrange(x, '... -> ...')
        assert result.shape == x.shape, "Shape mismatch for empty array"
    except Exception as e:
        assert False, f"Should handle empty arrays, got error: {str(e)}"

    # Test zero-dimensional array
    x = np.array(5)
    result = rearrange(x, '... -> ...')
    assert result.shape == (), "Shape mismatch for scalar array"


### **Running all tests**

In [10]:
def run_all_tests():
    # Get functions from global namespace (module level)
    test_functions = [obj for name, obj in globals().items()
                     if name.startswith('test_') and callable(obj)]

    for test in test_functions:
        # ... rest of the code ...
        try:
            test()
            print(f"✅ {test.__name__} passed")
        except AssertionError as e:
            print(f"❌ {test.__name__} failed: {str(e)}")
        except Exception as e:
            print(f"❌ {test.__name__} failed with unexpected error: {str(e)}")

if __name__ == "__main__":
    run_all_tests()

✅ test_basic_transpose passed
✅ test_basic_transposition_with_values passed
✅ test_identity_operation passed
✅ test_axis_splitting passed
✅ test_axis_merging passed
✅ test_nested_splitting passed
✅ test_complex_merge_split passed
✅ test_implicit_anonymous_dims passed
✅ test_basic_ellipsis passed
✅ test_leading_ellipsis passed
✅ test_ellipsis_edge_cases passed
✅ test_valid_transpose_with_ellipsis passed
✅ test_parametric_ellipsis passed
✅ test_repeating passed
✅ test_implicit_repeat passed
✅ test_singleton_dimensions passed
✅ test_advanced_reshaping passed
✅ test_performance_sensitive_operations passed
✅ test_basic_errors passed
✅ test_error_insufficient_dims passed
✅ test_error_multiple_unknowns passed
✅ test_large_dimensions passed
✅ test_different_dtypes passed
✅ test_zero_dim_arrays passed
