# Evaluation of ReDash
## Imports

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
plt.style.use('seaborn-v0_8-whitegrid')

import tol_colors as tc
plt.rc('axes', prop_cycle=plt.cycler('color', list(tc.tol_cset('bright'))))
plt.cm.register_cmap('iridescent', tc.tol_cmap('iridescent'))
plt.rc('image', cmap='iridescent')

# Use the following for all plots except for runtime distribution plots
tex_fonts = {
    # Use LaTeX to write all text
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Times"],
    # Use 10pt font in plots, to match 10pt font in document
    "axes.labelsize": 17,
    "font.size": 16,  # same size as ticks label
    # Make the legend/label fonts a little smaller
    "legend.fontsize": 16,
    "xtick.labelsize": 15,
    "ytick.labelsize": 15,
    "lines.linewidth": 2,
    "axes.titlesize": 17
}

plt.rcParams.update(tex_fonts)

## Utils

In [None]:
def get_latest_file(path, suffix):
    paths = sorted(Path(path).iterdir(), key=os.path.getmtime, reverse=True)
    paths = [str(p) for p in paths]
    for p in paths:
        if p.endswith(suffix):
            return p

def relative_std(x):
  # https://github.com/pandas-dev/pandas/issues/33517
  if not isinstance(x, pd.Series):
        raise TypeError
  
  return np.nanstd(x)/np.nanmean(x)

## Micro-Benchmarks

#### Rescaling old vs new

In [None]:
rescaling = get_latest_file("micro_benchmarks/data", "_rescaling.csv")
print(rescaling)


rescaling_data = pd.read_csv(rescaling, skipinitialspace=True)
rescaling_data = rescaling_data.drop(columns=["run", "crt_base_size", "gpu_mem_usage"])
rescaling_data = rescaling_data.groupby(["dimensions", "type", "use_legacy_scaling"]).aggregate(['mean', 'std', relative_std])
rescaling_data = rescaling_data.unstack([-1])
rescaling_data

In [None]:
fig, ax = plt.subplots()
mean_cpu = rescaling_data["runtime"]["mean"].xs("CPU", level="type")
std_cpu = rescaling_data["runtime"]["std"].xs("CPU", level="type")

mean_cpu.plot(kind="bar", ax=ax, yerr=std_cpu, capsize=4)
ax.set_xlabel("Inputs")
ax.set_ylabel("Runtime (ms)")
ax.set_xticklabels(mean_cpu.index, rotation=45)

handles, _ = ax.get_legend_handles_labels()
ax.legend(handles, ['New Scaling', 'Legacy Scaling'])

fig.tight_layout()
plt.show()
fig.savefig("micro_benchmarks/data/rescaling_runtime.pdf", format='pdf', bbox_inches='tight')

### CPU-Scalability

#### Rescaling (1-16 Threads)

In [None]:
path_rescaling_scale = get_latest_file("micro_benchmarks/data/scalability", "_rescaling.csv")
print(path_rescaling_scale)

rescaling_scale_data = pd.read_csv(path_rescaling_scale, skipinitialspace=True)
rescaling_scale_data = rescaling_scale_data.drop(columns=["run", "crt_base_size", "type"])
rescaling_scale_data = rescaling_scale_data.groupby(["dims", "nr_threads", 'l']).aggregate(['mean', 'std', relative_std])
rescaling_scale_data = rescaling_scale_data.unstack([-2])
rescaling_scale_data

In [None]:
path_rescaling_scale_MORE_INPUT_SIZES = get_latest_file("micro_benchmarks/data/scalability", "_rescaling_more_inputs.csv")
print(path_rescaling_scale_MORE_INPUT_SIZES)

rescaling_scale_data_MORE_INPUT_SIZES = pd.read_csv(path_rescaling_scale_MORE_INPUT_SIZES, skipinitialspace=True)
rescaling_scale_data_MORE_INPUT_SIZES = rescaling_scale_data_MORE_INPUT_SIZES.drop(columns=["run", "crt_base_size", "type"])
rescaling_scale_data_MORE_INPUT_SIZES = rescaling_scale_data_MORE_INPUT_SIZES.groupby(["dims", "nr_threads", 'l']).aggregate(['mean', 'std', relative_std])
rescaling_scale_data_MORE_INPUT_SIZES = rescaling_scale_data_MORE_INPUT_SIZES.unstack([-2])
rescaling_scale_data_MORE_INPUT_SIZES

In [None]:
path_rescaling_scale_DASH = get_latest_file("micro_benchmarks/data/scalability", "IMPORTED_DASH_RESCALING_SCALABILITY.csv")
print(path_rescaling_scale_DASH)

