In [74]:
import torch.nn as nn
import torch
import json
from card_embedding import *

import timeit
def benchmark(func, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin), torch.cuda.max_memory_allocated()

In [2]:
d=512
shared = SharedEmbeddingHolder(d, device='cuda')
filter_condition_embedding = FilterConditionEmbedding(shared, d, device='cuda')
filter_condition_embedding.compile()
filter_condition_embedding_uncompiled = FilterConditionEmbedding(shared, d, device='cuda')

In [3]:
batch_size = 2000
field_type = torch.randint(0, 6, (batch_size,), device='cuda')
comparison_operator = torch.randint(0, 5, (batch_size,), device='cuda')
value = torch.zeros_like(field_type, device='cuda')
for i in range(batch_size):
    if field_type[i] == 3:
        value[i] = torch.randint(0, 4, (1,), device='cuda')
    elif field_type[i] == 4:
        value[i] = torch.randint(0, 10, (1,), device='cuda')
    elif field_type[i] == 5:
        value[i] = torch.randint(0, 300, (1,), device='cuda')
print(field_type)
print(comparison_operator)
print(value)



tensor([3, 0, 2,  ..., 5, 0, 1], device='cuda:0')
tensor([2, 0, 3,  ..., 0, 4, 2], device='cuda:0')
tensor([ 0,  0,  0,  ..., 33,  0,  0], device='cuda:0')


In [4]:
for _ in range(10):
    _ = filter_condition_embedding.forward(field_type, comparison_operator, value)
    _ = filter_condition_embedding.forward_old(field_type, comparison_operator, value)
    _ = filter_condition_embedding_uncompiled.forward(field_type, comparison_operator, value)
    _ = filter_condition_embedding_uncompiled.forward_old(field_type, comparison_operator, value)
torch.cuda.synchronize()
print("time for forward: ", benchmark(filter_condition_embedding.forward, field_type, comparison_operator, value)[1])
print("time for forward_uncompiled: ", benchmark(filter_condition_embedding_uncompiled.forward, field_type, comparison_operator, value)[1])
print("time for forward_old: ", benchmark(filter_condition_embedding.forward_old, field_type, comparison_operator, value)[1])
print("time for forward_old_uncompiled: ", benchmark(filter_condition_embedding_uncompiled.forward_old, field_type, comparison_operator, value)[1])

time for forward:  0.001107612000851077
time for forward_uncompiled:  0.001174913000795641
time for forward_old:  0.08688885000083246
time for forward_old_uncompiled:  0.08165832299891918


In [None]:
nested_input=[{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {
                    "IsLeaf": True,
                    "Value": "a"
                },
                {
                    "IsLeaf": True,
                    "Value": "b"
                }
            ]
        },
        {
            "IsLeaf": False,
            "Nested": [
                {
                    "IsLeaf": True,
                    "Value": "c"
                },
                {
                    "IsLeaf": False,
                    "Nested": [
                        {
                            "IsLeaf": True,
                            "Value": "d"
                        },
                        {
                            "IsLeaf": True,
                            "Value": "e"
                        }
                    ]
                }
            ]
        }
    ]
},{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {
                    "IsLeaf": True,
                    "Value": "a"
                },
                {
                    "IsLeaf": True,
                    "Value": "b"
                }
            ]
        },
        {
            "IsLeaf": False,
            "Nested": [
                {
                    "IsLeaf": True,
                    "Value": "c"
                },
                {
                    "IsLeaf": False,
                    "Nested": [
                        {
                            "IsLeaf": True,
                            "Value": "d"
                        },
                        {
                            "IsLeaf": True,
                            "Value": "e"
                        }
                    ]
                }
            ]
        }
    ]
}]


In [None]:
def traverse_nested(nested_input, path_list=[]):

    for i, node in enumerate(nested_input):
        path_list.append(str(i))
        if node["IsLeaf"]:
            yield (node["Value"], path_list.copy())
        else:
            yield from traverse_nested_nodes(node["Nested"], path_list)
        path_list.pop()  # Backtrack


