In [1]:
from fastfusion.frontend.specification import Specification
from fastfusion import Specification
from fastfusion.mapper.FFM import make_pmappings, join_pmappings
import fastfusion as ff
import time
from joblib import Memory
import pandas as pd

# Set logging to warning
import logging
logging.getLogger().setLevel(logging.WARNING)

CACHE_DIR = ".cache"
memory = Memory(CACHE_DIR, compress=True)

CACHE = True

def cache(f):
    if CACHE:
        return memory.cache(f)
    return f

In [2]:
# ff.set_n_parallel_jobs(32)

# spec = Specification.from_yaml(
#     "../../examples/arches/tpu_v4i_like.arch.yaml",
#     "../../examples/workloads/gpt3_6.7B.workload.yaml",
#     jinja_parse_data=dict(BATCH_SIZE=64, N_TOKENS=65536)
# )
# einsum_name = "QK"

# spec.arch.nodes["LocalBuffer"].constraints.tensors.keep = "input | output"
# spec.arch.nodes["GlobalBuffer"].constraints.tensors.keep = "input | output"
# spec.arch.nodes["MainMemory"].constraints.tensors.keep = "All"
# spec.mapper.ffm.metrics = ff.Metrics.ENERGY | ff.Metrics.LATENCY

# t0 = time.perf_counter()
# pmappings = make_pmappings(spec, einsum_names=[einsum_name])
# t1 = time.perf_counter()
# print(f"Time taken: {t1 - t0} seconds")
# data = pd.concat(x.mappings.data for x in pmappings.einsum2pmappings[einsum_name])
# edps = data["Total<SEP>energy"] * data["Total<SEP>latency"]
# best = data.iloc[edps.argmin()]
# joined = join_pmappings(spec, pmappings, require_all_einsums=True)
# joined[0].render()

In [3]:
WORKLOADS = [
    "mha_qk",
    "gemm",
]

N_CPUS = 16

@cache
def run_ffm(workload_name):
    ff.set_n_parallel_jobs(1)

    if workload_name == "mha_qk":
        spec = Specification.from_yaml(
            "../../examples/arches/tpu_v4i_like.arch.yaml",
            "../../examples/workloads/gpt3_6.7B.workload.yaml",
            jinja_parse_data=dict(BATCH_SIZE=64, N_TOKENS=65536)
        )
        einsum_name = "QK"
    elif workload_name == "gemm":
        spec = Specification.from_yaml(
            "../../examples/arches/tpu_v4i_like.arch.yaml",
            "../../examples/workloads/matmuls.workload.yaml",
            jinja_parse_data=dict(M=16384, KN=16384, N_EINSUMS=1)
        )
        einsum_name = "Matmul0"
    else:
        raise ValueError(f"Unknown workload: {workload_name}")

    spec.arch["ArrayDummy"].constraints.spatial["reuse_input"].min_utilization = 0
    spec.arch["ArrayDummy"].constraints.spatial["reuse_output"].min_utilization = 0
    spec.arch.nodes["LocalBuffer"].constraints.tensors.keep = "input | output"
    spec.arch.nodes["LocalBuffer"].constraints.spatial["Z"].reuse = "All"
    spec.arch.nodes["GlobalBuffer"].constraints.tensors.keep = "input | output"
    spec.arch.nodes["MainMemory"].constraints.tensors.keep = "All"
    spec.mapper.ffm.metrics = ff.Metrics.ENERGY | ff.Metrics.LATENCY

    t0 = time.perf_counter()
    pmappings = make_pmappings(spec, einsum_names=[einsum_name])
    t1 = time.perf_counter()
    print(f"Time taken: {t1 - t0} seconds")
    data = pd.concat(x.mappings.data for x in pmappings.einsum2pmappings[einsum_name])
    edps = data["Total<SEP>energy"] * data["Total<SEP>latency"]
    best = data.iloc[edps.argmin()]
    return best["Total<SEP>energy"], best["Total<SEP>latency"], t1 - t0

In [4]:
import pytimeloop.timeloopfe.v4 as tl
import pytimeloop.timeloopfe.common.backend_calls as tl_backend_calls
from paths import DATA_DIR, TIMELOOP_CONFIG_DIR, TIMELOOP_WORKLOAD_DIR
import shutil

