In [None]:
from fastfusion import Specification
from fastfusion.mapper.FFM.exploration.single_einsum_mapper import iterate_mappings_constraints

# # Example mapping node
# type: "temporal"
# rank: P
# # Choose one of the following cases
# # Case 1.a
# tile_shape: 3   # will make tile shapes with shape 3
# # Case 1.b
# tile_shape: null # will create a sympy symbol to represent tile shape and use that
# # Case 2.a
# factor: 3       # will make 3 as evenly shaped possible tiles
# # Case 2.b
# factor: null    # will create a sympy symbol to represent the factor, then same as 2.a
# # Case 3   (I'm only showing null from now on)
# tile_pattern:
#   stride: null
#   initial_shape: null  # This will create tile like this [0, 1, ..., initial_shape - 1], [initial_shape, ..., initial_shape + stride - 1], [initial_shape + stride, ..., initial_shape + 2*stride - 1], ...
# # Case 4
# tile_pattern:
#   stride: null
#   shape: null      # This will create tile like this [0, 1, ..., shape-1], [stride, stride+1, ..., stride + shape-1], [2*stride, 2*stride + 1, ..., 2*stride + shape - 1], ...
#         choices = list(integer_factorizations_to_n_parts(rank_size, len(loops)))

# Tile shape constraint: Applies to all tensor(s) in a storage node for which that tile shape is relevant
# Loop bound constraint: Only for spatial

spec = Specification.from_yaml(
    "architecture/four_level.arch.yaml",
    "workloads/mha_full_new.yaml",
    "workloads/mha_full_new.renames.yaml",
)

workload = spec.workload
renames = spec.renames

einsum_name = "K"
einsum = workload.einsums[einsum_name]
rank_variables = einsum.rank_variables
tensors = einsum.tensors
rank_variable_to_size = {r: 16 for r in rank_variables}

# If there are two back-to-back storages for the same tensor & the outer is
# optional, then it is invalid.
import time

t0 = time.time()
mappings_count = 0
n_mappings = 0

# pr = cProfile.Profile()
# pr.enable()

for i, (mapping, constraints) in enumerate(iterate_mappings_constraints(spec, "Q")):
    print(f"{i}: {' '.join(c.compact_string() for c in mapping)}")

# pr.disable()
# s = io.StringIO()
# ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
# ps.print_stats(30)  # Print top 30 time-consuming functions
# print(s.getvalue())

# TODO: Check for ranks not in the mapping and put them at the bottom
# TODO: What if there are no loops? 
# TODO: Set _must_exist for all backing storage nodes
# TODO: Constraint attacher
# TODO: Can't have tile size constraints on backing memory
# I'm doing the tile shape exploration now and I'm trying to understand this note. I think I understand what you're saying.
# Can I ask one thing from the constraint code? If the constraint is an equality, then just set the tile_shape attribute of the node (or factor or whatever is needed) to the value.
# The tile shape exploration assumes a particular mapspace (in most cases, tile shapes are factors of the full rank shape), so an equality may never be satisfied. E.g., if the constraint sets the tile shape equal to a non-factor value because you want a particular imperfect factorization, but that's never in the mapspace, then you'll get nothing.
# It's also a bit more efficient to just set the value and the explorer doesn't have to figure out the equality by trial-and-error. For other more complicated constraints, trial-and-error is better.



2025-05-23 21:13:17 INFO        Loading yaml file architecture/four_level.arch.yaml
2025-05-23 21:13:17 INFO        Found top key variables in architecture/four_level.arch.yaml
2025-05-23 21:13:17 INFO        Found top key architecture in architecture/four_level.arch.yaml
2025-05-23 21:13:17 INFO        Found top key component_classes in architecture/four_level.arch.yaml
2025-05-23 21:13:17 INFO        Loading yaml file workloads/mha_full_new.yaml
2025-05-23 21:13:17 INFO        Found top key workload in workloads/mha_full_new.yaml
2025-05-23 21:13:17 INFO        Loading yaml file workloads/mha_full_new.renames.yaml
2025-05-23 21:13:17 INFO        Found top key renames in workloads/mha_full_new.renames.yaml
2025-05-23 21:13:17 INFO        Calculated "(450 / 64) / 8" = 0.87890625.
2025-05-23 21:13:17 INFO        Calculated "8 * width" = 65536.
2025-05-23 21:13:17 INFO        Calculated "1024*1024*128*8" = 1073741824.
2025-05-23 21:13:17 INFO        Calculated "1024*1024*32*8" = 26843545

0: [MainMemory WQ] [MainMemory I] e-None h-None [Register WQ] b-None e-None h-None m-None [GlobalBuffer Q] SX-b-None SX-d-None SX-e-None SX-h-None SX-m-None b-None e-None h-None m-None [LocalBuffer Q] SX-b-None SX-d-None SX-e-None SX-h-None SX-m-None SY-b-None SY-d-None SY-e-None SY-h-None SY-m-None b-None e-None h-None m-None
<fastfusion.mapper.FFM.exploration.single_einsum_mapper.LoopBoundsConstraintLambda object at 0x7fcf5b1a28c0>
<fastfusion.mapper.FFM.exploration.single_einsum_mapper.LoopBoundsConstraintLambda object at 0x7fcf5b1a3100>
<fastfusion.mapper.FFM.exploration.single_einsum_mapper.LoopBoundsConstraintLambda object at 0x7fcf5b1a2b30>
<fastfusion.mapper.FFM.exploration.single_einsum_mapper.LoopBoundsConstraintLambda object at 0x7fcf5b1a2800>
1: [MainMemory WQ] [MainMemory I] b-None e-None h-None m-None [GlobalBuffer Q] SX-b-None SX-d-None SX-e-None SX-h-None SX-m-None e-None h-None [Register WQ] e-None h-None [LocalBuffer Q] SX-b-None SX-d-None SX-e-None SX-h-None SX-m-Non