<a href="https://colab.research.google.com/github/Ayushichadha/Einops-implementation-/blob/main/Sarvam_Assignment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import numpy as np
import re
from functools import reduce
import operator
from typing import List, Tuple, Dict, Union, Any, Optional

def rearrange(tensor: np.ndarray, pattern: str, debug: bool = False, **axes_lengths) -> np.ndarray:
    """
    Rearranges the tensor according to an Einstein notation–inspired pattern.

    Supports:
      • Transposition (e.g., "h w -> w h")
      • Splitting axes (e.g., "(h w) c -> h w c", with h provided)
      • Merging axes (e.g., "a b c -> (a b) c")
      • Repeating axes (e.g., "a 1 c -> a b c", with b provided)
      • Ellipsis for batch dimensions (e.g., "... h w -> ... (h w)")

    Parameters:
      tensor (np.ndarray): The input tensor.
      pattern (str): The rearrangement pattern string.
      debug (bool): If True, print debug information.
      **axes_lengths: Additional dimensions needed for splitting or repeating.

    Returns:
      np.ndarray: The rearranged tensor.

    Raises:
      ValueError: If the pattern format is invalid or dimensions don't match.
      IndexError: If the tensor doesn't have enough dimensions for the pattern.
    """
    if '->' not in pattern:
        raise ValueError("Pattern must contain '->' to separate input and output axes")

    input_pattern, output_pattern = map(str.strip, pattern.split('->'))
    input_tokens = parse_tokens(input_pattern)
    output_tokens = parse_tokens(output_pattern)

    # Check for duplicate tokens in output
    flat_output = flatten_list(output_tokens)
    seen = set()
    for token in flat_output:
        if token in seen:
            raise ValueError(f"Duplicate token '{token}' found in output pattern")
        seen.add(token)

    has_ellipsis = '...' in input_pattern
    explicit_dims = sum(1 for token in input_tokens if token != '...')
    ellipsis_dims = len(tensor.shape) - explicit_dims if has_ellipsis else 0

    tensor, input_mapping = process_input_groups(tensor, input_tokens, axes_lengths, ellipsis_dims)
    flat_input = flatten_list(input_tokens)

    if all(not isinstance(tok, list) for tok in output_tokens) and set(flat_input) == set(flat_output) and len(flat_input) == len(flat_output):
        perm = [flat_input.index(token) for token in flat_output]
        tensor = np.transpose(tensor, axes=perm)
    else:
        output_shape = compute_output_shape(output_tokens, input_mapping, axes_lengths, ellipsis_dims)

        if np.prod(tensor.shape) != np.prod(output_shape):
            raise ValueError(f"Total elements mismatch: input {tensor.shape} ({np.prod(tensor.shape)}) vs output {output_shape} ({np.prod(output_shape)})")

        if debug:
            print(f"[Debug] Input pattern tokens: {input_tokens}")
            print(f"[Debug] Output pattern tokens: {output_tokens}")
            print(f"[Debug] Inferred input mapping: {input_mapping}")
            print(f"[Debug] Reshaping to: {output_shape}")

        tensor = tensor.reshape(output_shape)

    tensor = process_repeating_axes(tensor, flat_input, flat_output, axes_lengths)
    return tensor

# ------------------- Helper Functions -------------------

def parse_tokens(pattern: str) -> List[Union[str, List[str]]]:
    """
    Parse a pattern string into tokens, handling groups in parentheses.

    Args:
        pattern (str): The pattern string to parse.

    Returns:
        List[Union[str, List[str]]]: A list of tokens, where each token is either a string or a list of tokens (for grouped dimensions).

    Raises:
        ValueError: If there are mismatched parentheses in the pattern.
    """
    tokens = []
    token = ""
    i = 0
    while i < len(pattern):
        c = pattern[i]
        if c == '(':
            if token:
                tokens.append(token)
                token = ""
            group, j = parse_group(pattern, i)
            tokens.append(group)
            i = j
        elif c.isspace():
            if token:
                tokens.append(token)
                token = ""
            i += 1
        else:
            token += c
            i += 1
    if token:
        tokens.append(token)
    return tokens