def traverse_nested_nodes(nested_input, path_list=[]):
    """Use list-based paths instead of string concatenation"""

    for i, node in enumerate(nested_input):
        if node["IsLeaf"]:
            yield (node["Value"], path_list.copy())
        else:
            path_list.append(str(i))
            yield from traverse_nested_nodes(node["Nested"], path_list)
            path_list.pop()  # Backtrack

def traverse_nested_optimized(nested_input, path_tuple=()):
    """Use tuples instead of lists - immutable and faster for comparisons"""
    for i, node in enumerate(nested_input):
        new_path = path_tuple + (str(i),)
        if node["IsLeaf"]:
            yield (node["Value"], new_path)
        else:
            yield from traverse_nested_nodes_optimized(node["Nested"], new_path)


def traverse_nested_nodes_optimized(nested_input, path_tuple=()):

    for i, node in enumerate(nested_input):
        if node["IsLeaf"]:
            yield (node["Value"], path_tuple)
        else:
            new_path = path_tuple + (str(i),)
            yield from traverse_nested_nodes_optimized(node["Nested"], new_path)

def traverse_nested_v2(nested_input, path_list=[]):

    for i, node in enumerate(nested_input):
        path_list.append(str(i))
        if node["IsLeaf"]:
            yield (node["Value"], tuple(path_list))
        else:
            yield from traverse_nested_nodes_v2(node["Nested"], path_list)
        path_list.pop()  # Backtrack


def traverse_nested_nodes_v2(nested_input, path_list=[]):
    """Use list-based paths instead of string concatenation"""

    for i, node in enumerate(nested_input):
        if node["IsLeaf"]:
            yield (node["Value"], tuple(path_list))
        else:
            path_list.append(str(i))
            yield from traverse_nested_nodes_v2(node["Nested"], path_list)
            path_list.pop()  # Backtrack

def flatten_nested_optimized(nested_input):
    """Optimized flattening with tuple paths"""
    flattened_input = []
    groups = []
    for value, group_index in traverse_nested_optimized(nested_input):
        flattened_input.append(value)
        groups.append(group_index)
    return flattened_input, groups

def flatten_nested(nested_input):
    flattened_input = []
    groups = []
    for value, group_index in traverse_nested(nested_input):
        flattened_input.append(value)
        groups.append(group_index)
    return flattened_input, groups

def flatten(nested_input, traverse_function):
    flattened_input = []
    groups = []
    for value, group_index in traverse_function(nested_input):
        flattened_input.append(value)
        groups.append(group_index)
    return flattened_input, groups

def combine(elements):
    if len(elements) == 1:
        return elements[0]
    return "(" + "+".join(elements) + ")"


# [['root'], ['root', '1'], ['root', '1'], ['root', '1'], ['root', '2', '0'], ['root', '2', '0'], ['root', '2']]


def is_prefix(prefix, test_list):
    return len(prefix) <= len(test_list) and test_list[: len(prefix)] == prefix

def add_to_stack(element, current_combination, current_groups, group_index):
    if current_groups and group_index == current_groups[-1]:
        current_combination[-1].append(element)
    else:
        break_down_stack(current_combination, current_groups, group_index)
        if current_groups and group_index == current_groups[-1]:
            current_combination[-1].append(element)
        else:
            current_combination.append([element])
            current_groups.append(group_index)


def break_down_stack(current_combination, current_groups, group_index):
    while current_groups and not is_prefix(current_groups[-1], group_index):
        combined = combine(current_combination.pop())
        reduced_index = current_groups.pop()[:-1]
        add_to_stack(combined, current_combination, current_groups, reduced_index)


def reduce(flattened_input, groups):
    current_combination = []
    current_groups = []
    current_batch_index = 0

    for i in range(len(groups)):
        if not current_groups:
            current_groups.append([str(current_batch_index)])
            current_batch_index += 1
            current_combination.append([])

        group_index = groups[i]
        add_to_stack(
            flattened_input[i], current_combination, current_groups, group_index
        )

    break_down_stack(current_combination, current_groups, [])
    if not current_combination:
        return []
    return current_combination.pop()


