# **Scratch_Implementation_of_Einops**

# **Importing dependencies**

In [1]:
import re
import numpy as np
from typing import Dict, List, Tuple, Optional

# **Pattern Parsing**

In [2]:
def parse_pattern(pattern: str) -> Tuple[List[str], List[str]]:
    """Parse the rearrange pattern into input and output parts."""
    if '->' not in pattern:
        raise ValueError(f"Pattern must contain '->', got '{pattern}'")

    input_part, output_part = map(str.strip, pattern.split('->'))

    # Split into components while preserving parentheses
    def split_axes(s: str) -> List[str]:
        return re.findall(r'\(.*?\)|\.\.\.|\w+', s)

    input_axes = split_axes(input_part)
    output_axes = split_axes(output_part)

    return input_axes, output_axes

# **Validating the shapes**

In [10]:
def validate_shape(input_shape: Tuple[int, ...],
                  input_axes: List[str],
                  axes_lengths: Dict[str, int]) -> Dict[str, int]:
    """
    Validate the input shape against the pattern and axes lengths.
    Returns a dictionary of all axis lengths (including inferred ones).
    """
    axis_lengths = axes_lengths.copy()
    shape_idx = 0

    for axis in input_axes:
        if axis == '...':
            # Handle ellipsis (batch dimensions)
            remaining_dims = len(input_shape) - len(input_axes) + 1
            if remaining_dims < 0:
                raise ValueError(f"Not enough dimensions in input tensor for pattern")
            shape_idx += remaining_dims
            continue

        if '(' in axis:
            # Split axis like (h w)
            inner_axes = axis[1:-1].split()
            total_size = 1
            for inner_axis in inner_axes:
                if inner_axis in axis_lengths:
                    total_size *= axis_lengths[inner_axis]
                else:
                    raise ValueError(f"Missing length for axis '{inner_axis}' in split pattern")

            if shape_idx >= len(input_shape):
                raise ValueError("Input tensor has fewer dimensions than pattern")

            if input_shape[shape_idx] != total_size:
                raise ValueError(
                    f"Split axis size mismatch: expected {total_size} (product of {inner_axes}), "
                    f"got {input_shape[shape_idx]}"
                )
            shape_idx += 1
        else:
            # Simple axis
            if shape_idx >= len(input_shape):
                raise ValueError("Input tensor has fewer dimensions than pattern")

            if axis in axis_lengths:
                if axis_lengths[axis] != input_shape[shape_idx]:
                    raise ValueError(
                        f"Length mismatch for axis '{axis}': "
                        f"expected {axis_lengths[axis]}, got {input_shape[shape_idx]}"
                    )
            else:
                axis_lengths[axis] = input_shape[shape_idx]
            shape_idx += 1

    if shape_idx != len(input_shape):
        raise ValueError("Input tensor has more dimensions than pattern")

    return axis_lengths

# **Rearrange Function**

