In [None]:
import os
import pickle
from fastfusion import Specification
from fastfusion.mapper.FFM.exploration.mapper_multi_einsum import get_sims
from fastfusion.mapper.FFM.joining.simexplore import compress, decompress, join_sims

spec = Specification.from_yaml(
    "architecture/four_level.arch.yaml",
    "workloads/mha_full.workload.yaml",
    "workloads/mha_full.renames.yaml",
)

workload = spec.workload
renames = spec.renames

rank_variable_to_size = {r: 8 for r in "abcdefghijklmnopqrstuvwxyz"}

# pr = cProfile.Profile()
# pr.enable()

# sims = get_single_einsum_sims(spec, "Q", rank_variable_to_size)

def cache(filename):
    def decorator(func):
        def wrapper(*args, **kwargs):
            if os.path.exists(filename):
                return pickle.load(open(filename, "rb"))
            else:
                result = func(*args, **kwargs)
                pickle.dump(result, open(filename, "wb"))
                return result
        return wrapper
    return decorator

# @cache("sims.pkl")
def get_sims_with_cache():
    spec.estimate_energy_area()
    flattened_architecture = spec.get_flattened_architecture()
    sims = get_sims(spec, rank_variable_to_size, flattened_architecture)#, pkl_cache="sims.pkl")
    recovery_map = compress(sims)
    return sims, recovery_map, flattened_architecture

sims, recovery_map, flattened_architecture = get_sims_with_cache()
data = join_sims(sims, spec, flattened_architecture)
decompress(recovery_map, sims, spec.workload.einsum_names)

# pr.disable()
# s = io.StringIO()
# ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
# ps.print_stats(30)  # Print top 30 time-consuming functions
# print(s.getvalue())

# TODO: Check for ranks not in the mapping and put them at the bottom
# TODO: What if there are no loops? 
# TODO: Set _must_exist for all backing storage nodes
# TODO: Constraint attacher
# TODO: Can't have tile size constraints on backing memory
# TODO: Einsum orders
# TODO: Copy Einsums
# TODO: Test dataflow constraints and order of storage nodes
# I'm doing the tile shape exploration now and I'm trying to understand this note. I think I understand what you're saying.
# Can I ask one thing from the constraint code? If the constraint is an equality, then just set the tile_shape attribute of the node (or factor or whatever is needed) to the value.
# The tile shape exploration assumes a particular mapspace (in most cases, tile shapes are factors of the full rank shape), so an equality may never be satisfied. E.g., if the constraint sets the tile shape equal to a non-factor value because you want a particular imperfect factorization, but that's never in the mapspace, then you'll get nothing.
# It's also a bit more efficient to just set the value and the explorer doesn't have to figure out the equality by trial-and-error. For other more complicated constraints, trial-and-error is better.

INFO        Loading yaml file architecture/four_level.arch.yaml
INFO        Found top key variables in architecture/four_level.arch.yaml
INFO        Found top key architecture in architecture/four_level.arch.yaml
INFO        Found top key component_classes in architecture/four_level.arch.yaml
INFO        Loading yaml file workloads/mha_full_new.yaml
INFO        Found top key workload in workloads/mha_full_new.yaml
INFO        Loading yaml file workloads/mha_full_new.renames.yaml
INFO        Found top key renames in workloads/mha_full_new.renames.yaml
INFO        Calculated "1024*1024*128*8" = 1073741824.
INFO        Calculated "1024*1024*32*8" = 268435456.
INFO        Calculated "0.5" = 0.5.
Generating SIMs:   0%|          | 0/132 [00:00<?, ?it/s]

[MainMemory WK], b-None, m-None, [GlobalBuffer K], d-None, [GlobalBuffer I], SX-b-None, SX-m-None, e-None, h-None, [LocalBuffer K,I], SY-d-None, SX-e-None, SX-h-None, d-None, e-None, h-None, [Register WK], Einsum K
Returning choices of shape (619164, 13)
Returning df: (619164, 20)
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.97x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  1.03x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.56x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  1.25x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick: 

Generating SIMs:   1%|          | 1/132 [00:01<03:52,  1.78s/it]

Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  1.93x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.46x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.33x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.42x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  1.22x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  1.18x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quic

Generating SIMs:   2%|▏         | 3/132 [00:03<01:51,  1.16it/s]

Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.57x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.62x faster
Time to make pareto with merge:  0.01. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.33x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  1.38x faster
[MainMemory WK], b-None, [GlobalBuffer K], d-None, [GlobalBuffer I], SX-b-None, SX-m-None, e-None, h-None, [LocalBuffer K,I], SY-d-None, SX-e-None, SX-h-None, d-None, e-None, h-None, [Register WK], Einsum K
Returning choices of shape (68796, 12)
Returning df: (68796, 19)
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Num

Generating SIMs:   3%|▎         | 4/132 [00:03<01:14,  1.72it/s]

Returning choices of shape (68796, 11)
Returning df: (68796, 18)
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.90x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  1.01x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.91x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  1.29x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  0.93x faster
Time to make pareto with merge:  0.00. Number of pareto points: 1
Time to make pareto with quick:  0.00. Number of pareto points: 1
Quick is  1.23x faster
Time 