<a href="https://colab.research.google.com/github/atharva753/SarvamLLM_Assignment-2/blob/main/Unit_tests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
%%writefile your_module.py
import re
import numpy as np
from typing import Any

def rearrange(tensor: np.ndarray, pattern: str, **axes_lengths: int) -> np.ndarray:
    def parse_side(side: str) -> list:
        tokens = re.findall(r'\([^)]*\)|\.\.\.|[^\s]+', side)
        parsed = []
        for token in tokens:
            if token == '...':
                parsed.append('...')
            elif token.startswith('('):
                inner = token[1:-1].strip()
                if not inner:
                    raise ValueError("Empty group '()' is not allowed in the pattern.")
                parsed.append(inner.split())
            else:
                parsed.append(token)
        return parsed
    try:
        input_pattern, output_pattern = pattern.split("->")
    except Exception:
        raise ValueError("Invalid pattern: missing '->'")
    input_pattern = input_pattern.strip()
    output_pattern = output_pattern.strip()
    input_tokens = parse_side(input_pattern)
    output_tokens = parse_side(output_pattern)
    # Build raw input tokens list (group tokens count as one)
    raw_input = []
    for tok in input_tokens:
        if tok == '...':
            raw_input.extend([f"__batch_{i}" for i in range(tensor.ndim - sum(1 for t in input_tokens if t != '...'))])
        elif isinstance(tok, list):
            raw_input.append(tok)
        else:
            raw_input.append(tok)
    if len(raw_input) != tensor.ndim:
        raise ValueError("Parsed input token count does not match tensor dimensions.")
    # Process raw_input to build flat mapping. For group tokens, process the merged axis.
    flat_input = []
    tensor_axes = []
    t_ptr = 0
    for token in raw_input:
        if isinstance(token, list):
            merged_size = tensor.shape[t_ptr]
            factors = {}
            for name in token:
                factors[name] = axes_lengths[name] if name in axes_lengths else None
            unknowns = [n for n,v in factors.items() if v is None]
            prod_known = 1
            for v in factors.values():
                if v is not None:
                    prod_known *= v
            if len(unknowns) > 1:
                raise ValueError(f"Cannot infer more than one missing axis length in group {token}")
            if unknowns:
                missing = unknowns[0]
                if merged_size % prod_known != 0:
                    raise ValueError(f"Inferred size for axis '{missing}' does not divide merged axis size")
                factors[missing] = merged_size // prod_known
            else:
                if prod_known != merged_size:
                    raise ValueError("Provided axes lengths do not match the merged axis size.")
            for name in token:
                flat_input.append(name)
                tensor_axes.append(factors[name])
            t_ptr += 1
        else:
            flat_input.append(token)
            tensor_axes.append(tensor.shape[t_ptr])
            t_ptr += 1
    # Build input mapping.
    input_mapping = {}
    for idx, label in enumerate(flat_input):
        input_mapping.setdefault(label, []).append(idx)
    # Process output tokens. For groups in output, keep as list.
    output_flat = []
    output_groups = []
    for tok in output_tokens:
        if tok == '...':
            batch = [lbl for lbl in flat_input if lbl.startswith("__batch_")]
            output_flat.extend(batch)
        elif isinstance(tok, list):
            start = len(output_flat)
            output_flat.extend(tok)
            output_groups.append((start, len(tok)))
        else:
            output_flat.append(tok)
    perm = []
    used = set()
    for label in output_flat:
        if label in input_mapping:
            indices = input_mapping[label]
            chosen = None
            for i in indices:
                if i not in used:
                    chosen = i
                    used.add(i)
                    break
            if chosen is None:
                raise ValueError(f"Could not resolve label '{label}'")
            perm.append(chosen)
        else:
            # If output label not found, try to resolve a '1' token.
            ones = [i for i, lab in enumerate(flat_input) if lab == '1' and i not in used]
            if ones:
                chosen = ones[0]
                used.add(chosen)
                perm.append(chosen)
            else:
                raise ValueError(f"Output label '{label}' not found in input pattern.")
    remaining = [i for i in range(len(flat_input)) if i not in used]
    perm.extend(remaining)
    transposed = np.transpose(tensor, axes=perm)
    new_shape = [tensor_axes[i] for i in perm]
    out_tensor = transposed
    for i, label in enumerate(output_flat):
        if flat_input[perm[i]] == '1' and label not in input_mapping:
            if label not in axes_lengths:
                raise KeyError(f"Missing axes length for repeated singleton axis '{label}'")
            factor = axes_lengths[label]
            if new_shape[i] != 1:
                raise ValueError("Attempt to repeat a non-singleton axis")
            out_tensor = np.repeat(out_tensor, factor, axis=i)
            new_shape[i] = factor
    shift = 0
    for start, length in output_groups:
        real_start = start - shift
        if real_start + length > len(new_shape):
            raise ValueError("Merging group indices out of bounds.")
        prod = 1
        for j in range(real_start, real_start+length):
            prod *= new_shape[j]
        pre = new_shape[:real_start]
        post = new_shape[real_start+length:]
        new_shape = pre + [prod] + post
        out_tensor = out_tensor.reshape(new_shape)
        shift += length - 1
    return out_tensor

if __name__ == "__main__":
    import numpy as np
    x = np.random.rand(3, 4)
    print(rearrange(x, "h w -> w h"))


Overwriting your_module.py


In [None]:
import numpy as np
import unittest
from your_module import rearrange

class TestRearrangeFunction(unittest.TestCase):
    def test_transpose(self):
        x = np.arange(12).reshape(3, 4)
        result = rearrange(x, "h w -> w h")
        expected = np.transpose(x)
        np.testing.assert_array_equal(result, expected)
    def test_split_axis(self):
        x = np.arange(120).reshape(12, 10)
        result = rearrange(x, "(h w) c -> h w c", h=3)
        expected = np.reshape(x, (3, 4, 10))
        np.testing.assert_array_equal(result, expected)
    def test_merge_axes(self):
        x = np.arange(60).reshape(3, 4, 5)
        result = rearrange(x, "a b c -> (a b) c")
        expected = np.reshape(x, (3 * 4, 5))
        np.testing.assert_array_equal(result, expected)
    def test_repeat_axis(self):
        x = np.arange(15).reshape(3, 1, 5)
        result = rearrange(x, "a 1 c -> a b c", b=4)
        expected = np.repeat(x, 4, axis=1)
        np.testing.assert_array_equal(result, expected)
    def test_batch_dimensions(self):
        x = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5)
        result = rearrange(x, "... h w -> ... (h w)")
        expected = x.reshape(2, 3, 4*5)
        np.testing.assert_array_equal(result, expected)
    def test_invalid_pattern(self):
        x = np.arange(6).reshape(2, 3)
        with self.assertRaises(ValueError):
            rearrange(x, "a b -> a b c")
    def test_missing_axes_length(self):
        x = np.arange(120).reshape(12, 10)
        with self.assertRaises(ValueError):
            rearrange(x, "(h w) c -> h w c")
    def test_extra_axes_length(self):
        x = np.arange(12).reshape(3, 4)
        with self.assertRaises(ValueError):
            rearrange(x, "h w -> w h", x=3)
    def test_complex_combined_operations(self):
        x = np.arange(6 * 4 * 5).reshape(6, 4, 5)
        result = rearrange(x, "(a b) c d -> a (c d) b", a=2)
        expected = x.reshape(2, -1, 3)
        np.testing.assert_array_equal(result, expected)

if __name__ == '__main__':
    unittest.main(argv=[''], exit=False)
