In [12]:
import copy
import itertools
from typing import Union
from fastfusion.frontend._set_parse import InvertibleSet
from fastfusion.frontend.arch import Leaf, Storage
from fastfusion.frontend.workload.workload_spec import RankVariable, Tensor
from fastfusion import Specification

spec = Specification.from_yaml(
    "architecture/four_level.arch.yaml",
    "workloads/mha_full_new.yaml",
    "workloads/mha_full_new.renames.yaml",
)
workload = spec.workload
renames = spec.renames

einsum_name = "QK"
einsum = workload.einsums[einsum_name]
rank_variables = einsum.rank_variables
tensors = einsum.tensors
symbol_table = workload.get_constraint_symbol_table(einsum_name, renames)
first_value = next(iter(symbol_table.values()))
arch_nodes = spec.get_flattened_architecture()
tensor2rank_variables = einsum.tensor2rank_variables
storage_order = [n.name for n in arch_nodes if isinstance(n, Storage)]

from itertools import chain, combinations
def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
        
def make_storage_choices_one_level(node: Leaf, symbol_table: dict[str, InvertibleSet]):
    if not isinstance(node, Storage):
        yield set(), symbol_table
        return
    new_symbol_table = copy.copy(symbol_table)
    storage_constraints = node.constraints.storage._parse(symbol_table)
    must_keep = first_value.to_my_space(storage_constraints["keep"])
    must_bypass = first_value.to_my_space(storage_constraints["bypass"])
    
    if must_keep - new_symbol_table["All"]:
        raise KeyError(f"Keep constraint for {node.name} includes tensors that are not in the einsum: {must_keep - new_symbol_table['All']}")
    if must_bypass - new_symbol_table["All"]:
        raise KeyError(f"Bypass constraint for {node.name} includes tensors that are not in the einsum: {must_bypass - tensors}")
    if must_keep & must_bypass:
        raise KeyError(f"Keep and bypass constraints for {node.name} intersect: {must_keep & must_bypass}")
    
    may_keep = tensors - must_bypass - must_keep

    for subset in powerset(may_keep):
        subset = first_value.to_my_space(set(subset))
        keep_choice = first_value.to_my_space(subset | must_keep)
        keep_choice.tensors = lambda: keep_choice # So users can do MainMemory().tensors(). Optional.
        new_symbol_table[node.name] = keep_choice
        assert not any(isinstance(k, str) for k in keep_choice)
        keep_choice = keep_choice.to_my_space({copy.copy(t) for t in keep_choice})
        for t in keep_choice:
            t._storage_name = node.name
            t._required = t in must_keep
            t._uneven = node.constraints.storage.uneven
        yield keep_choice, new_symbol_table

def make_storage_choices_all_levels(nodes: list[Storage], symbol_table: dict[str, InvertibleSet]):
    while nodes and not isinstance(nodes[0], Storage):
        nodes = nodes[1:]
    if len(nodes) == 0:
        yield dict(), symbol_table
        return

    for choice, symbol_table in make_storage_choices_one_level(nodes[0], symbol_table):
        for subchoices, symbol_table in make_storage_choices_all_levels(nodes[1:], symbol_table):
            yield {**subchoices, nodes[0].name: choice}, symbol_table

def get_storage_loop_blocks(i: int, mapping: list[str], first_non_fused_loop: int) -> tuple[list[Tensor], list[RankVariable], list[Tensor]]:
    if not isinstance(mapping[i], RankVariable):
        raise ValueError(f"Index {i} must point to a RankVariable")
        
    left_tensors = []
    right_tensors = []
    
    # Find the start of the rank variable block
    rank_start = i
    while rank_start > 0 and isinstance(mapping[rank_start - 1], RankVariable):
        rank_start -= 1
    
    # Find the end of the rank variable block
    rank_end = i
    while rank_end < len(mapping) - 1 and isinstance(mapping[rank_end + 1], RankVariable):
        rank_end += 1
    
    # Find the start of the left tensor block
    left_tensor_start = rank_start
    while left_tensor_start > 0 and isinstance(mapping[left_tensor_start - 1], Tensor):
        left_tensor_start -= 1
    
    # Find the end of the right tensor block
    right_tensor_end = rank_end
    while right_tensor_end < len(mapping) - 1 and isinstance(mapping[right_tensor_end + 1], Tensor):
        right_tensor_end += 1
    
    # Extract the blocks
    rank_start = max(rank_start, first_non_fused_loop)
    left_tensors = mapping[left_tensor_start:rank_start]
    ranks = mapping[rank_start:rank_end + 1]
    right_tensors = mapping[rank_end + 1:right_tensor_end + 1]
    
    return left_tensors, ranks, right_tensors        