def parse_group(pattern: str, start: int) -> Tuple[List[str], int]:
    """
    Parse a group enclosed in parentheses within a pattern string.

    Args:
        pattern (str): The full pattern string.
        start (int): The starting index of the opening parenthesis.

    Returns:
        Tuple[List[str], int]: A tuple containing (tokens_list, end_index), where tokens_list is the list of
                               parsed tokens inside the parentheses and end_index is the index after the closing parenthesis.

    Raises:
        ValueError: If there are mismatched parentheses in the pattern.
    """
    assert pattern[start] == '('
    tokens = []
    token = ""
    i = start + 1
    while i < len(pattern):
        c = pattern[i]
        if c == ')':
            if token:
                tokens.append(token)
            return tokens, i + 1
        elif c.isspace():
            if token:
                tokens.append(token)
                token = ""
            i += 1
        else:
            token += c
            i += 1
    raise ValueError("Mismatched parentheses in pattern")

def flatten_list(tokens: List[Any]) -> List[str]:
    """
    Flatten a nested list of tokens into a single list.

    Args:
        tokens (List[Any]): A list of tokens, potentially containing nested lists.

    Returns:
        List[str]: A flattened list where all nested lists have been expanded.
    """
    flat = []
    for token in tokens:
        if isinstance(token, list):
            flat.extend(flatten_list(token))
        else:
            flat.append(token)
    return flat

def process_input_groups(tensor: np.ndarray,
                         tokens: List[Union[str, List[str]]],
                         axes_lengths: Dict[str, int],
                         ellipsis_dims: int) -> Tuple[np.ndarray, Dict[str, int]]:
    """
    Process input tokens to reshape the tensor according to specified groupings.

    Args:
        tensor (np.ndarray): The input tensor.
        tokens (List[Union[str, List[str]]]): The parsed input pattern tokens.
        axes_lengths (Dict[str, int]): Dictionary of known dimension sizes.
        ellipsis_dims (int): Number of dimensions represented by ellipsis.

    Returns:
        Tuple[np.ndarray, Dict[str, int]]: A tuple containing (reshaped_tensor, dimension_mapping), where dimension_mapping
                                           is a dictionary mapping token names to their sizes.

    Raises:
        IndexError: If the tensor doesn't have enough dimensions for the pattern.
        ValueError: If dimensions cannot be split correctly or not all dimensions are used.
    """
    shape = list(tensor.shape)
    new_shape = []
    axis_index = 0
    mapping: Dict[str, Any] = {}

    for token in tokens:
        if token == '...':
            ellipsis_start = axis_index
            ellipsis_end = axis_index + ellipsis_dims
            new_shape.extend(shape[ellipsis_start:ellipsis_end])
            mapping['...'] = shape[ellipsis_start:ellipsis_end]
            axis_index = ellipsis_end
        elif isinstance(token, list):
            if axis_index >= len(shape):
                raise IndexError(f"Not enough dimensions in tensor shape {shape} for pattern {tokens}")
            total = shape[axis_index]
            dims: List[Optional[int]] = []
            unspecified = []

            for subtoken in token:
                if subtoken in axes_lengths:
                    dims.append(axes_lengths[subtoken])
                    mapping[subtoken] = axes_lengths[subtoken]
                elif re.match(r'^\d+$', subtoken):
                    dims.append(int(subtoken))
                    mapping[subtoken] = int(subtoken)
                else:
                    dims.append(None)
                    unspecified.append(subtoken)

            if dims.count(None) == 1:
                specified = reduce(operator.mul, [d for d in dims if d is not None], 1)
                if total % specified != 0:
                    raise ValueError(f"Cannot split dimension {total} with specified product {specified}")
                inferred = total // specified
                dims = [inferred if d is None else d for d in dims]
                for i, subtoken in enumerate(token):
                    if subtoken not in mapping and subtoken not in axes_lengths and not re.match(r'^\d+$', subtoken):
                        mapping[subtoken] = dims[i]
            elif dims.count(None) > 1:
                raise ValueError(f"Multiple unspecified dimensions in group {token}")

            if reduce(operator.mul, dims, 1) != total:
                raise ValueError(f"Product of dimensions {dims} does not match size {total}")

            new_shape.extend(dims)
            axis_index += 1
        else:
            if axis_index >= len(shape):
                raise IndexError(f"Not enough dimensions in tensor shape {shape} for pattern {tokens}")
            new_shape.append(shape[axis_index])
            if re.match(r'^\d+$', token):
                mapping[token] = int(token)
            elif token in axes_lengths:
                mapping[token] = axes_lengths[token]
            else:
                mapping[token] = shape[axis_index]
            axis_index += 1

    if axis_index != len(shape):
        raise ValueError(f"Not all dimensions in {shape} were used in the pattern {tokens}")

    return tensor.reshape(new_shape), mapping