def add_to_stack_optimized(element, current_combination, current_groups, group_index):
    """Optimized: single check after break_down_stack"""
    if current_groups and group_index == current_groups[-1]:
        current_combination[-1].append(element)
        return
    
    # Break down stack until we find the right place
    break_down_stack_optimized(current_combination, current_groups, group_index)
    
    # After breaking down, we know where to add
    if current_groups and group_index == current_groups[-1]:
        current_combination[-1].append(element)
    else:
        current_combination.append([element])
        current_groups.append(group_index)


def break_down_stack_optimized(current_combination, current_groups, group_index):
    """Optimized: avoid creating new tuples unnecessarily"""
    while current_groups and not is_prefix(current_groups[-1], group_index):
        combined = combine(current_combination.pop())
        # More efficient: create tuple without full slice
        reduced_index = current_groups.pop()[:-1]
        add_to_stack_optimized(combined, current_combination, current_groups, reduced_index)


def reduce_optimized(flattened_input, groups):
    """Optimized reduce function"""
    if not flattened_input:
        return []
    
    current_combination = []
    current_groups = []
    current_batch_index = 0

    for i in range(len(groups)):
        if not current_groups:
            current_groups.append((str(current_batch_index),))
            current_batch_index += 1
            current_combination.append([])

        group_index = groups[i]
        add_to_stack_optimized(
            flattened_input[i], current_combination, current_groups, group_index
        )

    break_down_stack_optimized(current_combination, current_groups, ())
    if not current_combination:
        return []
    return current_combination.pop()



In [None]:
flattened, groups = flatten(test_2, traverse_nested_v2)
print(flattened)
print(groups)


['a', 'b', 'c']
[('0',), ('0',), ('0',)]


IndexError: list index out of range

In [186]:
_, group_tuple = flatten(long_batch, traverse_nested_v2)
_, group_list = flatten_nested(long_batch)
print([list(element) for element in group_tuple])
print(group_list)
print([list(element) for element in group_tuple] == group_list)


[['0'], ['1'], ['1'], ['1'], ['2', '0'], ['2', '0'], ['2'], ['3', '0'], ['3', '0'], ['3', '1'], ['3', '1'], ['4', '0'], ['4', '0'], ['4', '1'], ['4', '1', '1'], ['4', '1', '1'], ['5', '0', '0'], ['6'], ['6'], ['6'], ['6'], ['6'], ['6'], ['7'], ['7', '1'], ['7', '1'], ['7'], ['8', '0'], ['8', '1'], ['8', '1'], ['8', '1', '2'], ['9', '0'], ['9', '0'], ['9', '0'], ['9'], ['9', '2'], ['9', '2', '1'], ['9', '2', '1'], ['9', '2', '1'], ['9'], ['10'], ['10'], ['11', '0', '0', '0'], ['11', '0', '0', '0'], ['12', '0'], ['12', '0'], ['12', '0'], ['12', '0'], ['13', '0'], ['13', '0'], ['13', '1'], ['13', '1'], ['13', '2'], ['13', '2'], ['15', '0'], ['15', '0'], ['15', '0'], ['15', '0'], ['15', '0'], ['15'], ['16'], ['17', '0', '0'], ['17', '0', '0'], ['17', '0'], ['17'], ['18'], ['18', '1'], ['18', '1', '1'], ['18', '1', '1'], ['19'], ['19', '1'], ['19', '1'], ['19', '2', '0'], ['19', '2', '0'], ['19', '2'], ['19']]
[['0'], ['1'], ['1'], ['1'], ['2', '0'], ['2', '0'], ['2'], ['3', '0'], ['3', '0'

In [195]:
# Test Suite for flatten_nested_append and reduce functions

# Test Case 1: Single leaf node
test_1 = [{
    "IsLeaf": True,
    "Value": "a"
}]
# Expected output: "a"
# Flattened: ['a'], Groups: ['']

# Test Case 2: Flat structure - all leaves at root level
test_2 = [{
    "IsLeaf": False,
    "Nested": [
        {"IsLeaf": True, "Value": "a"},
        {"IsLeaf": True, "Value": "b"},
        {"IsLeaf": True, "Value": "c"}
    ]
}]
# Expected output: "(a+b+c)"
# Flattened: ['a', 'b', 'c'], Groups: ['0', '0', '0']

# Test Case 3: Two-level nesting - simple binary tree
test_3 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "a"},
                {"IsLeaf": True, "Value": "b"}
            ]
        },
        {"IsLeaf": True, "Value": "c"}
    ]
}]
# Expected output: "((a+b)+c)"
# Flattened: ['a', 'b', 'c'], Groups: ['00', '00', '0']

