<a href="https://colab.research.google.com/github/Kkumar-007/einflux/blob/main/einnflux_unittest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Unit Testing Einflux**

This notebook contains the implementation and unit tests for the `einflux` module using `pytest`.


## Import necessary libraries

In [98]:
import numpy as np
import unittest
from einflux import rearrange, parse_pattern, expand_dims, get_permutation, reshape_for_output, RearrangeError

## Custom TestResult class for better output formatting

In [99]:
class ColobTestResult(unittest.TextTestResult):
    def __init__(self, stream, descriptions, verbosity):
        super(ColobTestResult, self).__init__(stream, descriptions, verbosity)
        self.successes = []

    def addSuccess(self, test):
        super(ColobTestResult, self).addSuccess(test)
        self.successes.append(test)

## Custom TestRunner class

In [100]:
class ColabTestRunner(unittest.TextTestRunner):
    def __init__(self, verbosity=2, **kwargs):
        super(ColabTestRunner, self).__init__(verbosity=verbosity, resultclass=ColobTestResult, **kwargs)

    def run(self, test):
        result = super(ColabTestRunner, self).run(test)
        return result

## Function to run specific test cases with pretty output

In [101]:
def run_test_case(test_case_instance, test_method_name):
    """
    Run a specific test method with formatted output

    Args:
        test_case_instance: The TestCase class instance
        test_method_name: Name of the test method to run

    Returns:
        Test result object
    """
    print(f"Running: {test_method_name}")
    print("-" * 80)
    test_case = test_case_instance(test_method_name)
    runner = ColabTestRunner(verbosity=2)
    result = runner.run(test_case)
    print("-" * 80)
    return result

## Function to run all tests in a test case

In [102]:
def run_test_suite(test_case_class, title=None):
    """
    Run all tests in a test case class with formatted output

    Args:
        test_case_class: The TestCase class
        title: Optional title for the test suite

    Returns:
        Test result object
    """
    if title:
        print(f"\n{title}")
        print("=" * 80)

    suite = unittest.TestLoader().loadTestsFromTestCase(test_case_class)
    runner = ColabTestRunner(verbosity=2)
    result = runner.run(suite)
    print("-" * 80)
    return result

## TestParsePattern class

In [103]:
class TestParsePattern(unittest.TestCase):
    """Tests for the parse_pattern function"""

    def test_basic_pattern(self):
        """Test parsing of basic patterns"""
        self.assertEqual(parse_pattern("a b c"), ["a", "b", "c"])
        self.assertEqual(parse_pattern("a"), ["a"])
        self.assertEqual(parse_pattern("a b"), ["a", "b"])

    def test_tuple_pattern(self):
        """Test parsing of patterns with tuples"""
        self.assertEqual(parse_pattern("(a b) c"), [("a", "b"), "c"])
        self.assertEqual(parse_pattern("a (b c)"), ["a", ("b", "c")])
        self.assertEqual(parse_pattern("(a b) (c d)"), [("a", "b"), ("c", "d")])

    def test_mixed_pattern(self):
        """Test parsing of mixed patterns with tuples and digits"""
        self.assertEqual(parse_pattern("a (b 2)"), ["a", ("b", "2")])
        self.assertEqual(parse_pattern("(a 3) b"), [("a", "3"), "b"])

    def test_wildcard_pattern(self):
        """Test parsing of patterns with wildcards"""
        self.assertEqual(parse_pattern("a * b"), ["a", "*", "b"])
        self.assertEqual(parse_pattern("a ... b"), ["a", "...", "b"])

    def test_complex_pattern(self):
        """Test parsing of complex patterns"""
        self.assertEqual(parse_pattern("(a b) * (c d)"), [("a", "b"), "*", ("c", "d")])
        self.assertEqual(parse_pattern("a (b 2) ... (c 3)"), ["a", ("b", "2"), "...", ("c", "3")])


## TestExpandDims class

