In [None]:
import fastfusion as ff
from math import log10

ff.set_n_parallel_jobs(32)

arch = "../../examples/arches/tpu_v4i_like.arch.yaml"
workload = "../../examples/workloads/gpt3_6.7B.workload.yaml"
jinja_parse_data = dict(BATCH_SIZE=64, N_TOKENS=65536)

spec = ff.Specification.from_yaml(arch, workload, jinja_parse_data=jinja_parse_data)
spec.arch["ArrayDummy"].constraints.spatial["reuse_input"].min_utilization = 0
spec.arch["ArrayDummy"].constraints.spatial["reuse_output"].min_utilization = 0
spec.arch["MainMemory"].constraints.tensors.keep = "All"
spec.arch["GlobalBuffer"].constraints.tensors.keep = "output | input | ~MainMemory"
spec.arch["GlobalBuffer"].constraints.tensors.may_keep = "weight | ~MainMemory"
if "tpu_v4i" in arch:
    # spec.arch["LocalBuffer"].constraints.spatial.append(ff.constraints.Spatial(name="Z", min_utilization=1))
    spec.arch["LocalBuffer"].constraints.tensors.keep = "input | output"

spec.mapper.ffm.metrics = ff.Metrics.ENERGY | ff.Metrics.LATENCY

print(spec.workload.shape)

def run_mapper(spec, count_option):
    spec.mapper.ffm._count_option_for_mapsapce_size_evaluation = count_option
    return ff.mapper.FFM.make_pmappings(spec, cache_dir="/tmp/ff_cache", einsum_names=("QK",))


pmappings_total = run_mapper(spec, ("redundant_dataplacements", "non_helpful_loops_for_loop_orders", "non_helpful_tile_shapes", "redundant_loop_orders"))
pmappings_no_redundant_dataplacements = run_mapper(spec, ("non_helpful_loops_for_loop_orders", "non_helpful_tile_shapes", "redundant_loop_orders"))
pmappings_no_non_helpful_tile_shapes = run_mapper(spec, ("non_helpful_loops_for_loop_orders", "redundant_loop_orders"))
pmappings_no_non_helpful_loops_for_loop_orders = run_mapper(spec, ("redundant_loop_orders"))
pmappings_evaluated = run_mapper(spec, ())



{'b': '0 <= b < 64', 'm': '0 <= m < 65536', 'p': '0 <= p < 65536', 'h': '0 <= h < 32', 'e': '0 <= e < 128', 'f': '0 <= f < 128', 'd': '0 <= d < 4096', 'c': '0 <= c < 16384', 'j': '0 <= j < 4096', 'g': '0 <= g < 4096'}


Generating pmapping templates for compute ScalarUnit Einsum QK: 0it [00:00, ?it/s]
Generating pmapping templates for compute MAC Einsum QK: 16it [00:00, 56.40it/s]
Generating jobs: 100%|██████████| 2/2 [00:01<00:00,  1.75it/s]
Generating pmappings:  56%|█████▋    | 9/16 [00:01<00:01,  6.62it/s]

KeyboardInterrupt: 

Generating pmappings:  62%|██████▎   | 10/16 [00:16<00:00,  6.62it/s]

In [5]:
n_dataplacements = 16
total = pmappings_total.n_total_pmappings()
n_dataflows = no_non_helpful_tile_shapes.n_total_pmappings() / pmappings_evaluated.n_total_pmappings()
n_tile_shapes = total / n_dataflows


print(f"Number of dataflows: {n_dataflows}")
print(f"Number of tile shapes: {n_tile_shapes}")
print(f"Number of dataplacements: {n_dataplacements}")
print(f"Number of total mappings: {total}")
print(f"Tile shapes per dataplacement evaluated: {pmappings_evaluated.n_evaluated_pmappings() / n_dataplacements}")

Number of dataflows: 305979208437901.6
Number of tile shapes: 1.4774533907076272e+22
Number of dataplacements: 16
Number of total mappings: 4.5207001899261357e+36
Tile shapes per dataplacement evaluated: 113010.5