rescaling_scale_data_DASH = pd.read_csv(path_rescaling_scale_DASH, skipinitialspace=True)
rescaling_scale_data_DASH = rescaling_scale_data_DASH.drop(columns=["run", "crt_base_size", "type"])
rescaling_scale_data_DASH = rescaling_scale_data_DASH.groupby(["dims", "nr_threads", 'l']).aggregate(['mean', 'std', relative_std])
rescaling_scale_data_DASH = rescaling_scale_data_DASH.unstack([-2])
rescaling_scale_data_DASH

In [None]:
path_rescaling_scale_DASH_MORE_INPUT_SIZES = get_latest_file("micro_benchmarks/data/scalability", "IMPORTED_DASH_RESCALING_SCALABILITY_MORE_INPUT_SIZES.csv")
print(path_rescaling_scale_DASH_MORE_INPUT_SIZES)

rescaling_scale_data_DASH_MORE_INPUT_SIZES = pd.read_csv(path_rescaling_scale_DASH_MORE_INPUT_SIZES, skipinitialspace=True)
rescaling_scale_data_DASH_MORE_INPUT_SIZES = rescaling_scale_data_DASH_MORE_INPUT_SIZES.drop(columns=["run", "crt_base_size", "type"])
rescaling_scale_data_DASH_MORE_INPUT_SIZES = rescaling_scale_data_DASH_MORE_INPUT_SIZES.groupby(["dims", "nr_threads", 'l']).aggregate(['mean', 'std', relative_std])
rescaling_scale_data_DASH_MORE_INPUT_SIZES = rescaling_scale_data_DASH_MORE_INPUT_SIZES.unstack([-2])
rescaling_scale_data_DASH_MORE_INPUT_SIZES

#### Plot

#### Rescaling (threads, input sizes)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Get the common dims values between legacy and DASH datasets and keep only the first two
dims_common = sorted(set(rescaling_scale_data.index.get_level_values('dims')).intersection(
                       set(rescaling_scale_data_DASH.index.get_level_values('dims'))))[:1]

# Colors for ℓ=3 and ℓ=5.
color_l3 = "tab:blue"
color_l5 = "tab:orange"

# =====================
# LEFT SUBPLOT: Scaling Curves (line plots with error bars)
# =====================
# Assume dims_common is defined; here we pick the first dimension.
d = dims_common[0]  # adjust as needed

# Prepare the left axis
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(12,4))

# Plot ReDASH scaling curves for ℓ=3 and ℓ=5.
for l in [3, 5]:
    data = rescaling_scale_data.loc[(d, l)]
    runtime_mean = data['runtime']['mean']
    runtime_std = data['runtime']['std']
    threads = runtime_mean.index.astype(int)
    color = color_l3 if l == 3 else color_l5
    ax_left.errorbar(threads, runtime_mean, yerr=runtime_std, marker='o', linestyle='-', 
                     color=color, capsize=4, label=f"ReDASH, $\\ell={l}$")
    
# Plot DASH scaling curves for ℓ=3 and ℓ=5.
for l in [3, 5]:
    data = rescaling_scale_data_DASH.loc[(d, l)]
    runtime_mean = data['runtime']['mean']
    runtime_std = data['runtime']['std']
    threads = runtime_mean.index.astype(int)
    color = color_l3 if l == 3 else color_l5
    ax_left.errorbar(threads, runtime_mean, yerr=runtime_std, marker='*', linestyle='--', 
                     color=color, capsize=4, label=f"DASH, $\\ell={l}$")
    
ax_left.set_xlabel("Thread Count")
ax_left.set_xticks(np.arange(1, 17))
ax_left.set_ylabel("Runtime (ms)")
ax_left.grid(False)

# =====================
# RIGHT SUBPLOT: Input Runtime Comparison (line plots with error bars, equally spaced x-axis)
# =====================
# Extract runtime means and standard deviations for 16 threads from both datasets.
redash_runtime = rescaling_scale_data_MORE_INPUT_SIZES["runtime"]["mean"].loc[:, 16]
redash_std     = rescaling_scale_data_MORE_INPUT_SIZES["runtime"]["std"].loc[:, 16]
dash_runtime   = rescaling_scale_data_DASH_MORE_INPUT_SIZES["runtime"]["mean"].loc[:, 16]# TODO: dash data here
dash_std       = rescaling_scale_data_DASH_MORE_INPUT_SIZES["runtime"]["std"].loc[:, 16] # TODO: dash data here

# Define the input counts and the ℓ values of interest.
inputs = [128, 256, 1024, 2048, 4096, 8192, 16384]
l_values = [3, 5]

# Create positions for the groups that are equally spaced.
positions = np.arange(len(inputs))