@cache
def run_timeloop(workload_name, time_limit, use_hint=False):
    #AHAHHFHAFHDH
    # Clear the output directory
    out_dir = "outputs/timeloop" + ("_hint" if use_hint else "")
    shutil.rmtree(out_dir, ignore_errors=True)
    spec = tl.Specification.from_yaml_files(
        TIMELOOP_CONFIG_DIR / ("tpu_like" + ("_hint" if use_hint else "") + ".yaml"),
        TIMELOOP_WORKLOAD_DIR / f"{workload_name}.yaml",
        TIMELOOP_CONFIG_DIR / "tpu_like.ert.yaml"
    )
    spec.mapper.evaluated_size = 1000000000

    if time_limit / N_CPUS > 3:
        time_limit /= N_CPUS
        spec.mapper.num_threads = N_CPUS

    print(f'Running timeloop with {spec.mapper.num_threads} threads for {time_limit} seconds')
    proc = tl.call_mapper(spec, output_dir=out_dir, return_proc=True, log_to=f"{out_dir}/mapper.log")
    time.sleep(time_limit)
    tl.call_stop(proc)
    # time.sleep(1)
    # tl.call_stop(proc)
    proc_result = proc.wait()
    proc_result = 0 # Was succeeding and returning nonzero for some reason
    with open(f"{out_dir}/timeloop-mapper.ART.yaml", "w") as f:
        f.write("ART: {version: 0.4, tables: [{name: x.x, area: 1}]}")
    with open(f"{out_dir}/timeloop-mapper.ERT.yaml", "w") as f:
        f.write("ERT: {version: 0.4, tables: []}")
    result = tl_backend_calls._parse_output(spec, out_dir, proc_result, for_model=False)
    return (
        result.energy,
        result.latency,
        time_limit
    )


In [5]:
from paths import (
    DATA_DIR,
    ZIGZAG_MAPPING_DIR,
    ZIGZAG_ARCHITECTURE_DIR,
    ZIGZAG_WORKLOAD_DIR,
)
import time
import zigzag
from func_timeout import func_timeout, FunctionTimedOut
import zigzag
from zigzag import api
import uuid

@cache
def get_loma_multiproc_speed(**kwargs):
    zigzag.stages.mapping.temporal_mapping_generator_stage.N_JOBS = 1
    _, _, t1 = run_loma_single(
        lpf_limit=7,
        cache_spoiler=uuid.uuid4(),
        **kwargs
    )
    zigzag.stages.mapping.temporal_mapping_generator_stage.N_JOBS = N_CPUS
    _, _, t0 = run_loma_single(
        lpf_limit=7,
        cache_spoiler=uuid.uuid4(),
        **kwargs
    )
    return t1 / t0


@cache
def run_loma_single(workload_name, lpf_limit, use_hint=False, cache_spoiler=None):
    start = time.time()
    energy, latency, cmes = api.get_hardware_performance_zigzag(
        workload=str(ZIGZAG_WORKLOAD_DIR / f"{workload_name}.yaml"),
        accelerator=str(ZIGZAG_ARCHITECTURE_DIR / (f"tpu_like" + ("_bypass_hint" if use_hint else "") + ".yaml")),
        mapping=str(ZIGZAG_MAPPING_DIR / f"tpu_{workload_name}.yaml"),
        opt="EDP",
        # dump_folder=dump_folder,
        # pickle_filename=pickle_filename,
        lpf_limit=lpf_limit,
        enable_mix_spatial_mapping=True,
    )
    end = time.time()
    return energy/1e12, latency/1.05e9, end - start


@cache
def run_loma(workload_name, time_limit):
    n_lpf = 1
    time_limit_scale = 1
    energy, latency, duration = None, None, None
    while True:
        if n_lpf >= 7:
            zigzag.stages.mapping.temporal_mapping_generator_stage.N_JOBS = N_CPUS
            time_limit_scale = get_loma_multiproc_speed(
                workload_name=workload_name,
                use_hint=False,
            )
            print(f'LOMA multiproc speed is {time_limit_scale}')
        else:
            zigzag.stages.mapping.temporal_mapping_generator_stage.N_JOBS = 1
            time_limit_scale = 1

        try:
            scaled_time_limit = time_limit / time_limit_scale
            print(f'Time limit is {scaled_time_limit}. Running LOMA:')
            cur_energy, cur_latency, cur_duration = func_timeout(
                scaled_time_limit,
                run_loma_single,
                kwargs=dict(workload_name=workload_name, lpf_limit=n_lpf),
            )
            cur_duration *= time_limit_scale
            if cur_duration > time_limit:
                break
            energy, latency, duration = cur_energy, cur_latency, cur_duration
        except FunctionTimedOut:
            break
        n_lpf += 1
    if energy is None:
        raise ValueError("Loma timed out")
    return energy, latency, duration



In [6]:
runtime_limits = [1, 10, 100, 1000]
results_edp = {runtime_limit: {} for runtime_limit in runtime_limits}
results_latency = {runtime_limit: {} for runtime_limit in runtime_limits}
results_energy = {runtime_limit: {} for runtime_limit in runtime_limits}
workload_name = "mha_qk"

ffm_energy, ffm_latency, ffm_duration = run_ffm(workload_name)
ffm_edp = ffm_energy * ffm_latency

def fmt(x, y):
    return f"{x} ({x / y:.2f}x)"

