In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os


font_size_base = 14 
plt.rcParams.update({
    'font.size': font_size_base,
    'axes.titlesize': font_size_base + 2,
    'axes.labelsize': font_size_base,
    'xtick.labelsize': font_size_base - 2,
    'ytick.labelsize': font_size_base - 2,
    'legend.fontsize': font_size_base - 4, 
    'figure.titlesize': font_size_base + 4 
})
output_dir = "./roofline_plots"


MEM_BANDWIDTH_GB_s = 400  
CLOCK_FREQ_MHZ = 940      # Accelerator clock frequency in MHz
FLOPS_PER_PE_PER_CYCLE = 2 # assumed, not given 
SRAM_CAPACITY_BYTES = 4 * 1024 * 1024 # 4MB Shared Global Buffer from arch-2d.yaml

# 2D PE Array 
PE_ROWS_2D = 256
PE_COLS_2D = 256
NUM_PES_2D = PE_ROWS_2D * PE_COLS_2D 

# 1D PE Array
NUM_PES_1D = 256 

# Peak perf
PEAK_PERF_2D_GFLOPS = (NUM_PES_2D * FLOPS_PER_PE_PER_CYCLE * CLOCK_FREQ_MHZ) / 1000
PEAK_PERF_1D_GFLOPS = (NUM_PES_1D * FLOPS_PER_PE_PER_CYCLE * CLOCK_FREQ_MHZ) / 1000


RIDGE_POINT_2D = PEAK_PERF_2D_GFLOPS / MEM_BANDWIDTH_GB_s
RIDGE_POINT_1D = PEAK_PERF_1D_GFLOPS / MEM_BANDWIDTH_GB_s

# bert
N_SEQ_REF_BERT = 1024     
D_MODEL_BERT = 1024         
N_HEADS_BERT = 16          
D_HEAD_BERT = D_MODEL_BERT // N_HEADS_BERT
D_FF_BERT = D_MODEL_BERT * 4 

# Pythia-14M 
N_SEQ_REF_PYTHIA = 2048     
D_MODEL_PYTHIA = 128
N_HEADS_PYTHIA = 4
D_HEAD_PYTHIA = D_MODEL_PYTHIA // N_HEADS_PYTHIA
D_FF_PYTHIA = D_MODEL_PYTHIA * 4 

# OPT-66B 
N_SEQ_REF_OPT = 2048       
D_MODEL_OPT = 9216
N_HEADS_OPT = 72
D_HEAD_OPT = D_MODEL_OPT // N_HEADS_OPT
D_FF_OPT = 36864           

BYTES_PER_ELEMENT_BF16 = 2 
GELU_FLOPS_PER_ELEMENT = 10 
LAYERNORM_FLOPS_PER_ELEMENT = 5

def get_achieved_perf(oi, peak_perf_for_type, bandwidth, ridge_point_for_type):

    if oi < ridge_point_for_type:
        return oi * bandwidth
    else:
        return peak_perf_for_type