# Extract the corresponding runtime means and stds.
redash_l3     = redash_runtime.loc[inputs, l_values[0]]
redash_l3_std = redash_std.loc[inputs, l_values[0]]
redash_l5     = redash_runtime.loc[inputs, l_values[1]]
redash_l5_std = redash_std.loc[inputs, l_values[1]]
dash_l3       = dash_runtime.loc[inputs, l_values[0]]
dash_l3_std   = dash_std.loc[inputs, l_values[0]]
dash_l5       = dash_runtime.loc[inputs, l_values[1]]
dash_l5_std   = dash_std.loc[inputs, l_values[1]]

# Plot as line plots with error bars using the new equally spaced positions.
ax_right.errorbar(positions, redash_l3, yerr=redash_l3_std, marker='o', linestyle='-', 
                    color=color_l3, capsize=4, label="ReDASH, $\\ell=3$")
ax_right.errorbar(positions, redash_l5, yerr=redash_l5_std, marker='o', linestyle='-', 
                    color=color_l5, capsize=4, label="ReDASH, $\\ell=5$")
ax_right.errorbar(positions, dash_l3, yerr=dash_l3_std, marker='o', linestyle='--', 
                    color=color_l3, capsize=4, label="DASH, $\\ell=3$")
ax_right.errorbar(positions, dash_l5, yerr=dash_l5_std, marker='o', linestyle='--', 
                    color=color_l5, capsize=4, label="DASH, $\\ell=5$")

ax_right.set_xlabel("Number of Inputs")
ax_right.set_ylabel("Runtime (ms)")
ax_right.set_xticks(positions)
ax_right.set_xticklabels(inputs)
# ax_right.set_yscale('log')
ax_right.grid(False)

# =====================
# Shared Legend
# =====================
# Collect handles and labels from both axes.
handles_left, labels_left = ax_left.get_legend_handles_labels()
handles_right, labels_right = ax_right.get_legend_handles_labels()
all_handles = handles_left + handles_right
all_labels = labels_left + labels_right

# Remove duplicates while preserving order.
unique = {}
for handle, label in zip(all_handles, all_labels):
    if label not in unique:
        unique[label] = handle

fig.legend(unique.values(), unique.keys(), loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.02))
fig.tight_layout(rect=[0, 0, 1, 0.93])
# plt.show()

# Optionally, save the merged plot.
fig.savefig("micro_benchmarks/data/merged_runtime_comparison.pdf", format='pdf', bbox_inches='tight')

## Model Benchmarks

In [None]:
path_models = get_latest_file("model_benchmarks/data", "_garbled_models.csv")
# path_models = get_latest_file("model_benchmarks/data", "_sgx_models.csv")
print(path_models)
models_data = pd.read_csv(path_models, skipinitialspace=True)
models_data = models_data[models_data["relu_acc"] == 100]
models_data = models_data.drop(columns=["target_crt_base_size", "label", "infered_label", "relu_acc", "type"])
models_data = models_data.groupby(["model", "optimize_bases"]).aggregate(['mean', 'std', relative_std])
models_data = models_data.unstack()
models_data

### Add available data of other frameworks

In [None]:
# Reset the index so that we can iterate through the numbers.
# This will help us to get the x tick positions
df = models_data["runtime"]["mean"]
df = df.reset_index() # Uncomment if you want to use the index as x ticks
# Add data from gnn paper
df["DASH (CPU)"] = [10263, 23959]
# df["DASH (GPU)"] = [1332, 1443]
df

### Comparison of Dash's model runtimes against other frameworks

In [None]:
# Filter df for only models F (capital) and f (lowercase)
df_filtered = df[df["model"].isin(["MODEL_F_GNNP_POOL_REPL", "MODEL_F_MINIONN_POOL_REPL"])]

# Create a 1x2 grid of subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# Get the runtime columns (excluding "model")
runtime_columns = [col for col in df.columns if col != "model"]
# Swap the first two columns so that the ReDASH bars are switched
runtime_columns[0], runtime_columns[1] = runtime_columns[1], runtime_columns[0]
runtime_labels = [("ReDASH w/ CPM-Bases (CPU)" if col == 0 
                   else "ReDASH w/ Optimized Bases (CPU)" if col == 1 
                   else col) for col in runtime_columns]

# Loop through each filtered model and plot its runtime bars
for i, (idx, row) in enumerate(df_filtered.iterrows()):
    ax_model = axs[i]
    positions = list(range(len(runtime_columns)))
    if 'color_hatch_combinations' not in globals():
        colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
        hatches = ['', '//', '\\\\', '||', '--', '++', 'xx', 'oo', 'OO', '..', '**']
        color_hatch_combinations = [(colors[i % (len(colors)-1)], h) for i, h in enumerate(hatches)]
    for j, col in enumerate(runtime_columns):
        c, h = color_hatch_combinations[j % len(color_hatch_combinations)]
        ax_model.bar(positions[j], row[col], color=c, hatch=h, width=0.8)
    ax_model.set_xticks(positions)
    # Remove x-axis tick labels by not setting them:
    ax_model.set_xticklabels([])
    
    # Remove the "MODEL_" prefix and then rename as needed
    title = row["model"].replace("MODEL_", "")
    if title == "F_GNNP_POOL_REPL":
        title = "f"
    elif title == "F_MINIONN_POOL_REPL":
        title = "F"
    ax_model.set_title(title)
    ax_model.set_ylabel("Runtime (ms)")