In [104]:
class TestExpandDims(unittest.TestCase):
    """Tests for the expand_dims function"""

    def test_basic_expand(self):
        """Test basic dimension expansion"""
        shape = [2, 3, 4]
        structure = ["a", "b", "c"]
        shape_dict = {}
        result = expand_dims(shape, structure, shape_dict)

        self.assertEqual(result, [2, 3, 4])
        self.assertEqual(shape_dict, {"a": 2, "b": 3, "c": 4})

    def test_tuple_expand_split(self):
        """Test tuple expansion when splitting dimensions"""
        shape = [6, 4]
        structure = [("a", "b"), "c"]
        shape_dict = {"a": 2, "b": 3}
        result = expand_dims(shape, structure, shape_dict, merging=False)

        self.assertEqual(result, [2, 3, 4])
        self.assertEqual(shape_dict, {"a": 2, "b": 3, "c": 4})

    def test_tuple_expand_merge(self):
        """Test tuple expansion when merging dimensions"""
        shape = [2, 3, 4]
        structure = ["a", "b", "c"]
        shape_dict = {"a": 2, "b": 3, "c": 4}

        result = expand_dims([2, 3], [("a", "b")], shape_dict, merging=True)
        self.assertEqual(result, [6])

    def test_ellipsis_expand(self):
        """Test ellipsis dimension expansion"""
        shape = [2, 3, 4, 5, 6]
        structure = ["a", "...", "b"]
        shape_dict = {}
        result = expand_dims(shape, structure, shape_dict)

        self.assertEqual(result, [2, 3, 4, 5, 6])
        self.assertEqual(shape_dict, {"a": 2, "b": 6, "...": [3, 4, 5]})

    def test_error_dimensions(self):
        """Test dimension mismatch error"""
        shape = [2, 3]
        structure = ["a", "b", "c"]
        shape_dict = {}

        with self.assertRaises(RearrangeError):
            expand_dims(shape, structure, shape_dict)

    def test_error_product_mismatch(self):
        """Test product mismatch error"""
        shape = [5, 3]
        structure = [("a", "b"), "c"]
        shape_dict = {"a": 2, "b": 3}  # Product is 6, but shape has 5

        with self.assertRaises(RearrangeError):
            expand_dims(shape, structure, shape_dict, merging=False)

## TestGetPermutation class

In [105]:
class TestGetPermutation(unittest.TestCase):
    """Tests for the get_permutation function"""

    def test_basic_permutation(self):
        """Test basic dimension permutation"""
        input_structure = ["a", "b", "c"]
        output_structure = ["c", "a", "b"]
        shape_dict = {"a": 2, "b": 3, "c": 4}

        result = get_permutation(input_structure, output_structure, shape_dict)
        self.assertEqual(result, [2, 0, 1])

    def test_tuple_permutation(self):
        """Test permutation with tuples"""
        input_structure = ["a", ("b", "c")]
        output_structure = [("c", "b"), "a"]
        shape_dict = {"a": 2, "b": 3, "c": 4}

        result = get_permutation(input_structure, output_structure, shape_dict)
        self.assertEqual(result, [2, 1, 0])

    def test_wildcard_permutation(self):
        """Test permutation with wildcards"""
        input_structure = ["a", "*", "b"]
        output_structure = ["b", "*", "a"]
        shape_dict = {"a": 2, "b": 5, "*": [3, 4]}

        result = get_permutation(input_structure, output_structure, shape_dict)
        self.assertEqual(result, [3, 1, 2, 0])

## TestReshapeForOutput class

In [106]:
class TestReshapeForOutput(unittest.TestCase):
    """Tests for the reshape_for_output function"""

    def test_basic_reshape(self):
        """Test basic reshaping"""
        x = np.zeros((2, 3, 4))
        output_structure = ["a", "b", "c"]
        shape_dict = {"a": 2, "b": 3, "c": 4}

        result = reshape_for_output(x, output_structure, shape_dict)
        self.assertEqual(result, [2, 3, 4])

    def test_tuple_reshape(self):
        """Test reshaping with tuples (merging)"""
        x = np.zeros((2, 3, 4))
        output_structure = ["a", ("b", "c")]
        shape_dict = {"a": 2, "b": 3, "c": 4}

        result = reshape_for_output(x, output_structure, shape_dict)
        self.assertEqual(result, [2, 12])

    def test_numeric_reshape(self):
        """Test reshaping with numeric values"""
        x = np.zeros((2, 3, 4))
        output_structure = ["a", ("2", "b")]
        shape_dict = {"a": 2, "b": 3}

        result = reshape_for_output(x, output_structure, shape_dict)
        self.assertEqual(result, [2, 6])

    def test_wildcard_reshape(self):
        """Test reshaping with wildcards"""
        x = np.zeros((2, 3, 4, 5))
        output_structure = ["a", "*", "b"]
        shape_dict = {"a": 2, "b": 5, "*": [3, 4]}

        result = reshape_for_output(x, output_structure, shape_dict)
        self.assertEqual(result, [2, 3, 4, 5])

    def test_infer_dimension(self):
        """Test inferring a missing dimension"""
        x = np.zeros((2, 3, 4))  # 24 elements
        output_structure = ["a", "d"]
        shape_dict = {"a": 2}

        result = reshape_for_output(x, output_structure, shape_dict)
        self.assertEqual(result, [2, 12])

    def test_error_multiple_unknown(self):
        """Test error with multiple unknown dimensions"""
        x = np.zeros((2, 3, 4))
        output_structure = ["d", "e"]  # Both unknown
        shape_dict = {}

        with self.assertRaises(RearrangeError):
            reshape_for_output(x, output_structure, shape_dict)