In [30]:
def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths) -> np.ndarray:
    """
    Reorganize tensor according to the specified pattern.
    """
    # Parse the pattern
    input_axes, output_axes = parse_pattern(pattern)

    # Validate input shape and get all axis lengths
    axis_lengths = validate_shape(tensor.shape, input_axes, axes_lengths)

    # Step 1: Reshape the tensor to combine split axes
    current_tensor = tensor
    reshape_dims = []
    transpose_order = []
    input_idx = 0

    for axis in input_axes:
        if axis == '...':
            # Handle batch dimensions
            remaining_dims = len(tensor.shape) - len(input_axes) + 1
            for _ in range(remaining_dims):
                reshape_dims.append(tensor.shape[input_idx])
                transpose_order.append(input_idx)
                input_idx += 1
            continue

        if '(' in axis:
            # Split axis
            inner_axes = axis[1:-1].split()
            split_sizes = [axis_lengths[a] for a in inner_axes]
            reshape_dims.extend(split_sizes)
            transpose_order.append(input_idx)
            input_idx += 1
        else:
            # Simple axis
            reshape_dims.append(axis_lengths[axis])
            transpose_order.append(input_idx)
            input_idx += 1

    # Reshape to split any combined axes
    if len(reshape_dims) != len(tensor.shape):
        current_tensor = np.reshape(current_tensor, reshape_dims)

    # Step 2: Build axis positions mapping
    axis_positions = {}
    pos = 0
    for axis in input_axes:
        if axis == '...':
            remaining_dims = len(tensor.shape) - len(input_axes) + 1
            for i in range(remaining_dims):
                axis_positions[f'...{i}'] = pos + i
            pos += remaining_dims
        elif '(' in axis:
            inner_axes = axis[1:-1].split()
            for inner_axis in inner_axes:
                axis_positions[inner_axis] = pos
                pos += 1
        else:
            axis_positions[axis] = pos
            pos += 1

    # Step 3: Determine output shape and transpose order
    output_shape = []
    output_transpose = []

    pos = 0
    for axis in output_axes:
        if axis == '...':
            # Handle batch dimensions
            batch_dims = [k for k in axis_positions if k.startswith('...')]
            batch_dims.sort(key=lambda x: int(x[3:]) if len(x) > 3 else 0)
            for dim in batch_dims:
                output_transpose.append(axis_positions[dim])
                output_shape.append(current_tensor.shape[axis_positions[dim]])
            continue

        if '(' in axis:
            # Merge axes
            inner_axes = axis[1:-1].split()
            size = 1
            for inner_axis in inner_axes:
                if inner_axis in axis_positions:  # Check if the axis exists in input
                    output_transpose.append(axis_positions[inner_axis])
                    size *= axis_lengths[inner_axis]
                else:
                    # Handle new or repeated axes within merged axes
                    size *= axis_lengths.get(inner_axis, 1)  # If not found, assume size 1
            output_shape.append(size)
        elif axis in axis_positions:
            # Existing axis
            output_transpose.append(axis_positions[axis])
            output_shape.append(axis_lengths[axis])
        else:
            # New axis (repetition)
            if axis not in axis_lengths:
                raise ValueError(f"Unknown axis '{axis}' in output pattern")
            output_shape.append(axis_lengths[axis])

    # Perform the transpose
    # Check if transpose is actually necessary
    if output_transpose and output_transpose != list(range(len(output_transpose))):
        current_tensor = np.transpose(current_tensor, output_transpose)

    # Reshape to the final output shape
    return np.reshape(current_tensor, output_shape)

# **Testing the Implementation**

In [37]:
#Test Transpose
x = np.random.rand(3, 4)
result = rearrange(x, 'h w -> w h')
assert result.shape == (4, 3)
assert np.allclose(result, x.T)
print(result)

[[0.6490135  0.3274479  0.8457537 ]
 [0.28244014 0.8461175  0.51904594]
 [0.359951   0.66607217 0.8119557 ]
 [0.1013809  0.64236727 0.79222967]]


In [32]:
# Test split axis
x = np.random.rand(12, 10)
result = rearrange(x, '(h w) c -> h w c', h=3,w=4)
assert result.shape == (3, 4, 10)
assert np.allclose(result, x.reshape(3, 4, 10))
print(result)