# Test Case 4: Two-level nesting - all nested
test_4 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "a"},
                {"IsLeaf": True, "Value": "b"}
            ]
        },
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "c"},
                {"IsLeaf": True, "Value": "d"}
            ]
        }
    ]
}]
# Expected output: "((a+b)+(c+d))"
# Flattened: ['a', 'b', 'c', 'd'], Groups: ['00', '00', '01', '01']

# Test Case 5: Three-level nesting (your original example)
test_5 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "a"},
                {"IsLeaf": True, "Value": "b"}
            ]
        },
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "c"},
                {
                    "IsLeaf": False,
                    "Nested": [
                        {"IsLeaf": True, "Value": "d"},
                        {"IsLeaf": True, "Value": "e"}
                    ]
                }
            ]
        }
    ]
}]
# Expected output: "((a+b)+(c+(d+e)))"
# Flattened: ['a', 'b', 'c', 'd', 'e'], Groups: ['00', '00', '01', '011', '011']

# Test Case 6: Single child at each level (deep chain)
test_6 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {
                    "IsLeaf": False,
                    "Nested": [
                        {"IsLeaf": True, "Value": "a"}
                    ]
                }
            ]
        }
    ]
}]
# Expected output: "(((a)))"
# Flattened: ['a'], Groups: ['000']

# Test Case 7: Wide branching (many siblings)
test_7 = [{
    "IsLeaf": False,
    "Nested": [
        {"IsLeaf": True, "Value": "a"},
        {"IsLeaf": True, "Value": "b"},
        {"IsLeaf": True, "Value": "c"},
        {"IsLeaf": True, "Value": "d"},
        {"IsLeaf": True, "Value": "e"},
        {"IsLeaf": True, "Value": "f"}
    ]
}]
# Expected output: "(a+b+c+d+e+f)"
# Flattened: ['a', 'b', 'c', 'd', 'e', 'f'], Groups: ['0', '0', '0', '0', '0', '0']

# Test Case 8: Mixed structure - leaves and nested at same level
test_8 = [{
    "IsLeaf": False,
    "Nested": [
        {"IsLeaf": True, "Value": "a"},
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "b"},
                {"IsLeaf": True, "Value": "c"}
            ]
        },
        {"IsLeaf": True, "Value": "d"}
    ]
}]
# Expected output: "(a+(b+c)+d)"
# Flattened: ['a', 'b', 'c', 'd'], Groups: ['0', '10', '10', '0']

# Test Case 9: Asymmetric tree - different depths on each side
test_9 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "a"}
            ]
        },
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "b"},
                {"IsLeaf": True, "Value": "c"},
                {
                    "IsLeaf": False,
                    "Nested": [
                        {"IsLeaf": True, "Value": "d"}
                    ]
                }
            ]
        }
    ]
}]
# Expected output: "(((a))+((b+c+(d))))"
# Flattened: ['a', 'b', 'c', 'd'], Groups: ['00', '10', '10', '110']

# Test Case 10: Complex mixed structure
test_10 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "a"},
                {"IsLeaf": True, "Value": "b"},
                {"IsLeaf": True, "Value": "c"}
            ]
        },
        {"IsLeaf": True, "Value": "d"},
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "e"},
                {
                    "IsLeaf": False,
                    "Nested": [
                        {"IsLeaf": True, "Value": "f"},
                        {"IsLeaf": True, "Value": "g"},
                        {"IsLeaf": True, "Value": "h"}
                    ]
                }
            ]
        },
        {"IsLeaf": True, "Value": "i"}
    ]
}]
# Expected output: "((a+b+c)+d+(e+(f+g+h))+i)"
# Flattened: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i'], Groups: ['00', '00', '00', '0', '10', '110', '110', '110', '0']

