<a href="https://colab.research.google.com/github/atharva753/SarvamLLM_Assignment-2/blob/main/Optimised_einops.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
import numpy as np
import re
from typing import Dict, List, Tuple, Optional, Union, Any, Set
from functools import lru_cache


@lru_cache(maxsize=128)
def parse_pattern(pattern: str) -> Tuple[str, str]:
    """
    Parse the pattern string and split it into input and output parts.
    Uses caching to avoid re-parsing the same patterns.

    Args:
        pattern: The einops pattern string (e.g., 'b c h w -> b h w c')

    Returns:
        A tuple of (input_pattern, output_pattern)
    """
    if '->' not in pattern:
        raise ValueError("Pattern must contain '->' to separate input and output dimensions")

    parts = pattern.split('->')
    if len(parts) != 2:
        raise ValueError("Pattern must contain exactly one '->'")

    input_pattern = parts[0].strip()
    output_pattern = parts[1].strip()

    return input_pattern, output_pattern


# Regex patterns for more efficient parsing
COMPOSITE_AXIS_PATTERN = re.compile(r'\(([\w\s]+)\)')
AXIS_PATTERN = re.compile(r'(\(\w+(?:\s+\w+)*\)|\w+|\.\.\.)')


@lru_cache(maxsize=128)
def parse_axis(axis: str) -> Tuple[str, bool, List[str]]:
    """
    Parse a single axis specification, handling parentheses for merging/splitting.
    Uses regex for faster parsing and caching for reuse.

    Args:
        axis: The axis specification (e.g., '(h w)', 'b', etc.)

    Returns:
        A tuple of (axis_name, is_composite, components)
    """
    is_composite = axis.startswith('(') and axis.endswith(')')

    if is_composite:
        # Extract components using regex for better performance
        match = COMPOSITE_AXIS_PATTERN.match(axis)
        if not match:
            raise ValueError(f"Invalid composite axis format: {axis}")
        inner_content = match.group(1).strip()
        components = inner_content.split()
        return axis, True, components

    return axis, False, [axis]


@lru_cache(maxsize=64)
def parse_axes(pattern: str) -> List[Tuple[str, bool, List[str]]]:
    """
    Parse all axes in a pattern using regex for better performance.

    Args:
        pattern: A pattern string (either input or output part)

    Returns:
        A list of parsed axes
    """
    # Special handling for ellipsis
    if '...' in pattern:
        parts = pattern.split('...')
        if len(parts) != 2:
            raise ValueError("Pattern can contain at most one ellipsis (...)")

        before_ellipsis = parts[0].strip()
        after_ellipsis = parts[1].strip()

        # Use regex to extract axes
        before_axes = AXIS_PATTERN.findall(before_ellipsis) if before_ellipsis else []
        after_axes = AXIS_PATTERN.findall(after_ellipsis) if after_ellipsis else []

        parsed_before = [parse_axis(ax) for ax in before_axes]
        ellipsis_marker = [('...', False, ['...'])]
        parsed_after = [parse_axis(ax) for ax in after_axes]

        return parsed_before + ellipsis_marker + parsed_after

    # Regular case without ellipsis - use regex for better performance
    axes = AXIS_PATTERN.findall(pattern)
    return [parse_axis(ax) for ax in axes]


