<a href="https://colab.research.google.com/github/Jyothiraditya135/Einops_implementation/blob/main/Einops_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

###The implementation

In [179]:
#Necessary imports
import re
import numpy as np
from typing import List, Union, Tuple, Dict

In [180]:
def identify_implicit_axes(input_shape: List[int],
                           left_str: str,
                           params: Dict[str, int]) -> Dict[str, int] :
    """
    The function identifies axes that are to be inferred from the shape of the input tensor.
    """

    left_dim_ltrs = re.findall(r'\(.*?\)|\S+', left_str)

    map_ltrs_dims = {}
    for ax in params:
        map_ltrs_dims[ax] = params[ax]

    ell_len = 0

    for i, a in enumerate(left_dim_ltrs):

        i = i + ell_len

        if a in map_ltrs_dims.keys():
          raise ValueError(f"Indexing expression contains duplicate dimension {a}")

        if a.startswith('('):
          a = a.strip('()').split()
          missing_ltr = None
          prod = 1

          for key in a:
            if '(' in key or ')' in key:
                raise ValueError("Brackets inside brackets not allowed")
            if key not in params:
              missing_ltr = key
            else:
              prod *= params[key]

          if missing_ltr is not None:
              missing_dim = input_shape[i] // prod
              map_ltrs_dims[missing_ltr] = missing_dim

    return map_ltrs_dims

In [181]:
def tokenize_and_parse(pattern: str) -> List[str]:

    parsed = re.findall(r"\([^()]+\)|\S+", pattern)

    tokens = []

    for el in parsed:
        if el.startswith("(") and el.endswith(")"):
            tokens.append(el[1:-1].split())
        else:
            tokens.append(el)
    return tokens

In [182]:
def flatten(tokens: List[str]) -> List[str]:
    flat = []
    for token in tokens:
        if isinstance(token, list):
            flat.extend(token)
        else:
            flat.append(token)
    return flat

In [183]:
def substitute_ellipsis(flat_tokens: List[str], rank: int) -> List[str]:
    """
    If an ellipsis ("...") is present in the flat token list, substitute it with placeholder tokens.
    For example, if rank = 5 and there are 2 explicit tokens, then the ellipsis is replaced by 3 tokens: [_e0, _e1, _e2]
    """

    if flat_tokens.count('...') > 1:
        raise ValueError("More than one ellipsis in pattern!")

    if "..." in flat_tokens:

        pos = flat_tokens.index("...")
        other_toks = [tok for tok in flat_tokens if tok != "..."]
        num_placeholders = rank - len(other_toks)

        if num_placeholders < 0:
            raise ValueError("More explicit tokens than tensor rank!")

        placeholders = [f"_e{i}" for i in range(num_placeholders)]

        return (flat_tokens[:pos] + placeholders + flat_tokens[pos+1:]), placeholders

    return flat_tokens, []

In [184]:
def is_valid_axis_name(name: str) -> bool:
    return re.fullmatch(r"[a-zA-Z_][a-zA-Z0-9_]*", name)

def validate_tokens(tokens: List[str]):
    for token in tokens:
        if not is_valid_axis_name(token):
            raise ValueError(f"Invalid axis identifier: '{token}'")

In [185]:
def parse_pattern(pattern: str) -> Tuple[List[str], List[str]]:

    if '->' not in pattern:
        raise ValueError("Pattern must contain '->'")

    left_str, right_str = pattern.split("->")

    left_tokens = tokenize_and_parse(left_str.strip())
    right_tokens = tokenize_and_parse(right_str.strip())

    return left_tokens, right_tokens