# Test Case 11: Multiple batches (like your original example)
test_11 = [
    {
        "IsLeaf": False,
        "Nested": [
            {"IsLeaf": True, "Value": "a"},
            {"IsLeaf": True, "Value": "b"}
        ]
    },
    {
        "IsLeaf": False,
        "Nested": [
            {
                "IsLeaf": False,
                "Nested": [
                    {"IsLeaf": True, "Value": "c"},
                    {"IsLeaf": True, "Value": "d"}
                ]
            },
            {"IsLeaf": True, "Value": "e"}
        ]
    },
    {
        "IsLeaf": True,
        "Value": "f"
    }
]
# Expected outputs: ["(a+b)", "((c+d)+e)", "f"]
# Flattened: ['a', 'b', 'c', 'd', 'e', 'f'], Groups: ['0', '0', '1', '10', '10', '1']

# Test Case 12: Deep nesting (4 levels)
test_12 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {
                    "IsLeaf": False,
                    "Nested": [
                        {
                            "IsLeaf": False,
                            "Nested": [
                                {"IsLeaf": True, "Value": "a"},
                                {"IsLeaf": True, "Value": "b"}
                            ]
                        }
                    ]
                }
            ]
        }
    ]
}]
# Expected output: "((((a+b))))"
# Flattened: ['a', 'b'], Groups: ['0000', '0000']

# Test Case 13: Single nested with multiple leaves
test_13 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "a"},
                {"IsLeaf": True, "Value": "b"},
                {"IsLeaf": True, "Value": "c"},
                {"IsLeaf": True, "Value": "d"}
            ]
        }
    ]
}]
# Expected output: "((a+b+c+d))"
# Flattened: ['a', 'b', 'c', 'd'], Groups: ['00', '00', '00', '00']

# Test Case 14: Three-way branching
test_14 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "a"},
                {"IsLeaf": True, "Value": "b"}
            ]
        },
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "c"},
                {"IsLeaf": True, "Value": "d"}
            ]
        },
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "e"},
                {"IsLeaf": True, "Value": "f"}
            ]
        }
    ]
}]
# Expected output: "((a+b)+(c+d)+(e+f))"
# Flattened: ['a', 'b', 'c', 'd', 'e', 'f'], Groups: ['00', '00', '01', '01', '02', '02']

# Test Case 15: Edge case - empty nested (if this is valid)
# Note: This might cause issues, but good to test
test_15 = [{
    "IsLeaf": False,
    "Nested": []
}]
# This might produce an error or empty result

# Test Case 16: Very wide at second level
test_16 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "a"},
                {"IsLeaf": True, "Value": "b"},
                {"IsLeaf": True, "Value": "c"},
                {"IsLeaf": True, "Value": "d"},
                {"IsLeaf": True, "Value": "e"}
            ]
        },
        {"IsLeaf": True, "Value": "f"}
    ]
}]
# Expected output: "((a+b+c+d+e)+f)"
# Flattened: ['a', 'b', 'c', 'd', 'e', 'f'], Groups: ['00', '00', '00', '00', '00', '0']

# Test Case 17: Multiple batches with varying complexity
test_17 = [
    {"IsLeaf": True, "Value": "a"},
    {
        "IsLeaf": False,
        "Nested": [
            {"IsLeaf": True, "Value": "b"},
            {"IsLeaf": True, "Value": "c"}
        ]
    },
    {
        "IsLeaf": False,
        "Nested": [
            {
                "IsLeaf": False,
                "Nested": [
                    {"IsLeaf": True, "Value": "d"},
                    {"IsLeaf": True, "Value": "e"}
                ]
            }
        ]
    }
]
# Expected outputs: ["a", "(b+c)", "((d+e))"]
# Flattened: ['a', 'b', 'c', 'd', 'e'], Groups: ['', '0', '0', '10', '10']