def compute_output_shape(tokens: List[Union[str, List[str]]],
                         mapping: Dict[str, Any],
                         axes_lengths: Dict[str, int],
                         ellipsis_dims: int) -> Tuple[int, ...]:
    """
    Compute the output shape based on output tokens and dimension mappings.

    Args:
        tokens (List[Union[str, List[str]]]): The parsed output pattern tokens.
        mapping (Dict[str, Any]): Dictionary mapping dimension names to their sizes.
        axes_lengths (Dict[str, int]): Dictionary of known dimension sizes.
        ellipsis_dims (int): Number of dimensions represented by ellipsis.

    Returns:
        Tuple[int, ...]: The output shape as a tuple of integers.

    Raises:
        ValueError: If an unknown token is found in the output pattern.
    """
    shape = []
    for token in tokens:
        if token == '...':
            shape.extend(mapping.get('...', [1] * ellipsis_dims))
        elif isinstance(token, list):
            prod = 1
            for subtoken in token:
                if re.match(r'^\d+$', subtoken):
                    prod *= int(subtoken)
                elif subtoken in mapping:
                    prod *= mapping[subtoken]
                elif subtoken in axes_lengths:
                    prod *= axes_lengths[subtoken]
                else:
                    raise ValueError(f"Unknown token '{subtoken}' in output pattern")
            shape.append(prod)
        else:
            if re.match(r'^\d+$', token):
                shape.append(int(token))
            elif token in mapping:
                shape.append(mapping[token])
            elif token in axes_lengths:
                shape.append(1)
            else:
                raise ValueError(f"Unknown token '{token}' in output pattern")
    return tuple(shape)

def process_repeating_axes(tensor: np.ndarray,
                          input_flat: List[str],
                          output_flat: List[str],
                          axes_lengths: Dict[str, int]) -> np.ndarray:
    """
    Process repeating axes in the output pattern.

    This function handles tokens that appear in the output pattern but not in the input pattern,
    repeating the corresponding dimensions according to the specified axes_lengths.

    Args:
        tensor (np.ndarray): The input tensor after reshaping.
        input_flat (List[str]): Flattened list of input tokens.
        output_flat (List[str]): Flattened list of output tokens.
        axes_lengths (Dict[str, int]): Dictionary of known dimension sizes.

    Returns:
        np.ndarray: The tensor with repeated dimensions as needed.
    """
    input_flat = [token for token in input_flat if token != '...']
    output_flat = [token for token in output_flat if token != '...']
    for i, token in enumerate(output_flat):
        if token not in input_flat and token in axes_lengths:
            tensor = np.repeat(tensor, axes_lengths[token], axis=i)
    return tensor

In [10]:
# ------------------- Unit Tests -------------------

if __name__ == "__main__":
    # Positive Tests
    x = np.random.rand(3, 4)
    assert rearrange(x, 'h w -> w h').shape == (4, 3)

    x = np.random.rand(12, 10)
    assert rearrange(x, '(h w) c -> h w c', h=3).shape == (3, 4, 10)

    x = np.random.rand(3, 4, 5)
    assert rearrange(x, 'a b c -> (a b) c').shape == (12, 5)

    x = np.random.rand(3, 1, 5)
    assert rearrange(x, 'a 1 c -> a b c', b=4).shape == (3, 4, 5)

    x = np.random.rand(2, 3, 4, 5)
    assert rearrange(x, '... h w -> ... (h w)').shape == (2, 3, 20)

    print("✅ All positive tests passed.")

    # Negative Tests
    try:
        rearrange(np.random.rand(12, 10), '(h w) c -> h w c')
    except ValueError as e:
        print("Caught expected error:", e)

    try:
        rearrange(np.random.rand(3, 4), 'a b -> (a b c)')
    except ValueError as e:
        print("Caught expected error:", e)

    try:
        rearrange(np.random.rand(2, 2), 'a b -> a a b')
    except ValueError as e:
        print("Caught expected error:", e)

    print("✅ All negative tests passed.")

✅ All positive tests passed.
Caught expected error: Multiple unspecified dimensions in group ['h', 'w']
Caught expected error: Unknown token 'c' in output pattern
Caught expected error: Duplicate token 'a' found in output pattern
✅ All negative tests passed.
