In [None]:
import copy
import itertools
from typing import Union
from fastfusion.frontend._set_parse import InvertibleSet
from fastfusion.frontend.arch import Leaf, Storage
from fastfusion.frontend.workload.workload_spec import RankVariable, Tensor
from fastfusion import Specification

# # Example mapping node
# type: "temporal"
# rank: P
# # Choose one of the following cases
# # Case 1.a
# tile_shape: 3   # will make tile shapes with shape 3
# # Case 1.b
# tile_shape: null # will create a sympy symbol to represent tile shape and use that
# # Case 2.a
# factor: 3       # will make 3 as evenly shaped possible tiles
# # Case 2.b
# factor: null    # will create a sympy symbol to represent the factor, then same as 2.a
# # Case 3   (I'm only showing null from now on)
# tile_pattern:
#   stride: null
#   initial_shape: null  # This will create tile like this [0, 1, ..., initial_shape - 1], [initial_shape, ..., initial_shape + stride - 1], [initial_shape + stride, ..., initial_shape + 2*stride - 1], ...
# # Case 4
# tile_pattern:
#   stride: null
#   shape: null      # This will create tile like this [0, 1, ..., shape-1], [stride, stride+1, ..., stride + shape-1], [2*stride, 2*stride + 1, ..., 2*stride + shape - 1], ...

#         choices = list(integer_factorizations_to_n_parts(rank_size, len(loops)))

# Tile shape constraint: Applies to all tensor(s) in a storage node for which that tile shape is relevant
# Loop bound constraint: Only for spatial

spec = Specification.from_yaml(
    "architecture/four_level.arch.yaml",
    "workloads/mha_full_new.yaml",
    "workloads/mha_full_new.renames.yaml",
)
workload = spec.workload
renames = spec.renames

einsum_name = "K"
einsum = workload.einsums[einsum_name]
rank_variables = einsum.rank_variables
tensors = einsum.tensors
symbol_table = workload.get_constraint_symbol_table(einsum_name, renames)
first_value = next(iter(symbol_table.values()))
arch_nodes = spec.get_flattened_architecture()
tensor2rank_variables = einsum.tensor2rank_variables
storage_order = [n.name for n in arch_nodes if isinstance(n, Storage)]
rank_variable_to_size = {r: 16 for r in rank_variables}

from itertools import chain, combinations
def powerset(iterable):
    "powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
        
def make_storage_choices_one_level(node: Leaf, symbol_table: dict[str, InvertibleSet]):
    if not isinstance(node, Storage):
        yield set(), symbol_table
        return
    new_symbol_table = copy.copy(symbol_table)
    storage_constraints = node.constraints.storage._parse_keep_bypass(symbol_table)
    must_keep = first_value.to_my_space(storage_constraints["keep"])
    must_bypass = first_value.to_my_space(storage_constraints["bypass"])

    if must_keep - new_symbol_table["All"]:
        raise KeyError(f"Keep constraint for {node.name} includes tensors that are not in the einsum: {must_keep - new_symbol_table['All']}")
    if must_bypass - new_symbol_table["All"]:
        raise KeyError(f"Bypass constraint for {node.name} includes tensors that are not in the einsum: {must_bypass - tensors}")
    if must_keep & must_bypass:
        raise KeyError(f"Keep and bypass constraints for {node.name} intersect: {must_keep & must_bypass}")
    
    may_keep = tensors - must_bypass - must_keep

    for subset in powerset(may_keep):
        subset = first_value.to_my_space(set(subset))
        keep_choice = first_value.to_my_space(subset | must_keep)
        keep_choice.tensors = lambda: keep_choice # So users can do MainMemory().tensors(). Optional.
        new_symbol_table[node.name] = keep_choice
        assert not any(isinstance(k, str) for k in keep_choice)
        keep_choice = keep_choice.to_my_space({copy.copy(t) for t in keep_choice})
        for t in keep_choice:
            t._storage_name = node.name
            t._required = t in must_keep
            t._uneven = node.constraints.storage.uneven
        yield keep_choice, new_symbol_table

def make_storage_choices_all_levels(nodes: list[Storage], symbol_table: dict[str, InvertibleSet]):
    while nodes and not isinstance(nodes[0], Storage):
        nodes = nodes[1:]
    if len(nodes) == 0:
        yield dict(), symbol_table
        return

    for choice, symbol_table in make_storage_choices_one_level(nodes[0], symbol_table):
        for subchoices, symbol_table in make_storage_choices_all_levels(nodes[1:], symbol_table):
            yield {**subchoices, nodes[0].name: choice}, symbol_table

def enumerate_rank_variable_sets(mapping: list[str]) -> list[tuple[int, "RankVariableSet"]]:
    return [(i, r) for i, r in enumerate(mapping) if isinstance(r, RankVariableSet)]

def enumerate_tensor_sets(mapping: list[str]) -> list[tuple[int, "TensorSet"]]:
    return [(i, t) for i, t in enumerate(mapping) if isinstance(t, TensorSet)]