def plot_roofline_generic(fig, ax, title, workload_data_list, show_1d_roofline=True, fixed_xlim=None, fixed_ylim=None):

    # this is hacky, for prettier graph , dont change
    oi_min_plot_default = 0.01
    oi_max_plot_default = 30000 
    perf_min_plot_default = 10
    
    perf_max_roof_limit = PEAK_PERF_2D_GFLOPS
    if show_1d_roofline:
        perf_max_roof_limit = max(PEAK_PERF_1D_GFLOPS, PEAK_PERF_2D_GFLOPS)
    perf_max_plot_limit_default = perf_max_roof_limit * 1.5 


    all_ois_for_range = [oi_min_plot_default, oi_max_plot_default, RIDGE_POINT_2D]
    if show_1d_roofline:
        all_ois_for_range.append(RIDGE_POINT_1D)
        
    for wl_data in workload_data_list:
        if wl_data.get('bytes', 0) > 0 and wl_data.get('flops', 0) > 0:
            all_ois_for_range.append(wl_data['flops'] / wl_data['bytes'])
    
    # Filter out non-positive values and create sorted unique OI range
    positive_ois = [oi for oi in all_ois_for_range if oi > 0]
    if not positive_ois: 
        # shouldnt happen
         oi_range = np.array([oi_min_plot_default, oi_max_plot_default]) 
    else:
        oi_range = np.array(sorted(list(set(positive_ois))))
        oi_range = np.union1d(oi_range, [oi for oi in [oi_min_plot_default, oi_max_plot_default] if oi > 0])
        oi_range = oi_range[(oi_range >= oi_min_plot_default) & (oi_range <= oi_max_plot_default)]

    # Plot 2D Roofline 
    roofline_perf_2d = np.minimum(PEAK_PERF_2D_GFLOPS, oi_range * MEM_BANDWIDTH_GB_s)
    ax.plot(oi_range, roofline_perf_2d, label=f'2D PE Array Roof (Peak: {PEAK_PERF_2D_GFLOPS/1000:.1f} TFLOPs/s)', color='black', linestyle='-', linewidth=2.5)

    # Plot 1D Roofline (delete in 2,3,4)
    if show_1d_roofline:
        roofline_perf_1d = np.minimum(PEAK_PERF_1D_GFLOPS, oi_range * MEM_BANDWIDTH_GB_s) 
        ax.plot(oi_range, roofline_perf_1d, label=f'1D PE Array Roof (Peak: {PEAK_PERF_1D_GFLOPS:.0f} GFLOPs/s)', color='dimgray', linestyle='--', linewidth=2.5)

    plotted_ois = []
    plotted_perfs = []

    for wl_data in workload_data_list:
        flops = wl_data.get('flops', 0)
        bytes_val = wl_data.get('bytes', 0)
        

        oi = flops / bytes_val
        pe_type = wl_data.get('pe_type', 'unknown')
        
        achieved_perf = 0
        if pe_type == '2d':
            achieved_perf = get_achieved_perf(oi, PEAK_PERF_2D_GFLOPS, MEM_BANDWIDTH_GB_s, RIDGE_POINT_2D)
        elif pe_type == '1d':
            
            if show_1d_roofline:
                achieved_perf = get_achieved_perf(oi, PEAK_PERF_1D_GFLOPS, MEM_BANDWIDTH_GB_s, RIDGE_POINT_1D)
            else:
                continue # Skip 1D points if 1D roof is hidden
        if oi > 0 and achieved_perf > 0:
            ax.plot(oi, achieved_perf, marker=wl_data.get('marker', 'o'), markersize=11, linestyle='None', 
                    label=f"{wl_data.get('label','Unknown')} ({pe_type.upper()})", color=wl_data.get('color', 'blue'), 
                    markeredgewidth=0.8, markeredgecolor='black') # Added edge for clarity
            plotted_ois.append(oi)
            plotted_perfs.append(achieved_perf)

    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel('Operational Intensity (FLOPs/Byte)')
    ax.set_ylabel('Performance (GFLOPs/s)')
    ax.set_title(title)
    
    # for final plot - they should have same limit
    if fixed_xlim:
        ax.set_xlim(fixed_xlim)

    if fixed_ylim:
        ax.set_ylim(fixed_ylim)

    handles, labels = ax.get_legend_handles_labels()
    num_items = len(handles)
    legend_cols = 1 if num_items <= 6 else 2 
    legend_fontsize = plt.rcParams['legend.fontsize'] 
    
    ax.legend(handles=handles, labels=labels, fontsize=legend_fontsize, loc='lower right', ncol=legend_cols, framealpha=0.8) 
    ax.grid(True, which="both", ls="--", alpha=0.4, color='gray') # Lighter grid