def enumerate_rank_variable_sets(mapping: list[str]) -> list[tuple[int, "RankVariableSet"]]:
    return [(i, r) for i, r in enumerate(mapping) if isinstance(r, RankVariableSet)]

def enumerate_tensor_sets(mapping: list[str]) -> list[tuple[int, "TensorSet"]]:
    return [(i, t) for i, t in enumerate(mapping) if isinstance(t, TensorSet)]

def is_valid(mapping, recurse=True):
    start_from = 0

    # Above the first appearance of a tensor, all rank variables must be irrelevant
    seen_tensors = set()
    for i, ts in enumerate_tensor_sets(mapping):
        for t in ts.tensors():
            if t in seen_tensors:
                continue
            seen_tensors.add(t)
            for _, rs in enumerate_rank_variable_sets(mapping)[:i]:
                if any(r not in tensor2rank_variables[t] for r in rs):
                    return False
                
    # If there are two instances of a tensor in the same tensor set, the outer
    # storage must be required. Otherwise we should have bypassed the outer.
    for _, ts in enumerate_tensor_sets(mapping):
        for t1 in ts.tensors():
            for t2 in ts.tensors():
                if t1.name != t2.name:
                    continue
                s1, s2 = t1._storage_name, t2._storage_name
                s1_idx, s2_idx = storage_order.index(s1), storage_order.index(s2)
                if s1_idx < s2_idx and not t1._required:
                    return False
                if s2_idx < s1_idx and not t2._required:
                    return False
            
    # # for i in range(start_from, len(mapping)):
    # #     if not isinstance(mapping[i], Tensor) or mapping[i]._must_be_here:
    # #         continue
    # #     for j in range(i+1, len(mapping)):
    # #         if isinstance(mapping[j], RankVariable):
    # #             if mapping[j] in tensor2rank_variables[mapping[i]]:
    # #                 return False
    # #             break
        
    # Minimize tile size for a given reuse: If a tensor is stored immediately
    # above a relevant rank variable, it should have been stored below.
    for i, ts in enumerate_tensor_sets(mapping):
        for _, rs in enumerate_rank_variable_sets(mapping[i:])[:1]:
            for t in ts.tensors():
                if t._must_be_here:
                    continue
                for r in rs:
                    if r in tensor2rank_variables[t]:
                        return False

    # # No further rules until all backing storages have been seen
    seen_tensors = set()
    start_from = len(mapping)
    for i in range(len(mapping)):
        if isinstance(mapping[i], TensorSet):
            seen_tensors.update(mapping[i].tensors())
            if len(seen_tensors) == len(tensors):
                start_from = i
                break

    # # Maximize reuse for a given tile size: Each tensor must be immediately
    # # above an irrelevant rank variable. This is disabled for fused loops
    # # because increasing reuse may increase tensor lifetime.
    # # for i in range(start_from, len(mapping)):
    # #     if not isinstance(mapping[i], Tensor):
    # #         continue
    # #     for j in range(i+1, len(mapping)):
    # #         if isinstance(mapping[j], RankVariable):
    # #             if mapping[j] in tensor2rank_variables[mapping[i]]:
    # #                 return False
    # #             break
    
    # Maximize reuse for a given tile size: If a tensor is stored immediately
    # below an irrelevant rank variable, it should have been stored above. Disabled
    # for fused loops because increasing reuse may increase tensor lifetime.
    for i, rs in enumerate_rank_variable_sets(mapping):
        if i < start_from:
            continue
        for _, ts in enumerate_tensor_sets(mapping[i:])[:1]:
            for t in ts.tensors():
                if t._must_be_here:
                    continue
                for r in rs:
                    if r not in tensor2rank_variables[t]:
                        return False
                    
    # # Maximize number of loops: If there is a set of rank variables and the right
    # # tensor group is complete, then it should contain the intersection of all
    # # left-irrelevant and right-relevant rank variables.
    # indent = sum(len(s) for s in mapping)
    # for i, rs in enumerate_rank_variable_sets(mapping):
    #     if i < start_from:
    #         continue
    #     prev_tensors, next_tensors = None, None
    #     for j, ts in enumerate_tensor_sets(mapping):
    #         if j <= i:
    #             prev_tensors = ts
    #         if next_tensors is None and j > i:
    #             next_tensors = ts
    #             break
            
    #     if prev_tensors is None or next_tensors is None:
    #         continue

    #     if j == len(mapping) - 1: # Right storage node may be incomplete
    #         continue
        
    #     if prev_tensors is None or next_tensors is None:
    #         continue
        
    #     left_irrelevent = set.intersection(*(tensor2rank_variables[t] for t in prev_tensors.tensors() if not t._must_be_here), rank_variables)
    #     right_relevant = set.intersection(*(tensor2rank_variables[t] for t in next_tensors.tensors() if not t._must_be_here))
    #     intersection = left_irrelevent & right_relevant
    #     if len(mapping) > 3 and str(mapping[1]) == "{b, d, h}":
    #         print("AHH")
    #     if len(intersection) > len(rs):
    #         # print("  " * indent + f'Drop due to {i}:'.ljust(15) + f'{mapping}')
    #         # if i == 3 and str(mapping[1]) == "{b, e, h, m}":
    #         #     print(f'{i}: {mapping}')
    #         return False
    # # print("  " * indent + f'Passed:'.ljust(15) + f'{mapping}')
    
    if recurse:
        for i, rs in enumerate_rank_variable_sets(mapping):
            # Can't do this if it's the last entry in the mapping
            # OR the next storage is incomplete
            if i >= len(mapping) - 2:
                continue
            for r in rank_variables:
                if r in rs:
                    continue
                rs.add(r)
                valid = is_valid(mapping, recurse=False)
                rs.remove(r)
                if valid:
                    return False

    return True