In [186]:
def determine_steps(left_tokens: List[str],
                    right_tokens: List[str],
                    input_shape: List[int],
                    params: Dict[str, int]) -> List[Dict]:
    """
    The function helps determine the steps to be taken based on the parsed pattern, to reach the required output.

    Steps (applied in order):
      1. Split (de-merge) operations from groups on the left.
      2. Transpose of the "common" tokens to match the order in the right side.
      3. Repeat (insertion) operations for tokens that exist on the right but not the left.
         The step now includes the output axis at which to insert the new axis.
      4. Merge operations for grouped tokens on the right.
    """

    steps = []
    axes_after_split = []

    if '...' in left_tokens:

        left_tokens_mod, placeholders = substitute_ellipsis(left_tokens, rank=len(input_shape))

        right_tokens_mod = []
        for i, tok in enumerate(right_tokens):
            if "..." in tok:
                pos = i
                if isinstance(tok, list):
                    right_tokens_mod.extend(right_tokens[:pos])
                    right_tokens_mod.append(placeholders)
                    right_tokens_mod.extend(right_tokens[pos+1:])
                else:
                    right_tokens_mod = right_tokens[:pos] + placeholders + right_tokens[pos+1:]

        left_flat = flatten(left_tokens_mod)
        right_flat = flatten(right_tokens_mod)

        left_tokens = left_tokens_mod
        right_tokens = right_tokens_mod

    else:
        left_flat = flatten(left_tokens)
        right_flat = flatten(right_tokens)

    if(len(left_tokens) != len(input_shape)):
        raise ValueError(f"Wrong shape: expected {len(left_tokens)} dims. Received {len(input_shape)}-dim tensor.")

    if len(left_flat) != len(set(left_flat)):
      raise ValueError("Invalid input, Pattern contains duplicate dimension")

    validate_tokens(left_flat)
    validate_tokens(right_flat)

    for item in left_flat:
      if item not in right_flat:
        raise ValueError(f"Identifiers only on one side of expression: {item}")

    for ax in params:
      if ax in left_flat or ax in right_flat:
        continue
      else:
        raise ValueError(f"Axes: {ax} has not been used in transformation")

    # Step 1 : Identify splits on left side
    split_axes_inc = 0
    for i, token in enumerate(left_tokens):
        if isinstance(token, list):
            steps.append({'op': 'split', 'axis': i + split_axes_inc, 'into': token})
            split_axes_inc += len(token) - 1

    # Step 2 : Compute common tokens and repeat ones
    common_tokens = [tok for tok in right_flat if tok in left_flat]
    repeat_tokens = []
    for idx, tok in enumerate(right_flat):
        if tok not in left_flat:
          if tok in params:
            repeat_tokens.append((tok, idx))
          else:
            raise ValueError(f"Identifiers only on one side of expression: {tok}")

    # Step 3 : Compute transpose ordering for common tokens
    transpose_order = [left_flat.index(tok) for tok in common_tokens]
    if transpose_order != list(range(len(left_flat))):
        steps.append({'op': 'transpose', 'order': transpose_order})

    # Step 4 : Insert (repeat) new axes at their designated positions
    for token, out_axis in repeat_tokens:
        if token not in params:
            raise ValueError(f"Missing repetition count for token: {token}")
        steps.append({'op': 'repeat', 'token': token, 'count': params[token], 'axis': out_axis})

    # Step 5 : Identify merge operations on the right side
    number_merges = 0
    for token in right_tokens:
        if isinstance(token, list):
            indices = [right_flat.index(tok) - number_merges for tok in token if tok in right_flat]
            number_merges += len(token) - 1
            steps.append({'op': 'merge', 'axes': indices, 'into': token})

    return steps

In [187]:
def do_split(x: np.ndarray, op: Dict, params: Dict[str, int]) -> np.ndarray:
    """
    Splits a merged axis into multiple axes.
    """
    axis = op['axis']
    tokens = op['into']
    original_size = x.shape[axis]
    sizes = []
    for tok in tokens:
        if tok not in params:
            raise ValueError(f"Missing size parameter for token: {tok}")
        sizes.append(params[tok])
    if np.prod(sizes) != original_size:
        raise ValueError(f"Product of sizes {sizes} does not match dimension {original_size}")

    new_shape = list(x.shape[:axis]) + sizes + list(x.shape[axis+1:])
    return np.reshape(x, new_shape)

In [188]:
def do_transpose(x: np.ndarray, op: Dict) -> np.ndarray:
    """
    Transposes the axes according to the given order.
    op contains the permutation order.
    """
    return np.transpose(x, op['order'])

In [189]:
def do_repeat(x: np.ndarray, op: Dict) -> np.ndarray:
    """
    Inserts a new axis at the specified position and repeats it 'rep' times.
    """
    axis = op['axis']
    rep = op['count']
    new_shape = list(x.shape)
    new_shape.insert(axis, 1)
    x_expanded = np.reshape(x, new_shape)
    reps = [1] * len(new_shape)
    reps[axis] = rep
    return np.tile(x_expanded, reps)