def calculate_transformer_layer_components_fusemax(model_prefix, n_seq, d_model, n_heads, d_ff, 
                                                 attn_color='red', mlp_color='green', overall_color='blue',
                                                 attn_marker='o', mlp_marker='o', overall_marker='D'):

    d_head = d_model // n_heads
    components = []
    debug_prints = False 

    flops_qk_av_per_head = 2 * (n_seq**2 * d_head)
    flops_softmax_per_head = 3 * (n_seq**2)
    total_flops_core_attn = n_heads * (flops_qk_av_per_head * 2 + flops_softmax_per_head)
    bytes_core_attn = (4 * n_seq * d_model) * BYTES_PER_ELEMENT_BF16
    components.append({
        'label': f'{model_prefix} Attn', 'flops': total_flops_core_attn, 'bytes': bytes_core_attn,
        'pe_type': '2d', 'color': attn_color, 'marker': attn_marker
    })
    
    flops_gemm1_mlp = 2 * n_seq * d_ff * d_model
    flops_activation_mlp = GELU_FLOPS_PER_ELEMENT * n_seq * d_ff
    flops_gemm2_mlp = 2 * n_seq * d_model * d_ff
    total_flops_mlp = flops_gemm1_mlp + flops_activation_mlp + flops_gemm2_mlp
    
    mlp_gemm1_weight_bytes = d_model * d_ff * BYTES_PER_ELEMENT_BF16
    mlp_gemm2_weight_bytes = d_ff * d_model * BYTES_PER_ELEMENT_BF16
    
    adjusted_mlp_gemm1_w_bytes = mlp_gemm1_weight_bytes
    if mlp_gemm1_weight_bytes > SRAM_CAPACITY_BYTES:
        penalty_factor1 = mlp_gemm1_weight_bytes / SRAM_CAPACITY_BYTES
        adjusted_mlp_gemm1_w_bytes = mlp_gemm1_weight_bytes * penalty_factor1
        if debug_prints: print(f"INFO: {model_prefix} MLP GEMM1 weights ({mlp_gemm1_weight_bytes/1024/1024:.2f}MB) exceed SRAM ({SRAM_CAPACITY_BYTES/1024/1024:.2f}MB). Penalty: {penalty_factor1:.2f}x")

    adjusted_mlp_gemm2_w_bytes = mlp_gemm2_weight_bytes
    if mlp_gemm2_weight_bytes > SRAM_CAPACITY_BYTES:
        penalty_factor2 = mlp_gemm2_weight_bytes / SRAM_CAPACITY_BYTES
        adjusted_mlp_gemm2_w_bytes = mlp_gemm2_weight_bytes * penalty_factor2
        if debug_prints: print(f"INFO: {model_prefix} MLP GEMM2 weights ({mlp_gemm2_weight_bytes/1024/1024:.2f}MB) exceed SRAM ({SRAM_CAPACITY_BYTES/1024/1024:.2f}MB). Penalty: {penalty_factor2:.2f}x")

    bytes_mlp_weights_total_adjusted = adjusted_mlp_gemm1_w_bytes + adjusted_mlp_gemm2_w_bytes
    bytes_mlp_activations = ( (n_seq * d_model) + (n_seq * d_model) ) * BYTES_PER_ELEMENT_BF16
    bytes_mlp = bytes_mlp_activations + bytes_mlp_weights_total_adjusted
    
    components.append({
        'label': f'{model_prefix} MLP', 'flops': total_flops_mlp, 'bytes': bytes_mlp,
        'pe_type': '2d', 'color': mlp_color, 'marker': mlp_marker
    })

    flops_attn_projections = 4 * (2 * n_seq * d_model**2) 
    flops_layernorm_op = LAYERNORM_FLOPS_PER_ELEMENT * n_seq * d_model
    total_flops_overall = flops_attn_projections + total_flops_core_attn + total_flops_mlp + 2 * flops_layernorm_op
    
    one_attn_proj_weight_bytes = (d_model**2) * BYTES_PER_ELEMENT_BF16
    adjusted_one_attn_proj_w_bytes = one_attn_proj_weight_bytes
    if one_attn_proj_weight_bytes > SRAM_CAPACITY_BYTES:
        penalty_factor_attn = one_attn_proj_weight_bytes / SRAM_CAPACITY_BYTES
        adjusted_one_attn_proj_w_bytes = one_attn_proj_weight_bytes * penalty_factor_attn
        if debug_prints: print(f"INFO: {model_prefix} Attn Proj weights ({one_attn_proj_weight_bytes/1024/1024:.2f}MB per proj) exceed SRAM ({SRAM_CAPACITY_BYTES/1024/1024:.2f}MB). Penalty: {penalty_factor_attn:.2f}x")

    bytes_overall_weights_attn_proj_adjusted = adjusted_one_attn_proj_w_bytes * 4
    bytes_overall_activations = (2 * n_seq * d_model) * BYTES_PER_ELEMENT_BF16
    total_bytes_overall = bytes_overall_activations + bytes_overall_weights_attn_proj_adjusted + bytes_mlp_weights_total_adjusted
    
    components.append({
        'label': f'{model_prefix} Overall', 'flops': total_flops_overall, 'bytes': total_bytes_overall,
        'pe_type': '2d', 'color': overall_color, 'marker': overall_marker
    })
    return components