## TestRearrange class

In [None]:
class TestRearrange(unittest.TestCase):
    """Tests for the main rearrange function"""

    def test_basic_transpose(self):
        """Test basic dimension transposition"""
        x = np.zeros((2, 3, 4))
        result = rearrange(x, "a b c -> c a b")

        self.assertEqual(result.shape, (4, 2, 3))

    def test_merge_dimensions(self):
        """Test merging dimensions"""
        x = np.zeros((2, 3, 4))
        result = rearrange(x, "a b c -> a (b c)")

        self.assertEqual(result.shape, (2, 12))

    def test_combine_operations(self):
        """Test combining split, merge and transpose"""
        x = np.zeros((6, 8))
        result = rearrange(x, "(a b) (c d) -> c (b d) a", a=2, b=3, c=4, d=2)

        self.assertEqual(result.shape, (4, 6, 2))

    def test_wildcard_dimensions(self):
        """Test using wildcard dimensions"""
        x = np.zeros((2, 3, 4, 5))
        result = rearrange(x, "a * b -> b * a")

        self.assertEqual(result.shape, (5, 3, 4, 2))

    def test_ellipsis_dimensions(self):
        """Test using ellipsis dimensions"""
        x = np.zeros((2, 3, 4, 5, 6))
        result = rearrange(x, "a ... b -> b ... a")

        self.assertEqual(result.shape, (6, 3, 4, 5, 2))

    def test_numeric_dimensions(self):
        """Test using numeric dimensions"""
        x = np.zeros((6, 4))
        result = rearrange(x, "(a 2) b -> 2 (a b)", a=3)

        self.assertEqual(result.shape, (2, 12))

    def test_non_numpy_error(self):
        """Test error with non-numpy input"""
        x = [1, 2, 3]

        with self.assertRaises(TypeError):
            rearrange(x, "a -> a")

    def test_pattern_format_error(self):
        """Test error with invalid pattern format"""
        x = np.zeros((2, 3))

        with self.assertRaises(RearrangeError):
            rearrange(x, "a b")  # Missing arrow

    def test_unused_args_error(self):
        """Test error with unused shape dict arguments"""
        x = np.zeros((2, 3))

        with self.assertRaises(RearrangeError):
            rearrange(x, "a b -> b a", c=4)  # c is unused

    def test_dimension_mismatch_error(self):
        """Test error with dimension count mismatch"""
        x = np.zeros((2, 3))

        with self.assertRaises(RearrangeError):
            rearrange(x, "a b c -> a b c")  # Input has only 2 dims

    def test_product_mismatch_error(self):
        """Test error with product mismatch"""
        x = np.zeros((5, 3))

        with self.assertRaises(RearrangeError):
            rearrange(x, "(a b) c -> (a c) b", a=2, b=3)  # 2*3 = 6 != 5

    def test_real_world_examples(self):
        """Test some real-world examples"""
        # Convert batch-channel-height-width to batch-height-width-channel (BCHW -> BHWC)
        x = np.zeros((32, 3, 128, 128))
        result = rearrange(x, "b c h w -> b h w c")
        self.assertEqual(result.shape, (32, 128, 128, 3))

        # Matrix multiplication via einsum-like notation
        x = np.random.rand(10, 5)
        y = np.random.rand(5, 7)
        result = rearrange(np.einsum('ij,jk->ik', x, y), "i j -> i j")
        self.assertEqual(result.shape, (10, 7))

    def test_performance_large_array(self):
        """Test performance with a large array (not strictly a unit test)"""
        x = np.zeros((32, 64, 64, 3))  # Typical image batch
        result = rearrange(x, "b h w c -> b c h w")
        self.assertEqual(result.shape, (32, 3, 64, 64))


## Example usage - Run a specific test

In [107]:
print("Running a single test case:")
run_test_case(TestParsePattern, 'test_basic_pattern')

test_basic_pattern (__main__.TestParsePattern.test_basic_pattern)
Test parsing of basic patterns ... ok

----------------------------------------------------------------------
Ran 1 test in 0.002s

OK


Running a single test case:
Running: test_basic_pattern
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------


<__main__.ColobTestResult run=1 errors=0 failures=0>