# Test Case 18: Unbalanced tree - left heavy
test_18 = [{
    "IsLeaf": False,
    "Nested": [
        {
            "IsLeaf": False,
            "Nested": [
                {
                    "IsLeaf": False,
                    "Nested": [
                        {"IsLeaf": True, "Value": "a"},
                        {"IsLeaf": True, "Value": "b"}
                    ]
                },
                {"IsLeaf": True, "Value": "c"}
            ]
        },
        {"IsLeaf": True, "Value": "d"}
    ]
}]
# Expected output: "(((a+b)+c)+d)"
# Flattened: ['a', 'b', 'c', 'd'], Groups: ['000', '000', '00', '0']

# Test Case 19: Unbalanced tree - right heavy
test_19 = [{
    "IsLeaf": False,
    "Nested": [
        {"IsLeaf": True, "Value": "a"},
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "b"},
                {
                    "IsLeaf": False,
                    "Nested": [
                        {"IsLeaf": True, "Value": "c"},
                        {"IsLeaf": True, "Value": "d"}
                    ]
                }
            ]
        }
    ]
}]
# Expected output: "(a+(b+(c+d)))"
# Flattened: ['a', 'b', 'c', 'd'], Groups: ['0', '10', '110', '110']

# Test Case 20: Maximum complexity - all combinations
test_20 = [{
    "IsLeaf": False,
    "Nested": [
        {"IsLeaf": True, "Value": "a"},
        {
            "IsLeaf": False,
            "Nested": [
                {"IsLeaf": True, "Value": "b"},
                {"IsLeaf": True, "Value": "c"}
            ]
        },
        {
            "IsLeaf": False,
            "Nested": [
                {
                    "IsLeaf": False,
                    "Nested": [
                        {"IsLeaf": True, "Value": "d"},
                        {"IsLeaf": True, "Value": "e"}
                    ]
                },
                {"IsLeaf": True, "Value": "f"}
            ]
        },
        {"IsLeaf": True, "Value": "g"}
    ]
}]
# Expected output: "(a+(b+c)+((d+e)+f)+g)"
# Flattened: ['a', 'b', 'c', 'd', 'e', 'f', 'g'], Groups: ['0', '10', '10', '200', '200', '20', '0']
# Test runner function
def run_tests():
    test_cases = [
        ("Test 1: Single leaf", test_1, ["a"]),
        ("Test 2: Flat structure", test_2, ["(a+b+c)"]),
        ("Test 3: Two-level simple", test_3, ["((a+b)+c)"]),
        ("Test 4: Two-level all nested", test_4, ["((a+b)+(c+d))"]),
        ("Test 5: Three-level (original)", test_5, ["((a+b)+(c+(d+e)))"]),
        ("Test 6: Deep chain", test_6, ["a"]),
        ("Test 7: Wide branching", test_7, ["(a+b+c+d+e+f)"]),
        ("Test 8: Mixed structure", test_8, ["(a+(b+c)+d)"]),
        ("Test 9: Asymmetric tree", test_9, ["(a+(b+c+d))"]),
        ("Test 10: Complex mixed", test_10, ["((a+b+c)+d+(e+(f+g+h))+i)"]),
        ("Test 11: Multiple batches", test_11, ["(a+b)", "((c+d)+e)", "f"]),
        ("Test 12: Deep nesting", test_12, ["(a+b)"]),
        ("Test 13: Single nested multiple leaves", test_13, ["(a+b+c+d)"]),
        ("Test 14: Three-way branching", test_14, ["((a+b)+(c+d)+(e+f))"]),
        ("Test 15: Empty", test_15, []),
        ("Test 16: Very wide second level", test_16, ["((a+b+c+d+e)+f)"]),
        ("Test 17: Multiple batches with varying complexity", test_17, ["a", "(b+c)", "(d+e)"]),
        ("Test 18: Left heavy", test_18, ["(((a+b)+c)+d)"]),
        ("Test 19: Right heavy", test_19, ["(a+(b+(c+d)))"]),
        ("Test 20: Maximum complexity", test_20, ["(a+(b+c)+((d+e)+f)+g)"]),
    ]
    
    print("Running test suite...\n")
    passed = 0
    failed = 0
    
    for name, test_input, expected in test_cases:
        try:
            flattened, groups = flatten_nested_optimized(test_input)
            result = reduce_ultra_optimized(flattened, groups)
            if result == expected:
                print(f"✓ {name}: PASSED")
                print(f"  Result: {result}")
                print(f"  Flattened: {flattened}")
                print(f"  Groups:    {groups}")
                passed += 1
            else:
                print(f"✗ {name}: FAILED")
                print(f"  Expected: {expected}")
                print(f"  Got:      {result}")
                print(f"  Flattened: {flattened}")
                print(f"  Groups:    {groups}")
                failed += 1
        except Exception as e:
            print(f"✗ {name}: ERROR - {e}")
            failed += 1
        print()
    
    print(f"\nSummary: {passed} passed, {failed} failed")