workloads_for_plot1 = []
bert_components_fusemax = calculate_transformer_layer_components_fusemax(
    "BERT", N_SEQ_REF_BERT, D_MODEL_BERT, N_HEADS_BERT, D_FF_BERT,
    attn_color='firebrick', mlp_color='forestgreen', overall_color='royalblue', 
    attn_marker='o', mlp_marker='o', overall_marker='D'
)
bert_attn_fusemax_data = next(c for c in bert_components_fusemax if "Attn" in c["label"])
bert_mlp_fusemax_data = next(c for c in bert_components_fusemax if "MLP" in c["label"])
bert_overall_fusemax_data = next(c for c in bert_components_fusemax if "Overall" in c["label"])

workloads_for_plot1.append({**bert_attn_fusemax_data, 'label': 'Attn (FuseMax-BERT)'})
workloads_for_plot1.append({**bert_mlp_fusemax_data, 'label': 'MLP (FuseMax-BERT)'})
workloads_for_plot1.append({**bert_overall_fusemax_data, 'label': 'Overall (FuseMax-BERT)'})

bytes_core_attn_unfused_bert = bert_attn_fusemax_data['bytes'] * 2.5
workloads_for_plot1.append({
    'label': 'Attn (Unfused-BERT)', 'flops': bert_attn_fusemax_data['flops'], 'bytes': bytes_core_attn_unfused_bert,
    'pe_type': '2d', 'color': 'darkorange', 'marker': 'X' 
})
mlp_activation_bytes_bert_unfused = ( (N_SEQ_REF_BERT * D_MODEL_BERT) + 
                                (N_SEQ_REF_BERT * D_FF_BERT) * 4 + 
                                (N_SEQ_REF_BERT * D_MODEL_BERT)   
                              ) * BYTES_PER_ELEMENT_BF16
mlp_weight_bytes_bert_unfused = ( (D_MODEL_BERT * D_FF_BERT) + (D_FF_BERT * D_MODEL_BERT) ) * BYTES_PER_ELEMENT_BF16 
bytes_mlp_unfused_bert = mlp_activation_bytes_bert_unfused + mlp_weight_bytes_bert_unfused
workloads_for_plot1.append({
    'label': 'MLP (Unfused-BERT)', 'flops': bert_mlp_fusemax_data['flops'], 'bytes': bytes_mlp_unfused_bert,
    'pe_type': '2d', 'color': 'gold', 'marker': 'X' 
})
bytes_core_attn_flat_bert = bert_attn_fusemax_data['bytes'] * 1.75
workloads_for_plot1.append({
    'label': 'Attn (FLAT-BERT)', 'flops': bert_attn_fusemax_data['flops'], 'bytes': bytes_core_attn_flat_bert,
    'pe_type': '2d', 'color': 'lightcoral', 'marker': 'P' 
})
bytes_mlp_flat_bert = (bert_mlp_fusemax_data['bytes'] + bytes_mlp_unfused_bert) / 2.2 
workloads_for_plot1.append({
    'label': 'MLP (FLAT-BERT)', 'flops': bert_mlp_fusemax_data['flops'], 'bytes': bytes_mlp_flat_bert,
    'pe_type': '2d', 'color': 'yellowgreen', 'marker': 'P' 
})