def _get_dim_sizes(
    input_axes: List[Tuple[str, bool, List[str]]],
    tensor_shape: Tuple[int, ...],
    axes_lengths: Dict[str, int]
) -> Tuple[Dict[str, int], List[int], int]:
    """
    Calculate dimension sizes for all axes in one pass.

    Args:
        input_axes: Parsed input axes
        tensor_shape: Shape of the input tensor
        axes_lengths: Dictionary mapping axis names to their lengths

    Returns:
        Tuple of (dimension dict, ellipsis_dims, ellipsis_len)
    """
    dims_dict = {**axes_lengths}  # Start with provided axes_lengths
    tensor_dim_idx = 0
    ellipsis_dims = []

    for axis, is_composite, components in input_axes:
        if axis == '...':
            # Calculate how many dimensions are covered by ellipsis
            remaining_explicit_dims = sum(1 for ax, _, _ in input_axes if ax != '...')
            ellipsis_size = len(tensor_shape) - remaining_explicit_dims

            # Store ellipsis dimensions
            for j in range(ellipsis_size):
                if tensor_dim_idx < len(tensor_shape):
                    ellipsis_dims.append(tensor_shape[tensor_dim_idx])
                    tensor_dim_idx += 1
        elif is_composite:
            # For composite axes, get dimension and calculate component sizes
            if tensor_dim_idx < len(tensor_shape):
                total_size = tensor_shape[tensor_dim_idx]
                tensor_dim_idx += 1

                # Calculate sizes for components
                unknown_components = [c for c in components if c not in dims_dict]

                if len(unknown_components) > 1:
                    raise ValueError(f"Multiple unknown components in axis {axis}")

                if unknown_components:
                    # Calculate size for the single unknown component
                    known_product = 1
                    for c in components:
                        if c in dims_dict:
                            known_product *= dims_dict[c]

                    if known_product == 0:
                        raise ValueError(f"Cannot infer size for {unknown_components[0]}, product of known sizes is 0")

                    if total_size % known_product != 0:
                        raise ValueError(f"Total size {total_size} not divisible by known components product {known_product}")

                    dims_dict[unknown_components[0]] = total_size // known_product
                else:
                    # Verify all components have correct sizes
                    component_product = 1
                    for c in components:
                        component_product *= dims_dict[c]

                    if component_product != total_size:
                        raise ValueError(
                            f"Product of component sizes {component_product} doesn't match axis size {total_size}"
                        )
        else:
            # For simple axes, use the tensor dimension size
            if tensor_dim_idx < len(tensor_shape):
                dims_dict[axis] = tensor_shape[tensor_dim_idx]
                tensor_dim_idx += 1

    return dims_dict, ellipsis_dims, len(ellipsis_dims)


def _compute_transpose_and_shapes(
    input_axes: List[Tuple[str, bool, List[str]]],
    output_axes: List[Tuple[str, bool, List[str]]],
    tensor_shape: Tuple[int, ...],
    dims_dict: Dict[str, int],
    ellipsis_dims: List[int],
    ellipsis_len: int
) -> Tuple[List[int], List[int], List[int]]:
    """
    Compute intermediate reshape, transpose, and final reshape in one pass.

    Returns:
        Tuple of (initial_shape, transpose_indices, final_shape)
    """
    # First calculate the initial reshape (splitting composite axes)
    initial_shape = []
    axis_positions = {}
    pos = 0

    for axis, is_composite, components in input_axes:
        if axis == '...':
            # Add ellipsis dimensions
            for i, dim in enumerate(ellipsis_dims):
                initial_shape.append(dim)
                axis_positions[f"..._{i}"] = pos
                pos += 1
        elif is_composite:
            # Add component dimensions for composite axes
            for component in components:
                if component not in dims_dict:
                    raise ValueError(f"Size for component '{component}' not found")
                initial_shape.append(dims_dict[component])
                axis_positions[component] = pos
                pos += 1
        else:
            # Add dimension for simple axis
            initial_shape.append(dims_dict[axis])
            axis_positions[axis] = pos
            pos += 1

    # Now calculate the transposition order
    transpose_indices = []
    ellipsis_positions = [axis_positions[f"..._{i}"] for i in range(ellipsis_len)]

    for axis, is_composite, components in output_axes:
        if axis == '...':
            # Add all ellipsis positions in order
            transpose_indices.extend(ellipsis_positions)
        elif is_composite:
            # Add positions for all components
            for component in components:
                if component not in axis_positions:
                    raise ValueError(f"Component '{component}' not found in input axes")
                transpose_indices.append(axis_positions[component])
        else:
            # Add position for simple axis
            if axis not in axis_positions:
                raise ValueError(f"Axis '{axis}' not found in input axes")
            transpose_indices.append(axis_positions[axis])

    # Calculate final shape (merging composite axes in output)
    final_shape = []
    i = 0

    for axis, is_composite, components in output_axes:
        if axis == '...':
            # Add all ellipsis dimensions
            final_shape.extend(ellipsis_dims)
            i += ellipsis_len
        elif is_composite:
            # Calculate product of component dimensions
            product = 1
            for _ in components:
                product *= initial_shape[transpose_indices[i]]
                i += 1
            final_shape.append(product)
        else:
            # Add dimension for simple axis
            final_shape.append(initial_shape[transpose_indices[i]])
            i += 1

    return initial_shape, transpose_indices, final_shape