In [190]:
def do_merge(x: np.ndarray, op: Dict) -> np.ndarray:
    """
    Merges (flattens) the axes specified in op['axes'] into a single axis.
    op['axes'] is the list of axis indices getting merged.
    """
    axes = op['axes']
    if not axes:
        return x
    axes = sorted(axes)
    merged_size = np.prod([x.shape[ax] for ax in axes])
    new_shape = []
    skip = set(axes)
    for i in range(x.ndim):
        if i == axes[0]:
            new_shape.append(merged_size)
        elif i in skip:
            continue
        else:
            new_shape.append(x.shape[i])
    return np.reshape(x, new_shape)

In [191]:
def apply_transformation(x: np.ndarray, steps: List[Dict], params: Dict[str, int]) -> np.ndarray:
    """
    Applies each transformation step in order to the input array.
    """
    for step in steps:
        op = step['op']
        if op == 'split':
            x = do_split(x, step, params)
        elif op == 'transpose':
            x = do_transpose(x, step)
        elif op == 'repeat':
            x = do_repeat(x, step)
        elif op == 'merge':
            x = do_merge(x, step)
    return x

In [192]:
pattern = "(h1 h2) w -> h1 b1 (w h2)"
params = {'b1': 2, 'h1': 1, 'h2': 2}

x = np.arange(2 * 3).reshape(2, 3)
print("Input x shape:", x.shape)
print("x =", x)

left_tokens, right_tokens = parse_pattern(pattern)
steps = determine_steps(left_tokens, right_tokens, x.shape, params)

print("Parsed Left Tokens:", left_tokens)
print("Parsed Right Tokens:", right_tokens)
print("Transformation Steps:")
for step in steps:
    print(step)

y = apply_transformation(x, steps, params)
print("Output shape:", y.shape)
print("y =", y)

Input x shape: (2, 3)
x = [[0 1 2]
 [3 4 5]]
Parsed Left Tokens: [['h1', 'h2'], 'w']
Parsed Right Tokens: ['h1', 'b1', ['w', 'h2']]
Transformation Steps:
{'op': 'split', 'axis': 0, 'into': ['h1', 'h2']}
{'op': 'transpose', 'order': [0, 2, 1]}
{'op': 'repeat', 'token': 'b1', 'count': 2, 'axis': 1}
{'op': 'merge', 'axes': [2, 3], 'into': ['w', 'h2']}
Output shape: (1, 2, 6)
y = [[[0 3 1 4 2 5]
  [0 3 1 4 2 5]]]


In [193]:
def rearrange(tensor, pattern, **axes_lengths):

    left_tokens, right_tokens = parse_pattern(pattern)

    params2 = identify_implicit_axes(tensor.shape, pattern.split('->')[0], axes_lengths)

    steps = determine_steps(left_tokens, right_tokens, tensor.shape, params2)

    op = apply_transformation(tensor, steps, params2)

    return op

In [194]:
import einops

###Examples

Transpose

In [195]:
a1 = np.random.randn(6, 3, 3)
b1 = rearrange(a1, 'h w c -> w h c')
c1 = einops.rearrange(a1, 'h w c -> w h c')
# print(a1)
# print(b1)
# print(c1)
print(np.array_equal(b1, c1))

True


Merge

In [196]:
a2 = np.random.randn(4, 4, 2)
b2 = rearrange(a2, 'h w c -> (h w) c')
c2 = einops.rearrange(a2, 'h w c -> (h w) c')
# print(a2)
# print(b2)
# print(c2)
print(np.array_equal(b2, c2))

True


In [197]:
a3 = np.random.randn(4, 4, 2)
b3 = rearrange(a3, 'h w c -> (h w c)')
c3 = einops.rearrange(a3, 'h w c -> (h w c)')
# print(a3)
# print(b3)
# print(c3)
print(np.array_equal(b3, c3))

True


Transpose and Merge

In [198]:
a4 = np.random.randn(3, 4, 4, 2)
b4 = rearrange(a4, 'b h w c -> h (b w) c')
c4 = einops.rearrange(a4, 'b h w c -> h (b w) c')
# print(a4)
# print(b4)
# print(c4)
print(np.array_equal(b4, c4))

True


Implicitly identify value of axes not given

In [199]:
a5 = np.random.randn(2, 3, 3)
b5 = rearrange(a5, "(b1 b2) h w -> b2 b1 h w", b1=1)
c5 = einops.rearrange(a5, "(b1 b2) h w -> b2 b1 h w", b1=1)
# print(a5)
# print(b5)
# print(c5)
print(np.array_equal(b5, c5))