flops_softmax_1d_bert = 3 * (N_SEQ_REF_BERT**2) 
bytes_softmax_1d_bert = (N_SEQ_REF_BERT**2 * BYTES_PER_ELEMENT_BF16) * 2 
workloads_for_plot1.append({'label': 'Softmax (1D)', 'flops': flops_softmax_1d_bert * N_HEADS_BERT, 'bytes': bytes_softmax_1d_bert * N_HEADS_BERT, 'pe_type': '1d', 'color': 'magenta', 'marker': 's'})
flops_gelu_1d_bert = GELU_FLOPS_PER_ELEMENT * N_SEQ_REF_BERT * D_FF_BERT
bytes_gelu_1d_bert = (N_SEQ_REF_BERT * D_FF_BERT * BYTES_PER_ELEMENT_BF16) * 2
workloads_for_plot1.append({'label': 'GELU (1D)', 'flops': flops_gelu_1d_bert, 'bytes': bytes_gelu_1d_bert, 'pe_type': '1d', 'color': 'cyan', 'marker': 's'})
flops_layernorm_1d_bert = LAYERNORM_FLOPS_PER_ELEMENT * N_SEQ_REF_BERT * D_MODEL_BERT
bytes_layernorm_1d_bert = (N_SEQ_REF_BERT * D_MODEL_BERT * BYTES_PER_ELEMENT_BF16) * 2
workloads_for_plot1.append({'label': 'LayerNorm (1D)', 'flops': flops_layernorm_1d_bert, 'bytes': bytes_layernorm_1d_bert, 'pe_type': '1d', 'color': 'orange', 'marker': 's'})

workloads_for_plot2 = []
seq_length_configs_plot2 = [
    {'label_suffix': 'N_seq=256 (Short)', 'N': 256, 'color': 'darkviolet', 'marker':'P'}, # Changed N back to 256
    {'label_suffix': f'N_seq={N_SEQ_REF_BERT} (Medium)', 'N': N_SEQ_REF_BERT, 'color': 'royalblue', 'marker':'D'},
    {'label_suffix': 'N_seq=4096 (Long)', 'N': 4096, 'color': 'saddlebrown', 'marker':'X'} 
]
for config in seq_length_configs_plot2:
    n_seq = config['N']
    bert_overall_components = calculate_transformer_layer_components_fusemax(
        "BERT", n_seq, D_MODEL_BERT, N_HEADS_BERT, D_FF_BERT,
        overall_color=config['color'], overall_marker=config['marker'] 
    )
    overall_data = next(c for c in bert_overall_components if "Overall" in c["label"])
    workloads_for_plot2.append({
        **overall_data, 'label': f'BERT Overall ({config["label_suffix"]})'
    })

N_SEQ_REF_OTHERS = 2048 
workloads_for_plot3 = []
opt_components = calculate_transformer_layer_components_fusemax(
    "OPT", N_SEQ_REF_OTHERS, D_MODEL_OPT, N_HEADS_OPT, D_FF_OPT,
    attn_color='purple', mlp_color='orange', overall_color='red',
    attn_marker='o', mlp_marker='o', overall_marker='D' 
)
workloads_for_plot3.extend(opt_components)

pythia_components = calculate_transformer_layer_components_fusemax(
    "Pythia", N_SEQ_REF_OTHERS, D_MODEL_PYTHIA, N_HEADS_PYTHIA, D_FF_PYTHIA,
    attn_color='green', mlp_color='indigo', overall_color='yellow',
    attn_marker='^', mlp_marker='^', overall_marker='P' 
)
workloads_for_plot3.extend(pythia_components)

workloads_for_plot4 = []
seq_length_configs_plot4 = [
    {'label_suffix': 'N_seq=256 (Short)', 'N': 256, 'marker': 'P'}, 
    {'label_suffix': 'N_seq=4096 (Long)', 'N': 4096, 'marker': 'D'}
]
model_configs_plot4 = [
    {'prefix': 'BERT', 'd_model': D_MODEL_BERT, 'n_heads': N_HEADS_BERT, 'd_ff': D_FF_BERT, 'color': 'royalblue'},
    {'prefix': 'OPT', 'd_model': D_MODEL_OPT, 'n_heads': N_HEADS_OPT, 'd_ff': D_FF_OPT, 'color': 'indigo'},
    {'prefix': 'Pythia', 'd_model': D_MODEL_PYTHIA, 'n_heads': N_HEADS_PYTHIA, 'd_ff': D_FF_PYTHIA, 'color': 'darkslategray'}
]