class TensorSet(set):
    def __str__(self):
        return f"{{{', '.join(str(t) for t in sorted(self))}}}"
    
    def __repr__(self):
        return str(self)
    
    def tensors(self) -> list[Tensor]:
        return [t[0] for t in self]

class RankVariableSet(set):
    def __init__(self, *args, spatial=None, spatial_node=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.spatial = spatial
        self.spatial_node = spatial_node
    
    def __str__(self):
        prefix = f"{self.spatial_node}" if self.spatial_node else ""
        postfix = f"{self.spatial}" if self.spatial else ""
        return f"{{ {', '.join(str(t) for t in sorted(self))}{postfix}}}"

    def __repr__(self):
        return str(self)
    
def valid_storage_order(mapping: list[str]):
    for i in range(len(mapping)):
        for j in range(i, len(mapping)):
            for t1, t2 in itertools.product(mapping[i].tensors(), mapping[j].tensors()):
                if t1.name != t2.name:
                    continue
                
                s1, s2 = t1._storage_name, t2._storage_name
                s1_idx, s2_idx = storage_order.index(s1), storage_order.index(s2)

                # If a tensor is stored in two levels back-to-back, then we
                # should have bypassed the outer storage if possible.
                if i == j or i == j - 1:
                    if s1_idx < s2_idx and not t1._required:
                        return False
                    if s2_idx < s1_idx and not t2._required:
                        return False
                    
                # Ensure order
                if i < j and s2_idx < s1_idx:
                    return False
    return True
            
def recursive_order_storage_choices(
    mapping: list[str],
    nodes: list[Storage],
    remaining_choices: list,
):
    if not remaining_choices:
        yield mapping
        return

    for choice in list(remaining_choices):
        mapping.append(choice)
        remaining_choices.remove(choice)
        if valid_storage_order(mapping):
            yield from recursive_order_storage_choices(mapping, nodes, remaining_choices)
        mapping.pop()
        remaining_choices.append(choice)

def insert_temporal_loops(mapping: list[str]):
    
    seen_tensors = set()
    
    i = 0
    while i < len(mapping):
        rank_vars = set(rank_variables)
        for t in mapping[i].tensors():
            if not (t._must_be_here and t in seen_tensors):
                rank_vars -= tensor2rank_variables[t]
        if i < len(mapping) - 1:
            for t in mapping[i+1].tensors():
                if not t._must_be_here:
                    rank_vars &= tensor2rank_variables[t]
        seen_tensors.update(mapping[i].tensors())
        for t in tensors - seen_tensors:
            rank_vars &= tensor2rank_variables[t]

        mapping.insert(i+1, RankVariableSet(rank_vars))
        i += 2
    return mapping


UNEVEN_CROSSES_SPATIAL_BOUNDARIES = False
def insert_spatial_loops(mapping: list[str], nodes: list[Leaf]):
    if not nodes:
        yield mapping
        return
    node = nodes[0]
    
    if not node.spatial.fanout_X > 1 and not node.spatial.fanout_Y > 1:
        yield from insert_spatial_loops(mapping, nodes[1:])
        return
    
    # Insert fanout below all storage nodes above this one and above all storage
    # nodes at or below this one
    last_over, first_under = 0, len(mapping)
    for i, m in enumerate_tensor_sets(mapping):
        for t in m.tensors():
            over = storage_order.index(t._storage_name) < storage_order.index(node.name)
            under = not over
            last_over = max(last_over, i) if over else last_over
            first_under = min(first_under, i) if under else first_under

    if last_over >= first_under:
        if UNEVEN_CROSSES_SPATIAL_BOUNDARIES:
            yield mapping
        return
    
    insert_point = first_under - 1

    tensors_not_seen_yet = set(tensors)
    for i in range(insert_point):
        if isinstance(mapping[i], TensorSet):
            tensors_not_seen_yet -= set(mapping[i].tensors())
    
    loops = set(rank_variables)
    for t in tensors_not_seen_yet:
        loops &= tensor2rank_variables[t]
    
    if node.spatial.fanout_Y > 1:
        mapping.insert(insert_point, RankVariableSet(loops, spatial="Y", spatial_node=node.name))
    if node.spatial.fanout_X > 1:
        mapping.insert(insert_point, RankVariableSet(loops, spatial="X", spatial_node=node.name))
    yield from insert_spatial_loops(mapping, nodes[1:])
    if node.spatial.fanout_X > 1:
        mapping.pop(insert_point)
    if node.spatial.fanout_Y > 1:
        mapping.pop(insert_point)

# If there are two back-to-back storages for the same tensor & the outer is
# optional, then it is invalid.
uneven_storages = [n for n in arch_nodes if n.constraints.storage.uneven]
storage_choice_options = list(make_storage_choices_all_levels(arch_nodes, symbol_table))
import time

t0 = time.time()
mappings_count = 0
main_memory = arch_nodes[0]
n_mappings = 0

# TODO: Check for ranks not in the mapping and put them at the bottom

print(f"Total mappings: {n_mappings}")

def make_mapping_nodes(mapping: list, symbol_table: dict[str, InvertibleSet]):
    node2storage_constraint = {
        n.name: n.constraints.storage._parse_non_keep_bypass(symbol_table) 
        for n in arch_nodes
    }
    node2spatial_X_constraint = {
        n.name: n.constraints.get_spatial_constraint(for_X=True)._parse(symbol_table)
        for n in arch_nodes
    }
    node2spatial_Y_constraint = {
        n.name: n.constraints.get_spatial_constraint(for_Y=True)._parse(symbol_table)
        for n in arch_nodes
    }

    mapping_nodes = []
    for i, node in enumerate(mapping):
        if isinstance(node, RankVariableSet):
            if not node.spatial:
                for r in node:
                    mapping_nodes.append(TemporalMappingNode(rank=r))
                    continue
            else:
                for r in node:
                    mapping_nodes.append(SpatialMappingNode(rank=r, dimension=node.spatial))
        elif isinstance(node, TensorSet):
            for t in node:
                mapping_nodes.append(StorageMappingNode(tensor_name=t[0].name))
    return mapping_nodes

from combinatorics.integer import integer_factorizations_to_n_parts
import numpy as np

def explore_tile_shapes(mapping_nodes: list[MappingNode], rank_variables: list[str]):
    for n in mapping_nodes:
        if isinstance(n, (SpatialMappingNode, TemporalMappingNode)):
            if 
            
    for rank_variable in rank_variables:
        nodes = [
            n for n in mapping_nodes
            if (isinstance(n, (SpatialMappingNode, TemporalMappingNode)) and rank_variable == n.rank)
        ]
        size = rank_variable_to_size[rank_variable]
        if not nodes and size != 1:
            raise ValueError(f"Rank variable {rank_variable} has no loops and size is not 1")

        choices = list(integer_factorizations_to_n_parts(size, len(nodes)))
        for choice in choices:
            for i, node in enumerate(nodes):
                if isinstance(node, TemporalMappingNode):
                    node.factor = choice[i]
                elif isinstance(node, SpatialMappingNode):
                    node.factor = choice[i]
        


for i, (storage_choices, symbol_table) in enumerate(make_storage_choices_all_levels(arch_nodes, symbol_table)):
    print(f"{i}/{len(storage_choice_options)}: {storage_choices}")
    flattened_storage_choices = []
    mapping_base = [TensorSet(
        (t, t._storage_name)
        for t in storage_choices[main_memory.name]
    )]
    for t in mapping_base[-1]:
        t[0]._must_be_here = True

    for k, v in storage_choices.items():
        if k != main_memory.name:
            flattened_storage_choices.extend(TensorSet([(t, t._storage_name)]) for t in v)
            for t in v:
                t._must_be_here = False
    for mapping in recursive_order_storage_choices(mapping_base, arch_nodes, flattened_storage_choices):
        # print(mapping)
        mapping = [x for x in mapping]
        mapping = insert_temporal_loops(mapping)
        for mapping2 in insert_spatial_loops(mapping, arch_nodes):
            print(f"\t{mapping2}")
            explore_tile_shapes(make_mapping_nodes(mapping2, symbol_table), rank_variables)
            n_mappings += 1

# TODO: What if there are no loops?


2025-05-12 16:35:02 INFO        Loading yaml file architecture/four_level.arch.yaml
2025-05-12 16:35:02 INFO        Found top key variables in architecture/four_level.arch.yaml
2025-05-12 16:35:02 INFO        Found top key architecture in architecture/four_level.arch.yaml
2025-05-12 16:35:02 INFO        Found top key compound_components in architecture/four_level.arch.yaml
2025-05-12 16:35:02 INFO        Loading yaml file workloads/mha_full_new.yaml
2025-05-12 16:35:03 INFO        Found top key workload in workloads/mha_full_new.yaml
2025-05-12 16:35:03 INFO        Loading yaml file workloads/mha_full_new.renames.yaml
2025-05-12 16:35:03 INFO        Found top key renames in workloads/mha_full_new.renames.yaml
2025-05-12 16:35:03 INFO        Loading yaml file /root/.config/fastfusion/config.yaml
2025-05-12 16:35:03 INFO        Found top key version in /root/.config/fastfusion/config.yaml
2025-05-12 16:35:03 INFO        Found top key environment_variables in /root/.config/fastfusion/conf

Total mappings: 0
0/48: {'Register': {WK}, 'LocalBuffer': {K}, 'GlobalBuffer': {K}, 'MainMemory': {I, WK}}
	[{(I, 'MainMemory'), (WK, 'MainMemory')}, { }, {(K, 'GlobalBuffer')}, { b, d, e, h, mX}, { }, {(K, 'LocalBuffer')}, { b, d, e, h, mX}, { b, d, e, h, mY}, { d}, {(WK, 'Register')}, { b, m}]