# Run the tests
run_tests()

Running test suite...

✓ Test 1: Single leaf: PASSED
  Result: ['a']
  Flattened: ['a']
  Groups:    [('0',)]

✗ Test 2: Flat structure: FAILED
  Expected: ['(a+b+c)']
  Got:      ['a', 'b', 'c']
  Flattened: ['a', 'b', 'c']
  Groups:    [('0',), ('0',), ('0',)]

✗ Test 3: Two-level simple: FAILED
  Expected: ['((a+b)+c)']
  Got:      ['(a+b)', 'c']
  Flattened: ['a', 'b', 'c']
  Groups:    [('0', '0'), ('0', '0'), ('0',)]

✗ Test 4: Two-level all nested: FAILED
  Expected: ['((a+b)+(c+d))']
  Got:      ['(a+b)', '(c+d)']
  Flattened: ['a', 'b', 'c', 'd']
  Groups:    [('0', '0'), ('0', '0'), ('0', '1'), ('0', '1')]

✗ Test 5: Three-level (original): FAILED
  Expected: ['((a+b)+(c+(d+e)))']
  Got:      ['(a+b)', '(c+(d+e))']
  Flattened: ['a', 'b', 'c', 'd', 'e']
  Groups:    [('0', '0'), ('0', '0'), ('0', '1'), ('0', '1', '1'), ('0', '1', '1')]

✓ Test 6: Deep chain: PASSED
  Result: ['a']
  Flattened: ['a']
  Groups:    [('0', '0', '0')]

✗ Test 7: Wide branching: FAILED
  Expected: 

In [72]:
long_batch = [test_1[0], test_2[0], test_3[0], test_4[0], test_5[0], test_6[0], test_7[0], test_8[0], test_9[0], test_10[0], test_11[0], test_12[0], test_13[0], test_14[0], test_15[0], test_16[0], test_17[0], test_18[0], test_19[0], test_20[0]]

In [116]:
result_old, time_old, _ = benchmark(flatten_nested,long_batch)
result_new, time_new, _ = benchmark(flatten_nested_optimized,long_batch)
print(f"Time taken for old: {time_old} seconds, new: {time_new} seconds, ratio: {time_old/time_new}")
print(result_old)
print(result_new)