# Create a single global legend using custom patches.
import matplotlib.patches as mpatches
legend_handles = []
for j, col in enumerate(runtime_columns):
    c, h = color_hatch_combinations[j % len(color_hatch_combinations)]
    label = "ReDASH w/ CPM-Bases (CPU)" if col == 0 else ("ReDASH w/ Optimized Bases (CPU)" if col == 1 else col)
    patch = mpatches.Patch(facecolor=c, hatch=h, label=label)
    legend_handles.append(patch)

# Place the global legend at the top center
fig.legend(handles=legend_handles, loc='upper center', bbox_to_anchor=(0.5, 0.98), ncol=len(runtime_columns))
fig.tight_layout(rect=[0, 0, 1, 0.93])
fig.savefig('model_benchmarks/data/models_runtime_comparison_subplots.pdf', format='pdf', bbox_inches='tight')
plt.show()


### Runtime Distribution over Layer-Types

In [None]:
path_models = get_latest_file("model_benchmarks/data", "IMPORTED_REDASH_RUNTIME_DISTRIBUTION_EVALUATION.csv")
print(path_models)
runtime_dist_data = pd.read_csv(path_models, skipinitialspace=True)
runtime_dist_data[runtime_dist_data["relu_acc"] == 100]

runtime_dist_data = runtime_dist_data.drop(columns=["target_crt_base_size", "relu_acc", "type"])
runtime_dist_data = runtime_dist_data.groupby(["model", "layer", "optimize_bases"]).aggregate(['mean', 'std', relative_std])

# add 0 second entries for missing layers (not all models contain all layers)
runtime_dist_data = runtime_dist_data.unstack(0)["runtime"]["mean"]
runtime_dist_data.fillna(0, inplace=True)
runtime_dist_data

In [None]:
path_models = get_latest_file("model_benchmarks/data", "IMPORTED_DASH_RUNTIME_DISTRIBUTION_EVALUATION.csv")
print(path_models)
runtime_dist_data_DASH = pd.read_csv(path_models, skipinitialspace=True)

# Only keep rows with relu_acc==100 and where "type" is not "GPU"
runtime_dist_data_DASH = runtime_dist_data_DASH[(runtime_dist_data_DASH["relu_acc"] == 100) & (runtime_dist_data_DASH["type"] != "GPU")]

runtime_dist_data_DASH = runtime_dist_data_DASH.drop(columns=["target_crt_base_size", "relu_acc", "type"])
runtime_dist_data_DASH = runtime_dist_data_DASH.groupby(["model", "layer"]).aggregate(['mean', 'std', relative_std])

# Add 0 second entries for missing layers (not all models contain all layers)
runtime_dist_data_DASH = runtime_dist_data_DASH.unstack(0)["runtime"]["mean"]
runtime_dist_data_DASH.fillna(0, inplace=True)
runtime_dist_data_DASH

In [None]:
# IMPORTED_EIGEN_runtime_distribution_evaluation.csv
path_eigen = get_latest_file("model_benchmarks/data", "IMPORTED_EIGEN_runtime_distribution_evaluation.csv")
print(path_eigen)
runtime_dist_data_EIGEN = pd.read_csv(path_eigen, skipinitialspace=True)
runtime_dist_data_EIGEN = runtime_dist_data_EIGEN[(runtime_dist_data_EIGEN["relu_acc"] == 100) &
                                                  (runtime_dist_data_EIGEN["type"] != "GPU")]
runtime_dist_data_EIGEN = runtime_dist_data_EIGEN.drop(columns=["target_crt_base_size", "relu_acc", "type"])
runtime_dist_data_EIGEN = runtime_dist_data_EIGEN.groupby(["model", "layer"]).aggregate(['mean', 'std', relative_std])
runtime_dist_data_EIGEN = runtime_dist_data_EIGEN.unstack(0)["runtime"]["mean"]
runtime_dist_data_EIGEN.fillna(0, inplace=True)
runtime_dist_data_EIGEN

# IMPORTED_INT32_EIGEN_runtime_distribution_evaluation.csv
path_int32_eigen = get_latest_file("model_benchmarks/data", "IMPORTED_INT32_EIGEN_runtime_distribution_evaluation.csv")
print(path_int32_eigen)
runtime_dist_data_INT32_EIGEN = pd.read_csv(path_int32_eigen, skipinitialspace=True)
runtime_dist_data_INT32_EIGEN = runtime_dist_data_INT32_EIGEN[(runtime_dist_data_INT32_EIGEN["relu_acc"] == 100) &
                                                              (runtime_dist_data_INT32_EIGEN["type"] != "GPU")]