True


In [200]:
a6 = np.random.randn(2, 3, 3)
b6 = rearrange(a6, "(b1 b2) h w -> (b1 h) (b2 w)", b1=1)
c6 = einops.rearrange(a6, "(b1 b2) h w -> (b1 h) (b2 w)", b1=1)
# print(a6)
# print(b6)
# print(c6)
print(np.array_equal(b6, c6))

True


Ellipsis handling and repeat

In [201]:
a7 = np.random.rand(4, 1)
b7 = einops.repeat(a7, '... -> ... c', c=2)
c7 = rearrange(a7, '... -> ... c', c=2)
#print(a7)
# print(b7)
# print(c7)
print(np.array_equal(b7, c7))

True


In [202]:
a8 = np.random.rand(2, 6)
b8 = rearrange(a8, '... -> c ...', c=3)
c8 = einops.repeat(a8, '... -> c ...', c=3)
#print(a8)
# print(b8)
# print(c8)
print(np.array_equal(b8, c8))

True


Ellipsis handling and merge

In [203]:
a9 = np.random.rand(6, 3, 4, 2)
b9 = rearrange(a9, '... w c -> ... (w c)')
c9 = einops.rearrange(a9, '... w c -> ... (w c)')
#print(a9)
# print(b9)
# print(c9)
print(np.array_equal(b9, c9))

True


Ellipsis, merge and repeat

In [204]:
a10 = np.random.randn(2, 26, 26, 3)
b10 = rearrange(a10, 'b ... -> b (...)')
c10 = einops.repeat(a10, 'b ... -> b (...)')
#print(a10)
# print(b10)
# print(c10)
print(np.array_equal(b10, c10))

True


Split

In [205]:
a11 = np.random.randn(3, 4)
b11 = rearrange(a11, 'h (w1 w2) -> h w1 w2', w1=2, w2=2)
c11 = einops.rearrange(a11, 'h (w1 w2) -> h w1 w2', w1=2, w2=2)
#print(a11)
# print(b11)
# print(c11)
print(np.array_equal(b11, c11))

True


Repeat

In [206]:
a12 = np.random.rand(3, 4)
b12 = rearrange(a12, 'h w -> h c w', c=3)
c12 = einops.repeat(a12, 'h w -> h c w', c=3)
#print(a12)
# print(b12)
# print(c12)
print(np.array_equal(b12, c12))

True


In [207]:
a13 = np.random.randn(2, 3)
b13 = rearrange(a13, "h w -> h b1 w", b1=2)
c13 = einops.repeat(a13, "h w -> h b1 w", b1=2)
# print(a13)
# print(b13)
# print(c13)
print(np.array_equal(b13, c13))

True


Repeat and Merge

In [208]:
a14 = np.random.rand(3, 2)
b14 = rearrange(a14, 'h w -> h (w c)', c=3)
c14 = einops.repeat(a14, 'h w -> h (w c)', c=3)
#print(a14)
# print(b14)
# print(b14)
print(np.array_equal(b14, b14))

True


###Mad designer gallery, inspired from the einops library

In [209]:
a_1 = np.random.randn(2, 3, 3, 1)
c_1 = einops.rearrange(a_1, "(b1 b2) h w ... -> (b1 h) (b2 w) ...", b1=1)
b_1 = rearrange(a_1, "(b1 b2) h w ... -> (b1 h) (b2 w) ...", b1=1)
# print(a_1)
# print(b_1)
# print(c_1)
print(np.array_equal(b_1, c_1))

True


In [210]:
a_2 = np.random.randn(2, 3)
b_2 = rearrange(a_2, "(h1 h2) w -> h1 b1 (w h2)", b1=2, h1=1, h2=2)
c_2 = einops.repeat(a_2, "(h1 h2) w -> h1 b1 (w h2)", b1=2, h1=1, h2=2)
print(np.array_equal(b_2, c_2))

True


In [211]:
a_3 = np.random.randn(2, 4, 8, 1)
c_3 = einops.rearrange(a_3, "b (h1 h2 h3) (w1 w2 w3) c -> (h1 w2 h3) (b w1 h2 w3) c", h2=2, w2=2, w3=2, h3=2)
b_3 = rearrange(a_3, "b (h1 h2 h3) (w1 w2 w3) c -> (h1 w2 h3) (b w1 h2 w3) c", h2=2, w2=2, w3=2, h3=2)
print(np.array_equal(b_3, c_3))