## Example usage - Run all tests in a class

In [108]:
print("\nRunning a full test suite:")
run_test_suite(TestParsePattern, "Parse Pattern Tests")

test_basic_pattern (__main__.TestParsePattern.test_basic_pattern)
Test parsing of basic patterns ... ok
test_complex_pattern (__main__.TestParsePattern.test_complex_pattern)
Test parsing of complex patterns ... ok
test_mixed_pattern (__main__.TestParsePattern.test_mixed_pattern)
Test parsing of mixed patterns with tuples and digits ... ok
test_tuple_pattern (__main__.TestParsePattern.test_tuple_pattern)
Test parsing of patterns with tuples ... ok
test_wildcard_pattern (__main__.TestParsePattern.test_wildcard_pattern)
Test parsing of patterns with wildcards ... ok

----------------------------------------------------------------------
Ran 5 tests in 0.020s

OK



Running a full test suite:

Parse Pattern Tests
--------------------------------------------------------------------------------


<__main__.ColobTestResult run=5 errors=0 failures=0>

## Example usage - Running selected tests

In [109]:
def run_selected_tests(test_class, test_methods):
    """
    Run selected test methods from a test class

    Args:
        test_class: The TestCase class
        test_methods: List of test method names to run
    """
    print(f"\nRunning selected tests from {test_class.__name__}:")
    print("=" * 80)

    for method in test_methods:
        run_test_case(test_class, method)

## Example usage of run_selected_tests

In [110]:
run_selected_tests(TestParsePattern, [
    'test_tuple_pattern',
    'test_wildcard_pattern'
])

test_tuple_pattern (__main__.TestParsePattern.test_tuple_pattern)
Test parsing of patterns with tuples ... ok

----------------------------------------------------------------------
Ran 1 test in 0.003s

OK
test_wildcard_pattern (__main__.TestParsePattern.test_wildcard_pattern)
Test parsing of patterns with wildcards ... ok

----------------------------------------------------------------------
Ran 1 test in 0.002s

OK



Running selected tests from TestParsePattern:
Running: test_tuple_pattern
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
Running: test_wildcard_pattern
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------


## Run all tests in the project with summary

In [111]:
def run_all_test_suites():
    """Run all test suites and display a summary"""
    print("\nRUNNING ALL TEST SUITES")
    print("=" * 80)

    test_classes = [
        (TestParsePattern, "Parse Pattern Tests"),
        # Add other test classes here
    ]

    total_tests = 0
    total_failures = 0
    total_errors = 0

    for test_class, title in test_classes:
        result = run_test_suite(test_class, title)
        total_tests += result.testsRun
        total_failures += len(result.failures)
        total_errors += len(result.errors)

    print("\nTEST SUMMARY")
    print("=" * 80)
    print(f"Total tests run: {total_tests}")
    print(f"Failures: {total_failures}")
    print(f"Errors: {total_errors}")

    if total_failures == 0 and total_errors == 0:
        print("\n✅ ALL TESTS PASSED")
    else:
        print("\n❌ SOME TESTS FAILED")


## Example usage with main() function

In [114]:
if __name__ == '__main__':
    # Option 1: Run all tests using unittest.main()
    unittest.main(argv=['first-arg-is-ignored'], exit=False, verbosity=2)

    # Option 2: Run specific test suites
    # run_test_suite(TestParsePattern, "Parse Pattern Tests")

    # Option 3: Run all test suites with summary
    # run_all_test_suites()

test_basic_expand (__main__.TestExpandDims.test_basic_expand)
Test basic dimension expansion ... ok
test_ellipsis_expand (__main__.TestExpandDims.test_ellipsis_expand)
Test ellipsis dimension expansion ... ok
test_error_dimensions (__main__.TestExpandDims.test_error_dimensions)
Test dimension mismatch error ... ok
test_error_product_mismatch (__main__.TestExpandDims.test_error_product_mismatch)
Test product mismatch error ... ok
test_tuple_expand_merge (__main__.TestExpandDims.test_tuple_expand_merge)
Test tuple expansion when merging dimensions ... ok
test_tuple_expand_split (__main__.TestExpandDims.test_tuple_expand_split)
Test tuple expansion when splitting dimensions ... ok
test_basic_permutation (__main__.TestGetPermutation.test_basic_permutation)
Test basic dimension permutation ... ok
test_tuple_permutation (__main__.TestGetPermutation.test_tuple_permutation)
Test permutation with tuples ... ok
test_wildcard_permutation (__main__.TestGetPermutation.test_wildcard_permutation)
Tes