In [None]:
import dotenv
env_file = '.env'
dotenv.load_dotenv(env_file, override=True)

import itertools

from pytimeloop.fastfusion.mapper.mapper2 import mapper, PeArrayConstraint, MacArrayConstraint

from tests.load_config_mixin import LoadConfigMixin
from tests.util import TEST_TMP_DIR

config, spec = LoadConfigMixin.load_config(['cascaded_mm_multi_32.workload.yaml', 'four_level.arch.yaml'])

pe_constraint = PeArrayConstraint(4)

mac_constraint = MacArrayConstraint(
    64,
    64,
    {f'Fc{x}': f'Filter{x}' for x in range(1, 32)},
    {f'Fc{x}': f'M{x}' for x in range(1, 32)},
    {f'Fc{x}': f'C{x}' for x in range(1, 32)}
)

resource2capacity = {
    0: None,
    1: 16384 * 512,
    2: 1024 * 512
}

result = mapper(config,
                pe_constraint,
                mac_constraint,
                explore_glb_uneven=True,
                explore_pe_uneven=False,
                spec=spec,
                tmp_path=TEST_TMP_DIR,
                verbose_stream=None)

import pandas as pd
from pytimeloop.fastfusion.sim import SIM
from pytimeloop.fastfusion.pareto import Pareto