class TensorSet(set):
    def __str__(self):
        return f"{{{', '.join(str(t) for t in sorted(self))}}}"
    
    def __repr__(self):
        return str(self)
    
    def tensors(self) -> list[Tensor]:
        return [t[0] for t in self]

class RankVariableSet(set):
    def __str__(self):
        return f"{{{', '.join(str(t) for t in sorted(self))}}}"

    def __repr__(self):
        return str(self)

def add_to_mapping(mapping: list[str], choice: Union[Tensor, RankVariable]):
    target_type = TensorSet if isinstance(choice, tuple) else RankVariableSet
    if not isinstance(mapping[-1], target_type):
        mapping.append(target_type())
    mapping[-1].add(choice)

def pop_from_mapping(mapping: list[str], choice: Union[Tensor, RankVariable]):
    target_type = TensorSet if isinstance(choice, tuple) else RankVariableSet
    if not isinstance(mapping[-1], target_type):
        raise ValueError(f"Last element of mapping is not a {target_type.__name__}")
    mapping[-1].remove(choice)
    if not mapping[-1]:
        mapping.pop()
    
def recursive_build_mapping(
    mapping: list[str],
    nodes: list[Storage],
    tensor2remaining: dict[Tensor, int],
):
    if not tensor2remaining or not any(tensor2remaining.values()):
        yield mapping
        return
    
    choices = [t for t in tensor2remaining if len(tensor2remaining[t]) > 0]
    choices += rank_variables
    
    for choice in choices:
        # No duplicate ranks
        if mapping and choice in mapping[-1]:
            continue

        to_add = choice
        if choice in tensors:
            prev = tensor2remaining[choice.name].pop(-1)
            to_add = (prev, prev._storage_name)

        add_to_mapping(mapping, to_add)
        if is_valid(mapping):
            yield from recursive_build_mapping(mapping, nodes, tensor2remaining)
        pop_from_mapping(mapping, to_add)

        if choice in tensors:
            tensor2remaining[choice.name].append(prev)