def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths) -> np.ndarray:
    """
    Rearrange a tensor according to the given pattern, similar to einops.rearrange.
    Optimized version with minimal intermediate operations.

    Args:
        tensor: Input tensor (numpy array)
        pattern: Rearrangement pattern (e.g., 'b c h w -> b h w c')
        **axes_lengths: Named sizes for axes (e.g., h=32, w=32)

    Returns:
        Rearranged tensor
    """
    # Parse pattern (using cached results if pattern was seen before)
    input_pattern, output_pattern = parse_pattern(pattern)
    input_axes = parse_axes(input_pattern)
    output_axes = parse_axes(output_pattern)

    # Validate shapes and compute dimensions in one pass
    dims_dict, ellipsis_dims, ellipsis_len = _get_dim_sizes(input_axes, tensor.shape, axes_lengths)

    # Validate output axes
    all_input_components = set()
    for _, _, components in input_axes:
        all_input_components.update(components)

    for axis, _, components in output_axes:
        if axis != '...':
            for comp in components:
                if comp != '...' and comp not in all_input_components and comp not in axes_lengths:
                    raise ValueError(f"Output component '{comp}' not found in input pattern or axes_lengths")

    # Compute all shapes and transpose indices in one pass
    initial_shape, transpose_indices, final_shape = _compute_transpose_and_shapes(
        input_axes, output_axes, tensor.shape, dims_dict, ellipsis_dims, ellipsis_len
    )

    # Special case: if no reshaping or transposition is needed, return original tensor
    if tensor.shape == tuple(final_shape) and all(i == j for i, j in enumerate(transpose_indices)):
        return tensor

    # Special case: if only transposition is needed (no reshape), do it directly
    if tensor.shape == tuple(initial_shape) and tuple(final_shape) == tuple(np.array(initial_shape)[transpose_indices]):
        return np.transpose(tensor, transpose_indices)

    # Otherwise do the full reshape-transpose-reshape operation
    # Reshape to split axes
    reshaped = tensor.reshape(initial_shape)

    # Transpose to reorder dimensions
    transposed = np.transpose(reshaped, transpose_indices)

    # Reshape to merge axes
    return transposed.reshape(final_shape)


# Test functions with timing to show performance
import time

def time_test(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(f"{func.__name__} completed in {(end-start)*1000:.2f} ms")
        return result
    return wrapper

@time_test
def test_transpose():
    x = np.random.rand(3, 4)
    result = rearrange(x, 'h w -> w h')
    assert result.shape == (4, 3)
    assert np.allclose(result, x.T)
    print("Transpose test passed")

@time_test
def test_split_axis():
    x = np.random.rand(12, 10)
    result = rearrange(x, '(h w) c -> h w c', h=3, w=4)
    assert result.shape == (3, 4, 10)
    # Verify content is preserved
    reshaped = x.reshape(3, 4, 10)
    assert np.allclose(result, reshaped)
    print("Split axis test passed")

@time_test
def test_merge_axes():
    x = np.random.rand(3, 4, 5)
    result = rearrange(x, 'a b c -> (a b) c')
    assert result.shape == (12, 5)
    # Verify content is preserved
    reshaped = x.reshape(12, 5)
    assert np.allclose(result, reshaped)
    print("Merge axes test passed")

@time_test
def test_ellipsis():
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, '... h w -> ... (h w)')
    assert result.shape == (2, 3, 20)
    # Verify content is preserved
    reshaped = x.reshape(2, 3, 20)
    assert np.allclose(result, reshaped)
    print("Ellipsis test passed")

@time_test
def test_complex_pattern():
    x = np.random.rand(2, 3, 4, 5)
    result = rearrange(x, 'a b (c d) ... -> a (b c) ... d', d=1)
    assert result.shape == (2, 12, 5, 1)
    print("Complex pattern test passed")

@time_test
def test_repeated_pattern():
    """Test the performance benefit of caching."""
    x = np.random.rand(10, 32, 32, 3)
    # First call - parsing happens
    result1 = rearrange(x, 'b h w c -> b c h w')
    # Second call - should use cached parsing
    result2 = rearrange(x, 'b h w c -> b c h w')
    assert np.allclose(result1, result2)
    print("Repeated pattern test passed")

def run_all_tests():
    test_transpose()
    test_split_axis()
    test_merge_axes()
    test_ellipsis()
    test_complex_pattern()
    test_repeated_pattern()
    print("All tests passed!")

if __name__ == "__main__":
    run_all_tests()

Transpose test passed
test_transpose completed in 20.46 ms
Split axis test passed
test_split_axis completed in 0.38 ms
Merge axes test passed
test_merge_axes completed in 0.39 ms
Ellipsis test passed
test_ellipsis completed in 0.24 ms
Complex pattern test passed
test_complex_pattern completed in 0.11 ms
Repeated pattern test passed
test_repeated_pattern completed in 1.95 ms
All tests passed!