for model_cfg in model_configs_plot4:
    for seq_cfg in seq_length_configs_plot4:
        n_seq = seq_cfg['N']
        components = calculate_transformer_layer_components_fusemax(
            model_cfg['prefix'], n_seq, model_cfg['d_model'], model_cfg['n_heads'], model_cfg['d_ff'],
            overall_color=model_cfg['color'], overall_marker=seq_cfg['marker']
        )
        overall_data = next(c for c in components if "Overall" in c["label"])
        workloads_for_plot4.append({
            **overall_data,
            'label': f'{model_cfg["prefix"]} Overall ({seq_cfg["label_suffix"]})'
        })


# fix limit - all plots are the same
all_workloads_combined = workloads_for_plot1 + workloads_for_plot2 + workloads_for_plot3 + workloads_for_plot4
min_oi_global = float('inf')
max_oi_global = float('-inf')
min_perf_global = float('inf')
max_perf_global = float('-inf')

for wl in all_workloads_combined:
    flops = wl.get('flops', 0)
    bytes_val = wl.get('bytes', 0)
    pe_type = wl.get('pe_type', 'unknown')

    if bytes_val <= 0 or flops <= 0:
        continue
    
    oi = flops / bytes_val
    
    achieved_perf = 0
    if pe_type == '2d':
        achieved_perf = get_achieved_perf(oi, PEAK_PERF_2D_GFLOPS, MEM_BANDWIDTH_GB_s, RIDGE_POINT_2D)
    elif pe_type == '1d': 
         if any(w['pe_type'] == '1d' for w in workloads_for_plot1): # Check if 1D points exist at all
              achieved_perf = get_achieved_perf(oi, PEAK_PERF_1D_GFLOPS, MEM_BANDWIDTH_GB_s, RIDGE_POINT_1D)

    min_oi_global = min(min_oi_global, oi)
    max_oi_global = max(max_oi_global, oi)
    min_perf_global = min(min_perf_global, achieved_perf)
    max_perf_global = max(max_perf_global, achieved_perf)

max_perf_global = max(max_perf_global, PEAK_PERF_2D_GFLOPS)
if any(w['pe_type'] == '1d' for w in workloads_for_plot1):
     max_perf_global = max(max_perf_global, PEAK_PERF_1D_GFLOPS)


oi_pad_factor_low = 0.1  
oi_pad_factor_high = 0.3 
perf_pad_factor_low = 0.1
perf_pad_factor_high = 0.2

final_oi_min = 10**(np.log10(min_oi_global) - (np.log10(max_oi_global) - np.log10(min_oi_global)) * oi_pad_factor_low) if min_oi_global > 0 and max_oi_global > 0 else 0.01
final_oi_max = 10**(np.log10(max_oi_global) + (np.log10(max_oi_global) - np.log10(min_oi_global)) * oi_pad_factor_high) if min_oi_global > 0 and max_oi_global > 0 else 30000
final_perf_min = 10**(np.log10(min_perf_global) - (np.log10(max_perf_global) - np.log10(min_perf_global)) * perf_pad_factor_low) if min_perf_global > 0 and max_perf_global > 0 else 10
final_perf_max = 10**(np.log10(max_perf_global) + (np.log10(max_perf_global) - np.log10(min_perf_global)) * perf_pad_factor_high) if min_perf_global > 0 and max_perf_global > 0 else PEAK_PERF_2D_GFLOPS * 2

final_oi_min = max(0.01, final_oi_min)
final_perf_min = max(10, final_perf_min)

final_xlim = (final_oi_min, final_oi_max)
final_ylim = (final_perf_min, final_perf_max)

