<a href="https://colab.research.google.com/github/atharva753/SarvamLLM_Assignment-2/blob/main/einops.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import re
import numpy as np
from typing import Any

def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths: int) -> np.ndarray:
    def parse_side(side: str) -> list:
        tokens = re.findall(r'\([^)]*\)|\.\.\.|[^\s]+', side)
        parsed = []
        for token in tokens:
            if token == '...':
                parsed.append('...')
            elif token.startswith('('):
                inner = token[1:-1].strip()
                if not inner:
                    raise ValueError("Empty group '()' is not allowed in the pattern.")
                parsed.append(inner.split())
            else:
                parsed.append(token)
        return parsed
    try:
        input_pattern, output_pattern = pattern.split("->")
    except Exception:
        raise ValueError("Invalid pattern: pattern must contain '->' separating input and output.")
    input_pattern = input_pattern.strip()
    output_pattern = output_pattern.strip()
    input_tokens = parse_side(input_pattern)
    output_tokens = parse_side(output_pattern)
    explicit_count = sum(1 for tok in input_tokens if tok != '...')
    if explicit_count > tensor.ndim:
        raise ValueError("The tensor has fewer dimensions than the explicit tokens specified in the input pattern.")
    flat_input = []
    tensor_axes = []
    t_ptr = 0
    for tok in input_tokens:
        if tok == '...':
            num_ell = tensor.ndim - explicit_count
            if num_ell < 0:
                raise ValueError("Tensor dimensionality incompatible with the pattern.")
            batch_labels = [f"__batch_{i}" for i in range(num_ell)]
            flat_input.extend(batch_labels)
            tensor_axes.extend(tensor.shape[t_ptr:t_ptr+num_ell])
            t_ptr += num_ell
        elif isinstance(tok, list):
            if t_ptr >= tensor.ndim:
                raise ValueError("Not enough tensor dimensions for group token.")
            merged_size = tensor.shape[t_ptr]
            factors = {}
            for name in tok:
                if name in axes_lengths:
                    factors[name] = axes_lengths[name]
                else:
                    factors[name] = None
            unknowns = [k for k, v in factors.items() if v is None]
            prod_known = 1
            for v in factors.values():
                if v is not None:
                    prod_known *= v
            if len(unknowns) > 1:
                raise ValueError(f"Cannot infer more than one missing axis length in group {tok}")
            if unknowns:
                missing_name = unknowns[0]
                if merged_size % prod_known != 0:
                    raise ValueError(f"Inferred size for axis '{missing_name}' does not divide merged axis size {merged_size}")
                factors[missing_name] = merged_size // prod_known
            else:
                if prod_known != merged_size:
                    raise ValueError("Provided axes lengths do not match the merged axis size.")
            for name in tok:
                flat_input.append(name)
                tensor_axes.append(factors[name])
            t_ptr += 1
        else:
            if t_ptr >= tensor.ndim:
                raise ValueError("Not enough tensor dimensions for token.")
            flat_input.append(tok)
            tensor_axes.append(tensor.shape[t_ptr])
            t_ptr += 1
    if len(flat_input) != tensor.ndim:
        raise ValueError("Parsed input token count does not match tensor dimensions.")
    for i, token in enumerate(flat_input):
        if token == '1' and token in axes_lengths:
            tensor_axes[i] = axes_lengths[token]
    input_mapping = {}
    for idx, label in enumerate(flat_input):
        input_mapping.setdefault(label, []).append((idx, tensor_axes[idx]))
    output_flat = []
    output_merge_groups = []
    for tok in output_tokens:
        if tok == '...':
            batch = [lbl for lbl in flat_input if lbl.startswith("__batch_")]
            output_flat.extend(batch)
        elif isinstance(tok, list):
            start_index = len(output_flat)
            output_flat.extend(tok)
            output_merge_groups.append((start_index, len(tok), tok))
        else:
            output_flat.append(tok)
    for label in output_flat:
        if label not in input_mapping and label != '1':
            raise ValueError(f"Output label '{label}' not found in input pattern.")
    used_indices = set()
    permutation = []
    for label in output_flat:
        indices = input_mapping.get(label, [])
        chosen = None
        for idx, size in indices:
            if idx not in used_indices or label == '1':
                chosen = idx
                used_indices.add(idx)
                break
        if chosen is None:
            raise ValueError(f"Could not resolve label '{label}' to a unique tensor axis.")
        permutation.append(chosen)
    remaining = [i for i in range(len(flat_input)) if i not in used_indices]
    permutation.extend(remaining)
    transposed = np.transpose(tensor, axes=permutation)
    transposed_shape = [tensor_axes[i] for i in permutation]
    final_tensor = transposed
    new_shape = list(transposed_shape)
    for i, label in enumerate(output_flat):
        if flat_input[permutation[i]] == '1':
            if label not in axes_lengths:
                raise KeyError(f"Missing axes length for repeated singleton axis corresponding to '{label}'")
            repeat_factor = axes_lengths[label]
            if new_shape[i] != 1:
                raise ValueError("Attempt to repeat a non-singleton axis")
            final_tensor = np.repeat(final_tensor, repeats=repeat_factor, axis=i)
            new_shape[i] = repeat_factor
    shift = 0
    for start, length, group_tokens in output_merge_groups:
        actual_start = start - shift
        if actual_start + length > len(new_shape):
            raise ValueError("Merging group indices out of bounds.")
        prod = 1
        for j in range(actual_start, actual_start + length):
            prod *= new_shape[j]
        pre_shape = new_shape[:actual_start]
        post_shape = new_shape[actual_start + length:]
        new_shape = pre_shape + [prod] + post_shape
        final_tensor = final_tensor.reshape(new_shape)
        shift += length - 1
    return final_tensor

if __name__ == "__main__":
    x = np.random.rand(3, 4)
    print(rearrange(x, "h w -> w h"))


[[0.56726569 0.26351639 0.18395007]
 [0.02702079 0.4047904  0.81889507]
 [0.09378598 0.75626025 0.79940272]
 [0.50546525 0.84770438 0.90001044]]