def valid_storage_order(mapping: list[str]):
    for i in range(len(mapping)):
        for j in range(i, len(mapping)):
            for t1, t2 in itertools.product(mapping[i].tensors(), mapping[j].tensors()):
                if t1.name != t2.name:
                    continue
                
                s1, s2 = t1._storage_name, t2._storage_name
                s1_idx, s2_idx = storage_order.index(s1), storage_order.index(s2)

                # If a tensor is stored in two levels back-to-back, then we
                # should have bypassed the outer storage if possible.
                if i == j or i == j - 1:
                    if s1_idx < s2_idx and not t1._required:
                        return False
                    if s2_idx < s1_idx and not t2._required:
                        return False
                    
                # Ensure order
                if i < j and s2_idx < s1_idx:
                    return False
    return True
            
def recursive_order_storage_choices(
    mapping: list[str],
    nodes: list[Storage],
    remaining_choices: list,
):
    if not remaining_choices:
        yield mapping
        return

    for choice in list(remaining_choices):
        mapping.append(choice)
        remaining_choices.remove(choice)
        if valid_storage_order(mapping):
            yield from recursive_order_storage_choices(mapping, nodes, remaining_choices)
        mapping.pop()
        remaining_choices.append(choice)

def insert_temporal_loops(mapping: list[str]):
    i = 0
    while i < len(mapping):
        rank_vars = set(rank_variables)
        for t in mapping[i].tensors():
            if not t._must_be_here:
                rank_vars -= tensor2rank_variables[t]
        if i < len(mapping) - 1:
            for t in mapping[i+1].tensors():
                if not t._must_be_here:
                    rank_vars &= tensor2rank_variables[t]
        mapping.insert(i+1, RankVariableSet(rank_vars))
        i += 2
    return mapping


UNEVEN_CROSSES_SPATIAL_BOUNDARIES = False
def insert_spatial_loops(mapping: list[str], nodes: list[Leaf]):
    if not nodes:
        yield mapping
        return
    node = nodes[0]
    
    if not node.spatial.fanout_X > 1 and not node.spatial.fanout_Y > 1:
        yield from insert_spatial_loops(mapping, nodes[1:])
        return
    
    # Insert fanout below all storage nodes above this one and above all storage
    # nodes at or below this one
    last_over, first_under = 0, len(mapping)
    for i, m in enumerate_tensor_sets(mapping):
        for t in m.tensors():
            over = storage_order.index(t._storage_name) < storage_order.index(node.name)
            under = not over
            last_over = max(last_over, i) if over else last_over
            first_under = min(first_under, i) if under else first_under

    if last_over >= first_under:
        if UNEVEN_CROSSES_SPATIAL_BOUNDARIES:
            yield mapping
        return
    
    insert_point = first_under - 1
    if node.spatial.fanout_Y > 1:
        mapping.insert(insert_point, RankVariableSet(["Y"]))
    if node.spatial.fanout_X > 1:
        mapping.insert(insert_point, RankVariableSet(["X"]))
    yield from insert_spatial_loops(mapping, nodes[1:])
    if node.spatial.fanout_X > 1:
        mapping.pop(insert_point)
    if node.spatial.fanout_Y > 1:
        mapping.pop(insert_point)

# If there are two back-to-back storages for the same tensor & the outer is
# optional, then it is invalid.
uneven_storages = [n for n in arch_nodes if n.constraints.storage.uneven]
storage_choice_options = [s for s, _ in make_storage_choices_all_levels(arch_nodes, symbol_table)]
import time

t0 = time.time()
mappings_count = 0
main_memory = arch_nodes[0]
n_mappings = 0

for i, storage_choices in enumerate(storage_choice_options):
    # if i != 0:
    #     continue
    print(f"{i}/{len(storage_choice_options)}: {storage_choices}")
    flattened_storage_choices = []
    mapping_base = [TensorSet(
        (t, t._storage_name)
        for t in storage_choices[main_memory.name]
    )]
    for t in mapping_base[-1]:
        t[0]._must_be_here = True

    for k, v in storage_choices.items():
        if k != main_memory.name:
            flattened_storage_choices.extend(TensorSet([(t, t._storage_name)]) for t in v)
            for t in v:
                t._must_be_here = False

    for mapping in recursive_order_storage_choices(mapping_base, arch_nodes, flattened_storage_choices):
        # print(mapping)
        mapping = [x for x in mapping]
        mapping = insert_temporal_loops(mapping)
        for mapping2 in insert_spatial_loops(mapping, arch_nodes):
            print(f"\t{mapping2}")
            n_mappings += 1