for runtime_limit in runtime_limits:
    runtime_limit_seconds = ffm_duration * runtime_limit
    # Timeloop
    for use_hint in [False, True]:
        while True:
            energy, latency, duration = run_timeloop(workload_name, time_limit=runtime_limit_seconds, use_hint=use_hint)
            edp = energy * latency
            print(f"With hint: {use_hint}, Timeloop EDP is {fmt(edp, ffm_edp)}. Latency is {fmt(latency, ffm_latency)}. Energy is {fmt(energy, ffm_energy)}.")
            results_edp[runtime_limit]["Timeloop" + (" + Hint" if use_hint else "")] = edp
            results_latency[runtime_limit]["Timeloop" + (" + Hint" if use_hint else "")] = latency
            results_energy[runtime_limit]["Timeloop" + (" + Hint" if use_hint else "")] = energy
            break

    # ZigZag
    zigzag_energy, zigzag_latency, runtime = run_loma(workload_name, runtime_limit_seconds)
    results_edp[runtime_limit]["ZigZag"] = zigzag_energy * zigzag_latency
    results_latency[runtime_limit]["ZigZag"] = zigzag_latency
    results_energy[runtime_limit]["ZigZag"] = zigzag_energy
    print(f"ZigZag EDP is {fmt(zigzag_energy * zigzag_latency, ffm_edp)}. Latency is {fmt(zigzag_latency, ffm_latency)}. Energy is {fmt(zigzag_energy, ffm_energy)}.")

________________________________________________________________________________
[Memory] Calling __main__--tmp-ipykernel-1267128337.run_timeloop...
run_timeloop('mha_qk', time_limit=36.448038906, use_hint=False)
Running timeloop with 1 threads for 36.448038906 seconds
____________________________________________________run_timeloop - 38.1s, 0.6min
With hint: False, Timeloop EDP is 37337707.59529264 (2399.08x). Latency is 4398.046511104 (268.80x). Energy is 8489.61180856637 (8.93x).
________________________________________________________________________________
[Memory] Calling __main__--tmp-ipykernel-1267128337.run_timeloop...
run_timeloop('mha_qk', time_limit=36.448038906, use_hint=True)
Running timeloop with 1 threads for 36.448038906 seconds
____________________________________________________run_timeloop - 37.4s, 0.6min
With hint: True, Timeloop EDP is 47459.82306030916 (3.05x). Latency is 30.360914905 (1.86x). Energy is 1563.1881716612306 (1.64x).
ZigZag EDP is 19941.06434067751

In [9]:
def get_latex_table(results):
    def floatfmt(x):
        if isinstance(x, str):
            return x
        if x > 100:
            return round(x)
        if x == 1:
            return int(x)
        return f"{x:.2f}"

    df = pd.DataFrame(results)

    for col in df.columns:
        df[col] = df[col].apply(floatfmt)

    cols = list(df.columns)
    runtime_labels = [f"$10^{{{x}}}$" for x in df.index]

    # Transpose it
    df = df.T
    df.columns = runtime_labels
    df[" "] = cols # Put this at the beginning
    df = df.reindex(columns=[" ", *df.columns[:-1]])

    # Add a row above the runtimes (all but the last columns, centered) that says

    latex = df.to_latex(
        index=5,
        float_format=floatfmt,
        column_format="l" + "c" * (len(df.columns) - 1)
    )

    # Insert "Runtime" row above all runtime columns (centered)
    # Identify the header row from pandas and inject our own.
    lines = latex.splitlines()
    for i, line in enumerate(lines):
        if line.startswith(' '):   # first header line
            # Extract runtime columns (skip the first " ")
            runtimes = df.columns[1:]
            num_runtime_cols = len(runtimes)

            # Center "Runtime" across all runtime columns
            # Use multicolumn spanning them
            runtime_header = (
                " & " +
                f"\\multicolumn{{{num_runtime_cols}}}{{c}}{{Best-Found EDP In X Runtime}} \\\\"
            )
            # Replace header line with: empty cell + our multicolumn
            lines.insert(i, "  " + runtime_header)
            break

    for i, line in enumerate(lines):
        if line.startswith("  & $10^{0}$"):
            lines[i] = line.replace("  & $10^{0}$", "Runtime & $10^{0}$")
            break

    for i, line in enumerate(lines):
        if line.startswith("Turbo-Charged"):
            n_replacements = len(runtime_labels) - 1
            to_replace = " & 1" * n_replacements + " \\\\"
            replace_with = f" & \\multicolumn{{{n_replacements}}}{{c}}{{Completed}} \\\\"
            lines[i] = lines[i].replace(to_replace, replace_with)
            break

    with open("outputs/results.tex", "w") as f:
        f.write("\n".join(lines))

import copy
results = copy.deepcopy(results_edp)
for k, v in results.items():
    for k2, v2 in v.items():
        results[k][k2] /= ffm_edp

get_latex_table(results)