True


In [212]:
ims = np.random.randn(6, 96, 96, 3)

In [213]:
md1 = rearrange(ims, "(b1 b2) h w c -> (h b1) (w b2) c ", b1=2)
ans1 = einops.rearrange(ims, "(b1 b2) h w c -> (h b1) (w b2) c ", b1=2)
print(np.array_equal(md1, ans1))

md2 = rearrange(ims, "(b1 b2) h w c -> (h b1) (b2 w) c", b1=2)
ans2 = einops.rearrange(ims, "(b1 b2) h w c -> (h b1) (b2 w) c", b1=2)
print(np.array_equal(md2, ans2))

md3 = rearrange(ims, "b (h1 h2) (w1 w2) c -> (h1 w2) (b w1 h2) c", h2=8, w2=8)
ans3 = einops.rearrange(ims, "b (h1 h2) (w1 w2) c -> (h1 w2) (b w1 h2) c", h2=8, w2=8)
print(np.array_equal(md3, ans3))

md4 = rearrange(ims, "b (h1 h2 h3) (w1 w2 w3) c -> (h1 w2 h3) (b w1 h2 w3) c", h2=2, w2=2, w3=2, h3=2)
ans4 = einops.rearrange(ims, "b (h1 h2 h3) (w1 w2 w3) c -> (h1 w2 h3) (b w1 h2 w3) c", h2=2, w2=2, w3=2, h3=2)
print(np.array_equal(md4, ans4))

md5 = rearrange(ims, "(b1 b2) (h1 h2) (w1 w2) c -> (h1 b1 h2) (w1 b2 w2) c", h1=3, w1=3, b2=3)
ans5 = einops.rearrange(ims, "(b1 b2) (h1 h2) (w1 w2) c -> (h1 b1 h2) (w1 b2 w2) c", h1=3, w1=3, b2=3)
print(np.array_equal(md5, ans5))

True
True
True
True
True


###A few error handling cases

Error handling of invalid axis names

In [214]:
n_p = np.random.randn(6, 3, 3)
try:
  rearrange(n_p, 'h... w c -> w h... (c k)', k=2)
except ValueError as e:
  print(e)

Invalid axis identifier: 'h...'


Duplicate dimension error handling

In [215]:
from einops import EinopsError

In [216]:
ae1 = np.random.rand(4, 3, 2)
try:
  be1 = rearrange(ae1, 'b b h -> b h b')
except ValueError as e:
  print(e)
try:
  ce1 = einops.rearrange(ae1, 'b b h -> b h b')
except EinopsError as e:
  print(e)

Invalid input, Pattern contains duplicate dimension
 Error while processing rearrange-reduction pattern "b b h -> b h b".
 Input tensor shape: (4, 3, 2). Additional info: {}.
 Indexing expression contains duplicate dimension "b"


Error handling of axes not specified or not interpretable

In [217]:
ae2 = np.random.rand(6, 3, 4, 2)
try:
  be2 = rearrange(ae2, '... w c -> ... w1 (w2 c)', w1 = 2, w2 = 2)
except ValueError as e:
  print(e)
try:
  ce2 = einops.rearrange(ae2, '... w c -> ... w1 (w2 c)', w1 = 2, w2 = 2)
except EinopsError as e:
  print(e)

Identifiers only on one side of expression: w
 Error while processing rearrange-reduction pattern "... w c -> ... w1 (w2 c)".
 Input tensor shape: (6, 3, 4, 2). Additional info: {'w1': 2, 'w2': 2}.
 Identifiers only on one side of expression (should be on both): {'w2', 'w1', 'w'}


In [221]:
ae2 = np.random.rand(6, 3, 4, 2)
try:
  be2 = rearrange(ae2, 'b h w c -> b h ((w c))', w1 = 2, w2 = 2)
except ValueError as e:
  print(e)
try:
  ce2 = einops.rearrange(ae2, 'b h w c -> b h w c)', w1 = 2, w2 = 2)
except EinopsError as e:
  print(e)

Invalid axis identifier: '((w'
 Error while processing rearrange-reduction pattern "b h w c -> b h w c)".
 Input tensor shape: (6, 3, 4, 2). Additional info: {'w1': 2, 'w2': 2}.
 Brackets are not balanced