print(f"Total mappings: {n_mappings}")
#     if i != 1:
#         continue
#     # mapping_base = [TensorSet(storage_choices[main_memory.name])]
#     main_memory = arch_nodes[0]
#     mapping_base = [TensorSet(
#         (t, t._storage_name)
#         for t in storage_choices[main_memory.name]
#     )]
#     for t in mapping_base[-1]:
#         t[0]._must_be_here = True
#     tensor2choices = {t: [] for t in tensors}
#     for k, v in storage_choices.items():
#         for t in v:
#             if k != main_memory.name:
#                 tensor2choices[t].append(t)
#                 t._must_be_here = False

#     print(f'{i}/{len(storage_choice_options)}: {storage_choices}')

#     for mapping in recursive_build_mapping(mapping_base, arch_nodes, tensor2choices):
#         mappings_count += 1
#         elapsed_time = time.time() - t0
#         if mappings_count % 10000 == 0:
#             print(f'Mappings/second: {mappings_count / elapsed_time}')
#         print(f"\t{mapping}")

# print(f"Total mappings: {mappings_count}")
# print(f'Mappings/second: {mappings_count / elapsed_time}')


# # [{(I, 'MainMemory'), (Q, 'MainMemory'), (WQ, 'MainMemory')}, {b, e, h, m

2025-05-11 19:58:39 INFO        Loading yaml file architecture/four_level.arch.yaml
2025-05-11 19:58:39 INFO        Found top key variables in architecture/four_level.arch.yaml
2025-05-11 19:58:39 INFO        Found top key architecture in architecture/four_level.arch.yaml
2025-05-11 19:58:39 INFO        Found top key compound_components in architecture/four_level.arch.yaml
2025-05-11 19:58:39 INFO        Loading yaml file workloads/mha_full_new.yaml
2025-05-11 19:58:39 INFO        Found top key workload in workloads/mha_full_new.yaml
2025-05-11 19:58:39 INFO        Loading yaml file workloads/mha_full_new.renames.yaml
2025-05-11 19:58:39 INFO        Found top key renames in workloads/mha_full_new.renames.yaml
2025-05-11 19:58:39 INFO        Loading yaml file /root/.config/fastfusion/config.yaml
2025-05-11 19:58:39 INFO        Found top key version in /root/.config/fastfusion/config.yaml
2025-05-11 19:58:39 INFO        Found top key environment_variables in /root/.config/fastfusion/conf

0/27: {'Register': {K}, 'LocalBuffer': {K, Q, QK}, 'GlobalBuffer': {K, Q, QK}, 'MainMemory': {}}
	[{}, {b, h, m, p}, {(QK, 'GlobalBuffer')}, {e}, {(K, 'GlobalBuffer')}, {m}, {(Q, 'GlobalBuffer')}, {X}, {p}, {(QK, 'LocalBuffer')}, {e}, {(K, 'LocalBuffer')}, {m}, {(Q, 'LocalBuffer')}, {X}, {Y}, {p}, {(K, 'Register')}, {m}]
	[{}, {b, h, m, p}, {(QK, 'GlobalBuffer')}, {e}, {(K, 'GlobalBuffer')}, {m}, {(Q, 'GlobalBuffer')}, {X}, {p}, {(QK, 'LocalBuffer')}, {e}, {(Q, 'LocalBuffer')}, {p}, {(K, 'LocalBuffer')}, {X}, {Y}, {}, {(K, 'Register')}, {m}]
	[{}, {b, h, m, p}, {(QK, 'GlobalBuffer')}, {e}, {(K, 'GlobalBuffer')}, {m}, {(Q, 'GlobalBuffer')}, {X}, {p}, {(K, 'LocalBuffer')}, {m}, {(Q, 'LocalBuffer')}, {p}, {(QK, 'LocalBuffer')}, {X}, {Y}, {e}, {(K, 'Register')}, {m}]
	[{}, {b, h, m, p}, {(QK, 'GlobalBuffer')}, {e}, {(K, 'GlobalBuffer')}, {m}, {(Q, 'GlobalBuffer')}, {X}, {p}, {(K, 'LocalBuffer')}, {m}, {(QK, 'LocalBuffer')}, {e}, {(Q, 'LocalBuffer')}, {X}, {Y}, {p}, {(K, 'Register')}, {m}]