runtime_dist_data_INT32_EIGEN = runtime_dist_data_INT32_EIGEN.drop(columns=["target_crt_base_size", "relu_acc", "type"])
runtime_dist_data_INT32_EIGEN = runtime_dist_data_INT32_EIGEN.groupby(["model", "layer"]).aggregate(['mean', 'std', relative_std])
runtime_dist_data_INT32_EIGEN = runtime_dist_data_INT32_EIGEN.unstack(0)["runtime"]["mean"]
runtime_dist_data_INT32_EIGEN.fillna(0, inplace=True)
runtime_dist_data_INT32_EIGEN

# IMPORTED_INT32_NOEIGEN_runtime_distribution_evaluation.csv
path_int32_noeigen = get_latest_file("model_benchmarks/data", "IMPORTED_INT32_NOEIGEN_runtime_distribution_evaluation.csv")
print(path_int32_noeigen)
runtime_dist_data_INT32_NOEIGEN = pd.read_csv(path_int32_noeigen, skipinitialspace=True)
runtime_dist_data_INT32_NOEIGEN = runtime_dist_data_INT32_NOEIGEN[(runtime_dist_data_INT32_NOEIGEN["relu_acc"] == 100) &
                                                                  (runtime_dist_data_INT32_NOEIGEN["type"] != "GPU")]
runtime_dist_data_INT32_NOEIGEN = runtime_dist_data_INT32_NOEIGEN.drop(columns=["target_crt_base_size", "relu_acc", "type"])
runtime_dist_data_INT32_NOEIGEN = runtime_dist_data_INT32_NOEIGEN.groupby(["model", "layer"]).aggregate(['mean', 'std', relative_std])
runtime_dist_data_INT32_NOEIGEN = runtime_dist_data_INT32_NOEIGEN.unstack(0)["runtime"]["mean"]
runtime_dist_data_INT32_NOEIGEN.fillna(0, inplace=True)
runtime_dist_data_INT32_NOEIGEN

In [None]:
# Select only the two models corresponding to f and F.
model_names = ["MODEL_F_GNNP_POOL_REPL", "MODEL_F_MINIONN_POOL_REPL"]

# For runtime_dist_data, the index is multi-index (layer, optimize_bases).
# Sum over optimize_bases so that each layer gets a single value per model.
df_redash = runtime_dist_data[model_names].groupby(level=0).sum()
# For runtime_dist_data_DASH, the index is simply the layer.
df_dash = runtime_dist_data_DASH[model_names]

# Determine the layer order.
# For ReDASH, note that the available layers are in the index of df_redash.
layer_order_redash = list(df_redash.index)
# For DASH, use the order as it appears (or force a consistent order if desired).
layer_order_dash = list(df_dash.index)

# Compute percentage contributions per layer for each model.
def compute_percentages(df, models, layers):
    # Create a dict: key=layer, value = list of percentages across models
    data = {layer: [] for layer in layers}
    for m in models:
        total = df[m].sum()
        for layer in layers:
            pct = (df.loc[layer, m] / total * 100) if total!=0 else 0
            data[layer].append(pct)
    return data

redash_percent = compute_percentages(df_redash, model_names, layer_order_redash)
dash_percent = compute_percentages(df_dash, model_names, layer_order_dash)

# Set up colors and hatches.
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
hatches = ['', '//', '\\\\', '||', '--', '++', 'xx', 'oo', 'OO', '..', '**']
color_hatch_combinations = [(colors[i % (len(colors)-1)], hatches[i % len(hatches)]) for i in range(max(len(layer_order_redash), len(layer_order_dash)))]

# Create subplots: left for ReDASH, right for DASH.
fig, axs = plt.subplots(1, 2, figsize=(12, 3), sharey=True)
fig.tight_layout(pad=3)

# For both axes, the y-axis corresponds to the two models.
y_positions = range(len(model_names))

# Plot for ReDASH.
left = [0, 0]  # starting left position for each model bar (two models)
for idx, layer in enumerate(layer_order_redash):
    c, h = color_hatch_combinations[idx]
    axs[0].barh(model_names, redash_percent[layer], left=left, label=layer, color=c, hatch=h)
    left = [l + v for l, v in zip(left, redash_percent[layer])]
axs[0].set_title("ReDASH")
axs[0].set_xlabel("Runtime share (\%)")
axs[0].set_ylabel("Model")

# Plot for DASH.
left_dash = [0, 0]
for idx, layer in enumerate(layer_order_dash):
    c, h = color_hatch_combinations[idx]
    axs[1].barh(model_names, dash_percent[layer], left=left_dash, label=layer, color=c, hatch=h)
    left_dash = [l + v for l, v in zip(left_dash, dash_percent[layer])]