Time taken for old: 9.14209995244164e-05 seconds, new: 5.426000279840082e-05 seconds, ratio: 1.6848690528838457
(['a', 'a', 'b', 'c', 'a', 'b', 'c', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'e', 'a', 'a', 'b', 'c', 'd', 'e', 'f', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'a', 'b', 'a', 'b', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'e', 'f', 'a', 'b', 'c', 'd', 'e', 'f', 'a', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'e', 'f', 'g'], [['0'], ['1'], ['1'], ['1'], ['2', '0'], ['2', '0'], ['2'], ['3', '0'], ['3', '0'], ['3', '1'], ['3', '1'], ['4', '0'], ['4', '0'], ['4', '1'], ['4', '1', '1'], ['4', '1', '1'], ['5', '0', '0'], ['6'], ['6'], ['6'], ['6'], ['6'], ['6'], ['7'], ['7', '1'], ['7', '1'], ['7'], ['8', '0'], ['8', '1'], ['8', '1'], ['8', '1', '2'], ['9', '0'], ['9', '0'], ['9', '0'], ['9'], ['9', '2'], ['9', '2', '1'], ['9', '2', '1'], ['9', '2', '1'], ['9'], ['10'], ['10'], ['11', '0', '0', '0'], ['11', '0', '0', '0'], ['

In [202]:
# total_time_old = 0
# total_time_new = 0
# N=100000
# for i in range(N):
#     result_old, time_old, _ = benchmark(reduce,flatten_nested(long_batch)[0], flatten_nested(long_batch)[1])
#     result_new, time_new, _ = benchmark(reduce_optimized,flatten_nested_optimized(long_batch)[0], flatten_nested_optimized(long_batch)[1])
#     total_time_old += time_old
#     total_time_new += time_new
def test_current():
    flattened, groups = flatten_nested(long_batch)
    return reduce(flattened, groups)


def test_optimized():
    flattened, groups = flatten_nested_optimized(long_batch)
    return reduce_optimized(flattened, groups)

def test_optimized_v2():
    flattened, groups = flatten(long_batch, traverse_nested_v2)
    return reduce_ultra_optimized(flattened, groups)

# Warm up
for _ in range(100):
    test_current()
    test_optimized()
    test_optimized_v2()
# Benchmark
n = 100000
time_current = timeit.timeit(test_current, number=n)
time_optimized = timeit.timeit(test_optimized, number=n)
time_optimized_v2 = timeit.timeit(test_optimized_v2, number=n)
print(f"Current:     {time_current/n:.4e}s")
print(f"Optimized:   {time_optimized/n:.4e}s ({time_current/time_optimized:.3f}x faster)")
print(f"Optimized v2: {time_optimized_v2/n:.4e}s ({time_current/time_optimized_v2:.3f}x faster)")
print(len(result_old))
print(len(result_new))
print(len(long_batch))

Current:     6.3347e-05s
Optimized:   6.2097e-05s (1.020x faster)
Optimized v2: 5.4939e-05s (1.153x faster)
19
19
20


In [None]:
for i in range(100):
    result_old, time_old, _ = benchmark(flatten_nested,long_batch)
    result_new, time_new, _ = benchmark(flatten_nested,long_batch)
print(f"Time taken for old: {time_old} seconds, new: {time_new} seconds, ratio: {time_old/time_new}")
print(result_old)
print(result_new)

Time taken for old: 0.00010963100066874176 seconds, new: 9.046100058185402e-05 seconds, ratio: 1.2119145262995592
(['a', 'a', 'b', 'c', 'a', 'b', 'c', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'e', 'a', 'a', 'b', 'c', 'd', 'e', 'f', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'a', 'b', 'a', 'b', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'e', 'f', 'a', 'b', 'c', 'd', 'e', 'f', 'a', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'a', 'b', 'c', 'd', 'e', 'f', 'g'], ['root', 'root1', 'root1', 'root1', 'root20', 'root20', 'root2', 'root30', 'root30', 'root31', 'root31', 'root40', 'root40', 'root41', 'root411', 'root411', 'root500', 'root6', 'root6', 'root6', 'root6', 'root6', 'root6', 'root7', 'root71', 'root71', 'root7', 'root80', 'root81', 'root81', 'root812', 'root90', 'root90', 'root90', 'root9', 'root92', 'root921', 'root921', 'root921', 'root9', 'root10', 'root10', 'root11000', 'root11000', 'root120', 'root120', 'root120', 'root120', 'root130', 'root130',

In [13]:
groups

[['root'],
 ['root', '1'],
 ['root', '1'],
 ['root', '1'],
 ['root', '2', '0'],
 ['root', '2', '0'],
 ['root', '2']]

In [14]:
groups[6] == groups[5][:len(groups[6])]


True