[[[0.86618109 0.71870061 0.11915996 0.60943997 0.07670577 0.24388805
   0.00765418 0.99089006 0.46813386 0.31119769]
  [0.63311909 0.41516128 0.93066064 0.98396487 0.04282848 0.88272992
   0.14556316 0.73334312 0.49904956 0.93595441]
  [0.31014712 0.57447275 0.65041448 0.41128217 0.29586901 0.43128662
   0.50279964 0.09390212 0.53314928 0.26923399]
  [0.48585885 0.14023704 0.392011   0.98901489 0.90668418 0.80087353
   0.93351851 0.71356172 0.91053733 0.76260652]]

 [[0.66888945 0.18455879 0.11608749 0.01211214 0.45403757 0.39398608
   0.38298733 0.86258221 0.58260407 0.92761327]
  [0.36163606 0.03115892 0.32533408 0.01032926 0.92362012 0.92768867
   0.85061328 0.74195318 0.56849735 0.53037461]
  [0.88034976 0.90679952 0.48642639 0.29329472 0.50742869 0.76666003
   0.88495202 0.16008531 0.38840925 0.87353935]
  [0.43083249 0.08728273 0.25248418 0.87272355 0.54386948 0.70051476
   0.29132395 0.17639452 0.71706701 0.28619366]]

 [[0.04825004 0.13809668 0.55202524 0.76461078 0.54999779 0.

In [33]:
# Test merge axes
x = np.random.rand(3, 4, 5)
result = rearrange(x, 'a b c -> (a b) c')
assert result.shape == (12, 5)
assert np.allclose(result, x.reshape(12, 5))
print(result)

[[0.9148496  0.59127652 0.14826708 0.79155329 0.95299138]
 [0.35769754 0.06351252 0.36610184 0.89132907 0.1467685 ]
 [0.27039273 0.87821392 0.27431536 0.73372591 0.23527835]
 [0.53928156 0.68941155 0.11360783 0.97197337 0.94575449]
 [0.87573686 0.09715079 0.93164302 0.62845888 0.09547183]
 [0.44162956 0.09933135 0.83600656 0.01644651 0.06045645]
 [0.22793613 0.39439851 0.36695713 0.12326195 0.82492789]
 [0.15634991 0.80185122 0.83577359 0.51185084 0.04001497]
 [0.07617363 0.8900316  0.54838826 0.04371433 0.55442802]
 [0.39129073 0.13807535 0.861427   0.06868399 0.98151636]
 [0.07734323 0.10883025 0.93864894 0.38329702 0.1192459 ]
 [0.4361359  0.10640083 0.36834962 0.70085583 0.50325263]]


In [35]:
# Test batch dimensions
x = np.random.rand(2, 3, 4, 5)
result = rearrange(x, '... h w -> ... (h w)')
assert result.shape == (2, 3, 20)
assert np.allclose(result, x.reshape(2, 3, 20))
print(result)

[[[4.15572313e-01 2.38835348e-02 6.65907899e-01 3.46062246e-01
   8.81714724e-01 7.97118657e-01 7.43929803e-01 4.72589297e-01
   5.63246146e-01 7.97172286e-01 8.44393487e-01 6.97822364e-01
   7.13685387e-01 6.09572582e-01 2.30949353e-01 2.44469993e-01
   7.14987731e-01 8.74661384e-01 9.77104407e-01 7.86156157e-01]
  [5.14499340e-01 3.20879746e-01 2.44506102e-05 4.53126627e-01
   7.58683097e-01 5.32228349e-01 1.74705227e-01 9.81554110e-01
   4.93187105e-01 2.61297353e-01 8.25343895e-01 9.75495068e-01
   7.89686877e-01 8.65701137e-01 2.33733892e-01 2.46437919e-01
   8.19733305e-01 3.93362303e-01 3.82342268e-01 7.18621735e-02]
  [9.38267854e-01 4.77210244e-01 1.06793309e-01 1.36388199e-02
   8.73214908e-01 3.53244894e-01 8.04889755e-01 9.86367616e-01
   3.11351327e-01 2.66892192e-01 2.17720471e-01 5.00274945e-01
   8.82925050e-01 8.53211651e-01 1.03864628e-01 4.71305156e-01
   6.31099250e-01 2.08183719e-01 3.71454567e-01 5.53926456e-01]]

 [[7.67169712e-01 9.20964821e-01 2.84979659e-01 9.

In [36]:
x = np.random.rand(2, 3, 4, 6, 5)
result = rearrange(x, 'a b ... (c d) e -> ... a d b e c', c=2,d=3)
expected = x.reshape(2, 3, 4, 2, 3, 5).transpose(2, 3, 0, 4, 1, 5).reshape(4, 2, 3, 3, 5, 2)
print(result)
print(expected)

[[[[[[9.41121309e-01 2.40077355e-01]
     [4.21466320e-01 4.13215652e-01]
     [3.96743267e-01 4.45502734e-01]
     [6.98996956e-01 8.70382236e-02]
     [7.80700037e-01 9.73436078e-01]]

    [[2.73446673e-01 7.14094959e-01]
     [8.19063039e-01 1.36315909e-01]
     [6.30342419e-01 9.47319017e-01]
     [2.38604998e-01 9.35053845e-01]
     [3.01581394e-01 5.93496378e-01]]

    [[7.48282352e-02 4.81455575e-01]
     [4.80173627e-01 5.63335876e-01]
     [9.84366203e-01 5.59931063e-01]
     [3.28931517e-01 2.34435294e-01]
     [3.93581067e-01 6.11669988e-01]]]


   [[[5.20472979e-01 4.71860491e-01]
     [2.34741825e-01 4.20430728e-01]
     [3.98187937e-01 8.19057037e-01]
     [1.58664394e-01 5.19633776e-01]
     [3.86025798e-01 1.10965638e-01]]

    [[4.50586034e-01 4.77600085e-01]
     [7.64944076e-01 5.61355775e-01]
     [3.67544711e-01 4.16219858e-01]
     [4.54362812e-02 1.96111817e-01]
     [6.63518801e-01 7.50068286e-01]]

    [[6.88801795e-01 9.22153984e-01]
     [6.61117879e-01 5.165