axs[1].set_title("DASH")
axs[1].set_xlabel("Runtime share (\%)")

# Create a combined legend (using those from the left subplot).
# Here we combine the labels from both plots; the order is taken from layer_order_redash.
all_handles = []
for idx, layer in enumerate(layer_order_redash):
    c, h = color_hatch_combinations[idx]
    patch = plt.matplotlib.patches.Patch(facecolor=c, hatch=h, label=layer)
    all_handles.append(patch)
fig.legend(handles=all_handles, loc='upper center', ncol=len(layer_order_redash), bbox_to_anchor=(0.5, 1.1))

fig.savefig("model_benchmarks/data/runtime_distribution_Fmodels.pdf", format='pdf', bbox_inches='tight')
plt.show()


In [None]:
# Function to compute total runtime contributions per layer for each model.
def compute_totals(df, models, layers):
    # Returns a dict: key=layer, value = list of total runtimes per model
    data = {layer: [] for layer in layers}
    for m in models:
        for layer in layers:
            # If a layer is missing in the dataframe, use 0.
            val = df.loc[layer, m] if layer in df.index else 0
            data[layer].append(val)
    return data

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches


df_redash_optim = runtime_dist_data[model_names].xs(1, level="optimize_bases")
df_redash_cpm   = runtime_dist_data[model_names].xs(0, level="optimize_bases")
df_dash         = runtime_dist_data_DASH[model_names]  # already has a single index: layer

# Determine layer order for each dataset.
# For ReDASH use the union of layers from the two optimize_bases values
layer_order_redash = sorted(set(df_redash_optim.index).union(df_redash_cpm.index))
layer_order_dash   = list(df_dash.index)

# Use the union order for plotting across datasets
layers = layer_order_redash

# Compute total runtime values instead of percentages.
redash_optim_total = compute_totals(df_redash_optim, model_names, layers)
redash_cpm_total   = compute_totals(df_redash_cpm,   model_names, layers)
dash_total         = compute_totals(df_dash,         model_names, layers)

# print(redash_optim_total)
# print(redash_cpm_total)
# print(dash_total)

# Update the cpm and dash totals for model f manually with values from ReDASH paper (manual fix for ArXiV version)
# redash_cpm_total['approx_relu'][0] = 529.0
# redash_cpm_total['conv2d'][0] = 5135.0
# redash_cpm_total['dense'][0] = 0.5
# redash_cpm_total['rescale'][0] = 437.0
# dash_total['approx_relu'][0] = 538.0
# dash_total['conv2d'][0] = 6771.0
# dash_total['dense'][0] = 0.6
# dash_total['rescale'][0] = 2841.0

# Update the optim, cpm and dash totals for model F manually with values from ReDASH paper (manual fix for ArXiV version)
# redash_optim_total['approx_relu'][1] = 433.0
# redash_optim_total['conv2d'][1] = 2742.0
# redash_optim_total['dense'][1] = 0.2
# redash_optim_total['rescale'][1] = 455.0
# redash_cpm_total['approx_relu'][1] = 982.0
# redash_cpm_total['conv2d'][1] = 13393.0
# redash_cpm_total['dense'][1] = 0.6
# redash_cpm_total['rescale'][1] = 763.0
# dash_total['approx_relu'][1] = 981.0
# dash_total['conv2d'][1] = 17760.0
# dash_total['dense'][1] = 0.6
# dash_total['rescale'][1] = 5274.0

print(redash_optim_total)
print(redash_cpm_total)
print(dash_total)

# Define the three dataset categories to compare.
datasets = ["ReDASH Optimized", "ReDASH CPM", "DASH"]
n_datasets = len(datasets)
n_models = len(model_names)
bar_width = 0.25

# x positions for model groups.
indices = np.arange(n_models)

# Create a mapping for layers to style using the pre-defined color_hatch_combinations.
layer_style = {layer: color_hatch_combinations[i % len(color_hatch_combinations)]
                for i, layer in enumerate(layers)}

# Setup the figure.
fig, ax = plt.subplots(figsize=(8, 4))

# Plot bars for each dataset category.
# Offsets: centered grouping per model.
for i, dataset in enumerate(datasets):
    pos = indices + (i - (n_datasets - 1) / 2) * bar_width
    for j, m in enumerate(model_names):
        if dataset == "ReDASH Optimized":
            ser = pd.Series({layer: redash_optim_total.get(layer, [0]*n_models)[j] for layer in layers})
        elif dataset == "ReDASH CPM":
            ser = pd.Series({layer: redash_cpm_total.get(layer, [0]*n_models)[j] for layer in layers})
        elif dataset == "DASH":
            ser = pd.Series({layer: dash_total.get(layer, [0]*n_models)[j] for layer in layers})
        bottom = 0
        for layer in layers:
            val = ser[layer]
            c, h = layer_style[layer]
            ax.bar(pos[j], val, bar_width, bottom=bottom, color=c, hatch=h, edgecolor='black')
            bottom += val