Count: 10000, fulltiling: ['S([0, 1] in 0)', 'T0 in 1', 'S({2} in 1)', 'T1 in 1', 'T2 in 1', 'T0 in 1', 'S({2} in 2)', 'S({1} in 2)', 'S({0} in 2)', 'T1 in 64', 'T2 in 64', 'T0 in 1', 'S1 in 1', 'S2 in 1', '{0: [1], 1: [1], 2: [1], 3: [4096.0]}', 'RESOURCE_0_LEVEL_0=2.05e+03', 'RESOURCE_1_LEVEL_1=2.56e+02', 'RESOURCE_2_LEVEL_1=1.04e+02', 'L=8.00e+00', 'E=4.69e+07']
Count: 10000, fulltiling: ['S([9] in 0)', 'T30 in 1', 'S({10} in 1)', 'S({8} in 1)', 'T31 in 1', 'T32 in 1', 'T30 in 1', 'S({10} in 2)', 'S({9} in 2)', 'S({8} in 2)', 'T31 in 64', 'T32 in 64', 'T30 in 1', 'S31 in 1', 'S32 in 1', '{0: [1], 1: [1], 2: [1], 3: [4096.0]}', 'RESOURCE_0_LEVEL_0=1.02e+03', 'RESOURCE_1_LEVEL_1=5.12e+02', 'RESOURCE_2_LEVEL_1=1.04e+02', 'L=8.00e+00', 'E=2.05e+07']
Count: 10000, fulltiling: ['S([11] in 0)', 'T37 in 1', 'S({12} in 1)', 'S({10} in 1)', 'T38 in 1', 'T39 in 1', 'T37 in 1', 'S({12} in 2)', 'S({11} in 2)', 'S({10} in 2)', 'T38 in 64', 'T39 in 64', 'T37 in 1', 'S38 in 1', 'S39 in 1', '{0: [1]

In [35]:
import plotly.express as px

def create_info_str(row):
    def fs(x):
        try:
            return f"{x:.2e}"
        except:
            return x
        
    content = [s for s in row["__Mappings"].values()]
    content.append(", ".join(f"{k}={fs(row[k])}" for k in row.index if k != "__Mappings"))
    return "<br>".join(content)

def plotly_show(df):
    fig = px.scatter(df, x='Latency', y='Energy', hyutitle='Energy vs Latency')
    fig.show()
    
def plotly_show2(dict_of_df):
    fig = px.scatter()
    # x = 'RESOURCE_2_LEVEL_0'
    y = 'Energy'
    for key, df in dict_of_df.items():
        for c in df.columns:
            if "RESOURCE_2" in c:
                x = c
        df["mp"] = df.apply(lambda row: create_info_str(row), axis=1)
        fig.add_scatter(x=df[x], y=df[y], mode='markers', name=str(key), text=df["mp"])
    # Show x and y axis labels
    x = "RESOURCE_2_LEVEL_..."
    fig.update_layout(xaxis_title=x, yaxis_title=y, xaxis_type="log", yaxis_type="log")
    fig.show()
    
def plotly_show_latency(dict_of_df):
    fig = px.scatter()
    x = 'Latency'
    y = 'Energy'
    for key, df in dict_of_df.items():
        df["mp"] = df.apply(lambda row: create_info_str(row), axis=1)
        fig.add_scatter(x=df[x], y=df[y], mode='markers', name=str(key), text=df["mp"])
    # Show x and y axis labels
    fig.update_layout(xaxis_title=x, yaxis_title=y, xaxis_type="log", yaxis_type="log")
    fig.show()


r = next(iter(result.values()))
r = {k: pd.DataFrame(v).fillna(0) for k, v in r.items()}
print(next(iter(r.values())).columns)
# plotly_show2(r)

Index(['Latency', 'Energy', 'RESOURCE_0_LEVEL_0', 'RESOURCE_1_LEVEL_0',
       'RESOURCE_2_LEVEL_0', '__Mappings'],
      dtype='object')


In [4]:
print(f'Number of buckets: {len(result[0])}')
print(f'Number of mappings: {sum(len(v) for v in result[0].values())}')
print(f'Mappings per bucket: {sum(len(v) for v in result[0].values()) / len(result[0])}')

r2 = {}
def paretofy(k, v):
    return SIM(k, Pareto(pd.DataFrame(v).fillna(0)))
from joblib import Parallel, delayed
for einsum_id, compat_dict in result.items():
    r2[einsum_id] = Parallel(n_jobs=1)(delayed(paretofy)(k, v) for k, v in compat_dict.items())
    
sample = next(iter(r2.values()))
print(f'Number of buckets: {len(sample)}')
print(f'Number of mappings: {sum(len(v.mappings[0].data) for v in sample)}')
print(f'Mappings per bucket: {sum(len(v.mappings[0].data) for v in sample) / len(sample)}')

Number of buckets: 61
Number of mappings: 151527
Mappings per bucket: 2484.0491803278687
Number of buckets: 61
Number of mappings: 404
Mappings per bucket: 6.622950819672131


In [31]:

def plotly_show3(list_of_sims):
    fig = px.scatter()
    # x = 'RESOURCE_2_LEVEL_0'
    y = 'Energy'
    for sim in list_of_sims:
        key = sim.tilings[0]
        df = pd.DataFrame(sim.mappings[0].data)
        for c in df.columns:
            if "RESOURCE_2" in c:
                x = c
        # Hoverlabel is MAPPING key
        df["mp"] = df.apply(lambda row: create_info_str(row), axis=1)
        fig.add_scatter(x=df[x], y=df[y], mode='markers', name=str(key), hovertext=df["mp"])
    # Show x and y axis labels
    x = "RESOURCE_2_LEVEL_..."
    fig.update_layout(xaxis_title=x, yaxis_title=y, xaxis_type="log", yaxis_type="log")
    fig.show()

def plotly_show4(list_of_sims):
    fig = px.scatter()
    x = 'Latency'
    y = 'Energy'
    for sim in list_of_sims:
        key = sim.tilings[0]
        df = pd.DataFrame(sim.mappings[0].data)
        # Hoverlabel is MAPPING key
        df["mp"] = df.apply(lambda row: create_info_str(row), axis=1)
        fig.add_scatter(x=df[x], y=df[y], mode='markers', name=str(key), hovertext=df["mp"])
    # Show x and y axis labels
    fig.update_layout(xaxis_title=x, yaxis_title=y, xaxis_type="log", yaxis_type="log")
    fig.show()

plotly_show3(next(iter(r2.values())))

In [6]:
r2 = {}
for einsum_id, compat_dict in result.items():
    r2[einsum_id] = Parallel(n_jobs=1)(delayed(paretofy)(k, v) for k, v in compat_dict.items())

r2 = {k: v for k, v in sorted(r2.items(), key=lambda x: x[0])}

for einsum_id, compat_dict in result.items():
    n_mappings = sum(len(v2) for v2 in compat_dict.values())
    print(f"Einsum {einsum_id} has {n_mappings} mappings in {len(compat_dict)} buckets")
    for k, v in compat_dict.items():
        print(f"\tBucket {k} has {len(v)} mappings")
    
sims = list(r2.values())
for i, s in enumerate(sims):
    print(f'Einsum {i} has tensors: {s[0].tensor_names}')

Einsum 0 has 151527 mappings in 61 buckets
	Bucket Tiling(loops=, tensors=2(1,0)) has 8509 mappings
	Bucket Tiling(loops=0-1, tensors=2(1,1)) has 3706 mappings
	Bucket Tiling(loops=0-2, tensors=2(1,1)) has 2198 mappings
	Bucket Tiling(loops=0-4, tensors=2(1,1)) has 3319 mappings
	Bucket Tiling(loops=0-8, tensors=2(1,1)) has 4935 mappings
	Bucket Tiling(loops=0-16, tensors=2(1,1)) has 6722 mappings
	Bucket Tiling(loops=1-1, tensors=2(1,1)) has 3706 mappings
	Bucket Tiling(loops=1-2, tensors=2(1,1)) has 2198 mappings
	Bucket Tiling(loops=1-4, tensors=2(1,1)) has 3319 mappings
	Bucket Tiling(loops=1-8, tensors=2(1,1)) has 4935 mappings
	Bucket Tiling(loops=1-16, tensors=2(1,1)) has 6722 mappings
	Bucket Tiling(loops=0-1, 1-1, tensors=2(1,2)) has 1257 mappings
	Bucket Tiling(loops=0-1, 1-2, tensors=2(1,2)) has 953 mappings
	Bucket Tiling(loops=0-1, 1-4, tensors=2(1,2)) has 1405 mappings
	Bucket Tiling(loops=0-1, 1-8, tensors=2(1,2)) has 2086 mappings
	Bucket Tiling(loops=0-1, 1-16, tensors

In [29]:

resource2capacity = {
    0: None,
    1: 16384 * 512,
    2: 1024 * 512
}
# resource2capacity = {}

def get_n_mappings(x):
    n = 0
    for s in x:
        for s2 in s:
            for d in s2.mappings:
                n += len(d.data)
    return n


for einsum_id, compat_dict in result.items():
    r2[einsum_id] = Parallel(n_jobs=1)(delayed(paretofy)(k, v) for k, v in compat_dict.items())
    
sims = list(r2.values())

s = sims.pop(0)

nmappings = []
nbuckets = []


while sims:
    print("\n\n")
    print("\n\n" + "=" * 100 + f"\n{len(sims) + 1} Remaining\n" + "=" * 100)
    live_tensors = set.union(set(), *[sim[0].tensor_names for sim in sims])
    ns = sims.pop(0)
    next_live_tensors = set.union(set(), *[sim[0].tensor_names for sim in sims])

    for s2 in s:
        s2.consolidate(live_tensors, resource2capacity)
    # for i, s2 in enumerate(s):
    #     print(f"\tPREV {i} Tiling: {s2.tilings[0]}")
    # for i, s2 in enumerate(ns):
    #     print(f"\tNEXT {i} Tiling: {s2.tilings[0]}")

    ns = SIM.combine_combineable(ns, next_live_tensors | s[0].tensor_names)
    ns = SIM.group_by_left(ns, s[0].tensor_names)
    print(f"\tNEXT: Combined by {sorted(next_live_tensors | s[0].tensor_names)}")
    print(f"\tNEXT: Grouped by {sorted(s[0].tensor_names)}")
    # for i, k in enumerate(ns):
    #     print(f"\t{i} Tiling: {k}")
    print(f"\tPREV: Combined by {sorted(live_tensors)}")
    print(f"\tPREV: Grouped by {sorted(live_tensors)}")
    s = SIM.combine_combineable(s, live_tensors)
    s = SIM.group_by_right(s, live_tensors)

    DO_PRINT = True

    with open('s_keys.txt', 'w') as f:
        for key in sorted(s.keys()):
            f.write(f"{key}\n")

    with open('s2_keys.txt', 'w') as f:
        for key in sorted(ns.keys()):
            f.write(f"{key}\n")

    combined: list[SIM] = []
    for k in s:
        if k in ns:
            for a, b in itertools.product(s[k], ns[k]):
                if DO_PRINT:
                    print(f"\t{a.tiling_str()} <--> {b.tiling_str()}")
                combined.append(a.copy())
                combined[-1].merge_next(b, set())
                # combined_keys.append()
        elif DO_PRINT:
            print(f"\tNo match for {s[k][0].tiling_str()}")
    print(f"\tCombining {sum(len(s2) for s2 in s)}({len(s)}) x {sum(len(s2) for s2 in ns)}({len(ns)}) -> {len(combined)}")
    for k in ns:
        if k not in s:
            if DO_PRINT:
                print(f"\tREVERSE: No match for {ns[k][0].tiling_str()}")


    s = combined
    print(f'Number of buckets: {len(s)}')
    print(f'Number of mappings: {sum(len(s2.mappings[-1].data) for s2 in s)}')
    print(f'Mappings per bucket: {sum(len(s2.mappings[0].data) for s2 in s) / len(s)}')
    nbuckets.append(len(s))
    nmappings.append(sum(len(m.data) for s2 in s for m in s2.mappings))
    







8 Remaining
	NEXT: Combined by [2, 4, 6, 8, 10, 12, 14]
	NEXT: Grouped by [2]
	PREV: Combined by [2, 4, 6, 8, 10, 12, 14]
	PREV: Grouped by [2, 4, 6, 8, 10, 12, 14]
	 || 2(1,0) <-->  || 4(1,0), 2(1,0)
	0-1 || 2(1,1) <--> 0-1 || 4(1,1), 2(1,1)
	0-2 || 2(1,1) <--> 0-2 || 4(1,1), 2(1,1)
	0-4 || 2(1,1) <--> 0-4 || 4(1,1), 2(1,1)
	0-8 || 2(1,1) <--> 0-8 || 4(1,1), 2(1,1)
	0-16 || 2(1,1) <--> 0-16 || 4(1,1), 2(1,1)
	1-1 || 2(1,1) <--> 1-1 || 4(1,1), 2(1,1)
	1-2 || 2(1,1) <--> 1-2 || 4(1,1), 2(1,1)
	1-4 || 2(1,1) <--> 1-4 || 4(1,1), 2(1,1)
	1-8 || 2(1,1) <--> 1-8 || 4(1,1), 2(1,1)
	1-16 || 2(1,1) <--> 1-16 || 4(1,1), 2(1,1)
	0-1,1-1 || 2(1,2) <--> 0-1,1-1 || 4(1,2), 2(1,2)
	0-1,1-2 || 2(1,2) <--> 0-1,1-2 || 4(1,2), 2(1,2)
	0-1,1-4 || 2(1,2) <--> 0-1,1-4 || 4(1,2), 2(1,2)
	0-1,1-8 || 2(1,2) <--> 0-1,1-8 || 4(1,2), 2(1,2)
	0-1,1-16 || 2(1,2) <--> 0-1,1-16 || 4(1,2), 2(1,2)
	0-2,1-1 || 2(1,2) <--> 0-2,1-1 || 4(1,2), 2(1,2)
	0-2,1-2 || 2(1,2) <--> 0-2,1-2 || 4(1,2), 2(1,2)
	0-2,1-4 || 2(1,2)

In [36]:
for s2 in s:
    s2.consolidate(set(), resource2capacity)
plotly_show4(s)
s_final = SIM.combine_combineable(s, set())[0]
data = s_final.mappings[0].data
# Sort data by the columns "Latency" and "Energy"
data = data.sort_values(by=["Latency", "Energy"])


# VLSI and ISSCC Add some content on superconducting and photonics to
# show Plan for a ~30 minute talk. Have some backup slides in case there
# are questions
print(f"# of mappings: {len(data)}")
plotly_show_latency({0: data})
# plotly_show_latency({0: data})

# of mappings: 11


In [34]:
from plotly.subplots import make_subplots

import plotly.graph_objects as go

fig = make_subplots(specs=[[{"secondary_y": True}]])

fig.add_trace(go.Scatter(x=list(range(len(nbuckets))), y=nbuckets, mode='lines+markers', name='Number of Buckets'), secondary_y=False)
fig.add_trace(go.Scatter(x=list(range(len(nmappings))), y=nmappings, mode='lines+markers', name='Number of Mappings'), secondary_y=True)

fig.update_layout(title='Number of Buckets and Mappings',
                  xaxis_title='Einsum Number')

fig.update_yaxes(title_text='Number of Buckets', secondary_y=False)
fig.update_yaxes(title_text='Number of Mappings', secondary_y=True)

fig.show()