print(f"Global Axis Limits Set:")
print(f"  X-Axis (OI): {final_xlim}")
print(f"  Y-Axis (Perf): {final_ylim}")


os.makedirs(output_dir, exist_ok=True)

# Plot 1
fig1, ax1 = plt.subplots(figsize=(12, 9)) 
plot_roofline_generic(fig1, ax1, f'Plot 1: BERT Component Ops & Baselines (N_seq={N_SEQ_REF_BERT})', 
                      workloads_for_plot1, show_1d_roofline=True, 
                      fixed_xlim=final_xlim, fixed_ylim=final_ylim) 
fig1.savefig(os.path.join(output_dir, "roofline_plot_1_bert_baselines.jpeg"), bbox_inches='tight', dpi=500)
plt.close(fig1) 

# Plot 2
fig2, ax2 = plt.subplots(figsize=(12, 9))
plot_roofline_generic(fig2, ax2, 'Plot 2: BERT Overall Layer - FuseMax (Varying Seq Length)', 
                      workloads_for_plot2, show_1d_roofline=False, 
                      fixed_xlim=final_xlim, fixed_ylim=final_ylim)
fig2.savefig(os.path.join(output_dir, "roofline_plot_2_bert_varying_seq.jpeg"), bbox_inches='tight', dpi=500)
plt.close(fig2)

# Plot 3
fig3, ax3 = plt.subplots(figsize=(12, 9))
plot_roofline_generic(fig3, ax3, f'Plot 3: OPT & Pythia Components - FuseMax (N_seq={N_SEQ_REF_OTHERS}, SRAM Constraint)', 
                      workloads_for_plot3, show_1d_roofline=False, 
                      fixed_xlim=final_xlim, fixed_ylim=final_ylim) 
fig3.savefig(os.path.join(output_dir, "roofline_plot_3_opt_pythia_components.jpeg"), bbox_inches='tight', dpi=500)
plt.close(fig3)

# Plot 4
fig4, ax4 = plt.subplots(figsize=(12, 9))
plot_roofline_generic(fig4, ax4, 'Plot 4: Model Overall Layer - FuseMax (Varying Seq Length, SRAM Constraint)', 
                      workloads_for_plot4, show_1d_roofline=False, 
                      fixed_xlim=final_xlim, fixed_ylim=final_ylim) #
fig4.savefig(os.path.join(output_dir, "roofline_plot_4_overall_comparison.jpeg"), bbox_inches='tight', dpi=500)
plt.close(fig4)


Global Axis Limits Set:
  X-Axis (OI): (np.float64(0.3164996779046972), np.float64(55719.23518047689))
  Y-Axis (Perf): (np.float64(164.3499598051711), np.float64(410526.394420279))
Roofline plots saved to directory: ./roofline_plots

--- System & Roofline Parameters ---
Memory Bandwidth: 400 GB/s
2D PE Array: 256x256 PEs
  Peak Performance (2D): 123207.68 GFLOPs/s (123.21 TFLOPs/s)
  Ridge Point (2D): 308.02 FLOPs/Byte
1D PE Array: 1x256 PEs
  Peak Performance (1D): 481.28 GFLOPs/s
  Ridge Point (1D): 1.20 FLOPs/Byte

--- Workloads for Plot 1 (BERT Components, N_seq=1024) ---
  Attn (FuseMax-BERT)                           (2d): FLOPs=4.35e+09, Bytes=8.39e+06, OI=518.00, Est. Perf=123.2 TFLOPs/s
  MLP (FuseMax-BERT)                            (2d): FLOPs=1.72e+10, Bytes=3.77e+07, OI=456.22, Est. Perf=123.2 TFLOPs/s
  Overall (FuseMax-BERT)                        (2d): FLOPs=3.02e+10, Bytes=4.61e+07, OI=653.86, Est. Perf=123.2 TFLOPs/s
  Attn (Unfused-BERT)                           (2

Global Axis Limits Set:
  X-Axis (OI): (0.32, 55719.24)
  Y-Axis (Perf): (164.35, 410526.39)
Roofline plots saved to directory: ./roofline_plots