# Set x-axis: one tick per model (using the center of grouped bars)
ax.set_xticks(indices)
# Replace model names with short labels: f and F.
model_labels = {"MODEL_F_GNNP_POOL_REPL": "f", "MODEL_F_MINIONN_POOL_REPL": "F"}
# ax.set_xticklabels([model_labels[m] for m in model_names])
# ax.set_xlabel("Model")
ax.set_ylabel("Online runtime (ms)")

# Instead of placing model labels at the default bottom,
# we remove the current ticks and set two x-axis labels.
# 1. Bottom ticks for dataset categories (I, II, III) for each model.
roman = ["I", "II", "III"]
tick_positions = []
tick_labels = []
for x in indices:
    for i in range(n_datasets):
        tick_positions.append(x + (i - (n_datasets - 1)/2)*bar_width)
        tick_labels.append(roman[i])
ax.set_xticks(tick_positions)
ax.set_xticklabels(tick_labels)
ax.set_xlabel("Setup")

# 2. Create a twin x-axis at the top for the model labels (f and F).
ax_top = ax.twiny()
ax_top.set_xlim(ax.get_xlim())
ax_top.set_xticks(indices)
ax_top.set_xticklabels([model_labels[m] for m in model_names])
ax_top.xaxis.set_ticks_position('top')
ax_top.xaxis.set_label_position('top')
ax_top.set_xlabel("Model")
ax_top.set_zorder(-1)

layer_label = {
    'approx_relu': 'ReLU',
    'conv2d': 'Conv2d',
    'dense': 'Dense',
    'rescale': 'Scaling'
}

# Create a combined legend for layers.
legend_handles = [mpatches.Patch(facecolor=layer_style[layer][0],
                                 hatch=layer_style[layer][1],
                                 edgecolor='black',
                                 label=layer_label.get(layer, layer))
                  for layer in layers]
ax.legend(handles=legend_handles, title="Layer", bbox_to_anchor=(1, 1))

fig.tight_layout()
ax.grid(False)

# Optionally, save the figure:
fig.savefig("model_benchmarks/data/vertical_combined_runtime_distribution.pdf", format='pdf', bbox_inches='tight')
plt.show()


In [None]:
models = ["MODEL_F_GNNP_POOL_REPL", "MODEL_F_MINIONN_POOL_REPL"]
dataset_labels = ["I", "I*", "I**"]

# Use the union of layers from the ReDASH optimized results.
layers = sorted(df_redash_optim.index)  # e.g., ['approx_relu', 'conv2d', 'dense', 'rescale']

# Compute total runtime values instead of percentages.
redash_optim_total = compute_totals(df_redash_optim, model_names, layers)
redash_eigen_total   = compute_totals(runtime_dist_data_EIGEN, model_names, layers)
redash_int32_eigen_total         = compute_totals(runtime_dist_data_INT32_EIGEN, model_names, layers)

# Update the optim totals for model F manually with values from ReDASH paper (manual fix for ArXiV version)
# redash_optim_total['approx_relu'][1] = 433.0
# redash_optim_total['conv2d'][1] = 2742.0
# redash_optim_total['dense'][1] = 0.2
# redash_optim_total['rescale'][1] = 455.0

print(redash_optim_total)
print(redash_eigen_total)
print(redash_int32_eigen_total)

# Define the three dataset categories to compare.
datasets = ["ReDASH Optimized", "ReDASH Optimized Eigen", "ReDASH Optimized Eigen Int32"]
n_datasets = len(datasets)
n_models = len(model_names)
bar_width = 0.25

# x positions for model groups.
indices = np.arange(n_models)

# Create a mapping for layers to style using the pre-defined color_hatch_combinations.
layer_style = {layer: color_hatch_combinations[i % len(color_hatch_combinations)]
                for i, layer in enumerate(layers)}

# Setup the figure.
fig, ax = plt.subplots(figsize=(8, 4))

# Plot bars for each dataset category.
# Offsets: centered grouping per model.
for i, dataset in enumerate(datasets):
    pos = indices + (i - (n_datasets - 1) / 2) * bar_width
    for j, m in enumerate(model_names):
        if dataset == "ReDASH Optimized":
            ser = pd.Series({layer: redash_optim_total.get(layer, [0]*n_models)[j] for layer in layers})
        elif dataset == "ReDASH Optimized Eigen":
            ser = pd.Series({layer: redash_eigen_total.get(layer, [0]*n_models)[j] for layer in layers})
        elif dataset == "ReDASH Optimized Eigen Int32":
            ser = pd.Series({layer: redash_int32_eigen_total.get(layer, [0]*n_models)[j] for layer in layers})
        bottom = 0
        for layer in layers:
            val = ser[layer]
            c, h = layer_style[layer]
            ax.bar(pos[j], val, bar_width, bottom=bottom, color=c, hatch=h, edgecolor='black')
            bottom += val

# Set x-axis: one tick per model (using the center of grouped bars)
ax.set_xticks(indices)
# Replace model names with short labels: f and F.
model_labels = {"MODEL_F_GNNP_POOL_REPL": "f", "MODEL_F_MINIONN_POOL_REPL": "F"}
# ax.set_xticklabels([model_labels[m] for m in model_names])
# ax.set_xlabel("Model")
ax.set_ylabel("Online runtime (ms)")

# Instead of placing model labels at the default bottom,
# we remove the current ticks and set two x-axis labels.
# 1. Bottom ticks for dataset categories (I, II, III) for each model.
roman = ["I", "I*", "I**"]
tick_positions = []
tick_labels = []
for x in indices:
    for i in range(n_datasets):
        tick_positions.append(x + (i - (n_datasets - 1)/2)*bar_width)
        tick_labels.append(roman[i])
ax.set_xticks(tick_positions)
ax.set_xticklabels(tick_labels)
ax.set_xlabel("Setup")

# 2. Create a twin x-axis at the top for the model labels (f and F).
ax_top = ax.twiny()
ax_top.set_xlim(ax.get_xlim())
ax_top.set_xticks(indices)
ax_top.set_xticklabels([model_labels[m] for m in model_names])
ax_top.xaxis.set_ticks_position('top')
ax_top.xaxis.set_label_position('top')
ax_top.set_xlabel("Model")
ax_top.set_zorder(-1)

layer_label = {
    'approx_relu': 'ReLU',
    'conv2d': 'Conv2d',
    'dense': 'Dense',
    'rescale': 'Scaling'
}

# Create a combined legend for layers.
legend_handles = [mpatches.Patch(facecolor=layer_style[layer][0],
                                 hatch=layer_style[layer][1],
                                 edgecolor='black',
                                 label=layer_label.get(layer, layer))
                  for layer in layers]
ax.legend(handles=legend_handles, title="Layer", bbox_to_anchor=(1, 1))

fig.tight_layout()
ax.grid(False)

fig.savefig("model_benchmarks/data/vertical_combined_runtime_distribution_EigenInt32.pdf", format='pdf', bbox_inches='tight')
plt.show()

## Ciphertext Comparison

In [None]:
cpm_dict = {
    3: [2,3,5],
    4: [2,3,5,7],
    5: [2,3,5,7,11],
    6: [2,3,5,7,11,13],
    7: [2,3,5,7,11,13,17],
    8: [2,3,5,7,11,13,17,19]
}

mrs_dict = {
    3: [32],
    4: [26,3],
    5: [54,4,3],
    6: [60,125],
    7: [86,7,36,5],
    8: [92,7,6,125,4]
}

def redash_ciphertexts(k):
    crt_base = cpm_dict[k]
    crt_base[0], crt_base[-1] = crt_base[-1], crt_base[0]

    sum = 0
    for i in range(1, k):
        for j in range(i, k):
            sum += crt_base[j]
    return sum

def dash_cipher_texts(k):
    crt_base = cpm_dict[k]
    mrs_base = mrs_dict[k]
    t = len(mrs_base)

    sum1 = 0
    for i in range(0, k):
        sum1 += crt_base[i]
    sum1 *= t

    sum2 = 0
    for i in range(1, t):
        sum2 += (mrs_base[i] + 2 * t * (k-1))
    sum2 *= 2 * k
    

    sum3 = mrs_base[0]

    base_change_projection_ciphertexts = 2 * (k-1)

    return sum1 + sum2 + sum3 + base_change_projection_ciphertexts

In [None]:
# Define k values corresponding to the data points
k_values = list(range(3, 9))

dash_data = [dash_cipher_texts(k) for k in k_values]
redash_data = [redash_ciphertexts(k) for k in k_values]

plt.figure(figsize=(5, 4))
plt.plot(k_values, redash_data, marker='o', linestyle='-', label="ReDASH")
plt.plot(k_values, dash_data, marker='*', linestyle='--', label="DASH")

plt.xlabel("RNS Size ($k$)")
plt.ylabel("Ciphertexts per Input")
plt.yscale('linear')  # set y-axis to a normal (linear) scale
plt.legend()
plt.grid(False)
plt.tight_layout()

plt.savefig("micro_benchmarks/data/ciphertexts_per_input.pdf", format="pdf", bbox_inches="tight")
plt.show()