In [None]:
import os, sys
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import matplotlib.pyplot as plt
matplotlib.rcParams['figure.dpi'] = 300
matplotlib.rcParams.update({'mathtext.default': 'regular'})
plt.rcParams['text.usetex'] = False
from matplotlib.ticker import FormatStrFormatter
import matplotlib.patches as patches

from plots_utils import *

import seaborn as sns
import numpy as np
import pandas as pd

import importlib

In [None]:
source_path = "<please enter your source path>"
dest_path = source_path + "/paper_plots"

# Store the results .csv files

## UPMEM

In [None]:
save_path = 'PIM_Criteo.csv'
filenames_list, file_paths_list = find_txt_filenames_new(source_path + '/postprocessed_results_UPMEM/Criteo')
save_csv_UPMEM(save_path, filenames_list, file_paths_list)

save_path = 'PIM_YFCC100M-HNfc6.csv'
filenames_list, file_paths_list = find_txt_filenames_new(source_path + '/results/results_UPMEM/YFCC100M-HNfc6')
save_csv_UPMEM(save_path, filenames_list, file_paths_list)

In [None]:
save_path = 'benchmark_PIM_Criteo.csv'
filenames_list, file_paths_list = find_txt_filenames_new(source_path + '/results/benchmark_UPMEM/Criteo')
save_csv_UPMEM_benchmark(save_path, filenames_list, file_paths_list)

save_path = 'benchmark_PIM_YFCC100M-HNfc6.csv'
filenames_list, file_paths_list = find_txt_filenames_new(source_path + '/results/benchmark_UPMEM/YFCC100M-HNfc6')
save_csv_UPMEM_benchmark(save_path, filenames_list, file_paths_list)

## Baselines

In [None]:
df_baseline_yfcc_lr = pd.read_csv('/mnt/galactica/rhyners/AE_PACT_2024/baseline/baseline_yfcc_lr.csv')
df_baseline_yfcc_lr.loc[df_baseline_yfcc_lr['dist_type'] == 'ma1', 'dist_type'] = 'ma'
df_baseline_yfcc_svm = pd.read_csv('/mnt/galactica/rhyners/AE_PACT_2024/baseline/baseline_yfcc_svm.csv')
df_baseline_yfcc_svm.loc[df_baseline_yfcc_svm['dist_type'] == 'ma1', 'dist_type'] = 'ma'
df_baseline_yfcc = pd.concat([df_baseline_yfcc_lr, df_baseline_yfcc_svm], ignore_index = True)
df_baseline_yfcc.reset_index()
df_baseline_yfcc.to_csv('/mnt/galactica/rhyners/AE_PACT_2024/paper_plots/baseline_yfcc.csv', index = False)

In [None]:
df_baseline_criteo_lr = pd.read_csv('/mnt/galactica/rhyners/AE_PACT_2024/baseline/baseline_criteo_lr.csv')
df_baseline_criteo_lr.loc[df_baseline_criteo_lr['dist_type'] == 'ma1', 'dist_type'] = 'ma'
df_baseline_criteo_svm = pd.read_csv('/mnt/galactica/rhyners/AE_PACT_2024/baseline/baseline_criteo_svm.csv')
df_baseline_criteo_svm.loc[df_baseline_criteo_svm['dist_type'] == 'ma1', 'dist_type'] = 'ma'
df_baseline_criteo = pd.concat([df_baseline_criteo_lr, df_baseline_criteo_svm], ignore_index = True)
df_baseline_criteo.reset_index()
df_baseline_criteo.to_csv('/mnt/galactica/rhyners/AE_PACT_2024/paper_plots/baseline_criteo.csv', index = False)

# §II. Background & Motivation

## §II-C. Motivation

#### Figure 2: Per global epoch comparison of measured throughput (a) and total data movement (b) for all studied algorithms (MA-SGD, GA-SGD, and ADMM) on the UPMEM PIM system for the Criteo dataset.

In [None]:
df_dpu = process_benchmark_DPU_csv("benchmark_PIM_Criteo.csv")

In [None]:
final_df = df_dpu.copy()
final_df = final_df[(final_df['scaling_type'] == 'weak')]
final_df = final_df[(final_df['nr_procs'] == 2048)]
final_df = final_df[(final_df['num_local_epochs'] == 1)]
final_df = final_df[(final_df['batch_size'] == 2048) | (final_df['batch_size'] == 262144)]
final_df.loc[final_df['model_type'] == 'lr', 'model_type'] = 'LR'
final_df = final_df[final_df['model_type'] == 'LR']

columns_to_melt_bandwidth = ['CPU_and_DPU_band', 'M_and_W_band']
columns_to_melt_data = ['CPU_and_DPU_data', 'M_and_W_data']
id_vars_bandwidth = [col for col in final_df.columns if col not in columns_to_melt_bandwidth]
id_vars_data = [col for col in final_df.columns if col not in columns_to_melt_data]

df_melted_bandwidth = final_df.melt(id_vars=id_vars_bandwidth, value_vars=columns_to_melt_bandwidth,
                                    var_name='TimeType', value_name='Value')
df_melted_data = final_df.melt(id_vars=id_vars_data, value_vars=columns_to_melt_data,
                               var_name='TimeType', value_name='Value')

df_melted_bandwidth['algorithm'] = df_melted_bandwidth['algorithm']
df_melted_data['algorithm'] = df_melted_data['algorithm']

hue_order = ['MA-SGD', 'GA-SGD', 'ADMM']

fig, axs = plt.subplots(1, 2, figsize=(14, 4.05))

g1 = sns.barplot(x='TimeType', y='Value', hue='algorithm', hue_order=hue_order, data=df_melted_bandwidth,
                 edgecolor='black', linewidth=1, ax=axs[0], width=0.6)
g2 = sns.barplot(x='TimeType', y='Value', hue='algorithm', hue_order=hue_order, data=df_melted_data,
                 edgecolor='black', linewidth=1, ax=axs[1], width=0.6)

g1.legend_.remove()
g2.legend_.remove()

all_handles, all_labels = [], []
for i in range(2):
    handles, labels = axs[i].get_legend_handles_labels()
    for handle, label in zip(handles, labels):
        if label not in all_labels:
            all_handles.append(handle)
            all_labels.append(label)

ordered_handles = [next(h for h, l in zip(all_handles, all_labels) if l == alg) for alg in hue_order]

leg = fig.legend(ordered_handles, hue_order, loc='upper center', bbox_to_anchor=(0.54, 1),
                 ncol=3, fontsize=17, columnspacing=0.8, frameon=True, edgecolor="black")

for i in range(2):
    axs[i].grid(axis='y', linestyle='--')
    axs[i].set_axisbelow(True)
    axs[i].tick_params(axis="y", direction="in", which='both', labelsize=20)
    axs[i].tick_params(axis="x", length=0, width=0, labelsize=15)

    axs[0].set_ylim(1, 150000)
    axs[1].set_ylim(1, 150000)
    axs[i].set_yscale('log')

    axs[0].set_ylabel("Measured\nThroughput (GB/s)", fontsize=18, weight="bold")
    axs[1].set_ylabel("Total Data\nMovement (GB)", fontsize=18, weight="bold")
    axs[i].set_xlabel("")

    new_x_labels_0 = ['Comm. with\nParameter Server', 'PIM']
    axs[0].set_xticklabels(new_x_labels_0, fontsize=18)
    new_x_labels_1 = ['Comm. with\nParameter Server', 'PIM']
    axs[1].set_xticklabels(new_x_labels_1, fontsize=18)

    for axis in ['top', 'bottom', 'left', 'right']:
        axs[i].spines[axis].set_linewidth(1.3)

axs[0].text(0.5, -0.3, '(a)', transform=axs[0].transAxes, fontsize=18, va='top', ha='center')
axs[1].text(0.5, -0.3, '(b)', transform=axs[1].transAxes, fontsize=18, va='top', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.80, wspace=0.23)
plt.savefig("./output/Fig_2.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_2.png", bbox_inches='tight', pad_inches=0.01)

# §V. Evaluation

## §V-A. Evaluation of YFCC100M-HNfc6

### PIM Performance Breakdown.

#### Figure 4: Per global epoch training time breakdown into Comm./Sync. Para. Server, PIM Comp., and PIM data movement time for LR (a) and SVM (b).

In [None]:
df_dpu = process_benchmark_DPU_csv("benchmark_PIM_YFCC100M-HNfc6.csv")

In [None]:
final_df = df_dpu.copy()
final_df = final_df[(final_df['scaling_type'] == 'weak')]
final_df = final_df[(final_df['nr_procs'] == 2048)]
final_df = final_df[(final_df['num_local_epochs'] == 1)]
final_df = final_df[(final_df['batch_size'] == 8) | (final_df['batch_size'] == 4096)]
final_df.loc[final_df['model_type'] == 'lr', 'model_type'] = 'LR'
final_df.loc[final_df['model_type'] == 'svm', 'model_type'] = 'SVM'

columns_to_melt = ['total_communication_time', 'DPU_compute_time', 'M_and_W_time', 'total_time']
id_vars = [col for col in final_df.columns if col not in columns_to_melt]

df_melted = final_df.melt(id_vars=id_vars, value_vars=columns_to_melt,
                    var_name='TimeType', value_name='Value')

hue_order = ['MA-SGD', 'GA-SGD', 'ADMM']
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(22, 6.6))

g1 = sns.barplot(x='TimeType', y='Value', hue='algorithm', hue_order = hue_order, data=df_melted[df_melted["model_type"] == "LR"], edgecolor='black', linewidth=1, ax=axs[0], width=0.7)
g2 = sns.barplot(x='TimeType', y='Value', hue='algorithm', hue_order = hue_order, data=df_melted[df_melted["model_type"] == "SVM"], edgecolor='black', linewidth=1, ax=axs[1], width=0.7)
g1.legend_.remove()
g2.legend_.remove()


model_types = ["LR", "SVM"]

for i in range(2):
    axs[i].set_title(f"{model_types[i]}",fontsize=30, weight="bold")
    if i == 0:
        handles, labels = axs[0].get_legend_handles_labels()
        handles, labels = handles[:3], labels[:3]
        l = fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.54, 1),ncol=3, fontsize=28,edgecolor="black", title=None,  frameon=True)
    axs[i].set_yscale('log')

    axs[i].grid(axis='y', linestyle='--')
    axs[i].set_axisbelow(True)
    axs[i].tick_params(axis="y", direction="in", which='both',labelsize=30)
    axs[i].tick_params(axis="x", length=0, width=0, labelsize=6)

    if i == 0:
        axs[i].set_ylabel("Per Global Epoch\nTraining Time (s)", fontsize=34, weight="bold")
    
    axs[i].set_xlabel("")
    print(axs[i].get_ylim())
    axs[i].set_ylim(0.01, 100)
    new_x_labels = ['Comm./Sync.\nPara. Server', 'PIM\nComp.', 'PIM Data\nMovement', 'Total']
    axs[i].set_xticklabels(new_x_labels, fontsize=32)

    if i == 1:
        axs[i].set_ylabel("")
        axs[i].set_yticklabels([])

    for axis in ['top','bottom','left','right']:
        axs[i].spines[axis].set_linewidth(1.3)

colors = []
for i in range(3):
    colors.append(sns.color_palette()[i])

axs[0].text(0.5, -0.5, '(a)', transform=axs[0].transAxes, fontsize=30, va='top', ha='center')
axs[1].text(0.5, -0.5, '(b)', transform=axs[1].transAxes, fontsize=30, va='top', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.77)
plt.subplots_adjust(wspace=0.03, hspace=0.05)
plt.savefig("./output/Fig_4.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_4.png", bbox_inches='tight', pad_inches=0.01)

### Algorithm Selection.

#### Figure 5: Comparison of various models (LR (a) and SVM (b)), algorithms (MA-SGD, GA-SGD, ADMM, and mini-batch SGD), and architectures (PIM, CPU, and GPU). We study the test accuracy (at the last global epoch) and total training time (10 global epochs).

In [None]:
df_dpu = process_DPU_csv("PIM_YFCC100M-HNfc6.csv")
df_baseline_1 = process_baseline_csv("baseline_yfcc.csv")
df = pd.concat([df_dpu, df_baseline_1], ignore_index=True)

In [None]:
final_df = df.copy()
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(6, 4.3))

hue_order = ['MA-SGD', 'GA-SGD', 'ADMM', 'SGD']
style_order = ['DPU', 'CPU', 'GPU']
marker_dict = {'MA-SGD': 'o', 'GA-SGD': 's', 'ADMM': '^', 'SGD': 'd'}


for i, model in enumerate(["lr", "svm"]):
  for j, arch in enumerate(['DPU', 'CPU', 'GPU']):
    for dataset in ["yfcc"]:
      data = final_df.loc[
        ( 
          (final_df["model_type"] == model) &
          (final_df["dataset"] == dataset) &
          (final_df["algorithm"] == 'ADMM') &
          (final_df["num_local_epochs"] == 1) &
          (final_df["nr_procs"] == 2048) &
          (final_df['batch_size']==8) & 
          (final_df['scaling_type']=='weak') 
        ) | ( 
          (final_df["model_type"] == model) &
          (final_df["dataset"] == dataset) &
          (final_df["algorithm"] == 'ADMM') &
          (final_df["num_local_epochs"] == 1) &
          (final_df["nr_procs"] == 128) & 
          (final_df['batch_size']==8) & 
          (final_df['scaling_type']=='weak') 
        ) | ( 
          (final_df["model_type"] == model) &
          (final_df["dataset"] == dataset) &
          (final_df["algorithm"] == 'MA-SGD') &
          (final_df["num_local_epochs"] == 1) &
          (final_df["nr_procs"] == 2048) & 
          (final_df['batch_size']==8) & 
          (final_df['scaling_type']=='weak') 
        ) | ( 
          (final_df["model_type"] == model) &
          (final_df["dataset"] == dataset) &
          (final_df["algorithm"] == 'MA-SGD') &
          (final_df["num_local_epochs"] == 1) &
          (final_df["nr_procs"] == 128) & 
          (final_df['batch_size']==8) & 
          (final_df['scaling_type']=='weak') 
        ) | ( 
          (final_df["model_type"] == model) &
          (final_df["dataset"] == dataset) &
          (final_df["num_local_epochs"] == 1) &
          ((final_df["nr_procs"] == 2048) | (final_df["nr_procs"] == 128) | (final_df["nr_procs"] == 1)) &
          (final_df['batch_size']==4096) & 
          (final_df['scaling_type']=='weak') 
        )
      ]

      data.loc[df['architecture'] == 'GPU', 'algorithm'] = 'SGD'
      data.loc[df['architecture'] == 'GPU', 'total_elapsed_time'] /= 1000
      g1 = sns.lineplot(data[data["architecture"] == arch], x="total_elapsed_time", y="test_accuracy", hue="algorithm", hue_order=hue_order, style="algorithm", style_order=hue_order,markers=marker_dict, dashes=False, errorbar=None, ax=axs[j, i],palette = 'deep', markersize=3.65, linewidth=1,markeredgewidth= 0)
      
      model_tmp = model.upper()
      axs[0, i].set_title(f"{model_tmp}",fontsize=10, weight="bold")


all_handles = []
all_labels = []
for i in range(3):
    for j in range(2):
        handles, labels = axs[i][j].get_legend_handles_labels()
        for handle, label in zip(handles, labels):
            if label not in all_labels:
                all_handles.append(handle)
                all_labels.append(label)

order = ['MA-SGD', 'GA-SGD', 'ADMM', 'SGD']
ordered_handles = [next(h for h, l in zip(all_handles, all_labels) if l == alg) for alg in order]

for i in range(3):
    for j in range(2):
        axs[i][j].legend().remove()

order_display = ['MA-SGD', 'GA-SGD', 'ADMM', 'mini-batch SGD']
fig.legend(ordered_handles, order_display, loc='upper center', bbox_to_anchor=(0.56, 1),ncol=4, fontsize=8,frameon=True, edgecolor="black", markerscale=1)


for i, ax in enumerate(axs.reshape(-1)):
  ax.set(xlabel=None, ylabel=None)
  ax.grid(linestyle='--')
  ax.set_axisbelow(True)

  ax.tick_params(axis="y", direction="in", labelsize=7)
  ax.tick_params(axis="x", direction="in", labelsize=7)

  for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1)


for i in range(3):
  if i == 0:
    axs[i][0].set_ylabel("PIM", fontsize=10, weight="bold")
  else:
    axs[i][0].set_ylabel(style_order[i], fontsize=10, weight="bold")
  for j in range(2):
    if i == 0:
      axs[i][j].set_xlim(0, 320) 
      axs[i][j].set_xticks(list(range(0, 301, 50))) 
      axs[i][j].set_xticklabels(list(range(0, 301, 50)), fontsize = 8) 
    if i == 1:
      axs[i][j].set_xlim(0, 1300) 
      axs[i][j].set_xticks(list(range(0, 1201, 200))) 
      axs[i][j].set_xticklabels(list(range(0, 1201, 200)),fontsize=8)
    if i == 2:
      axs[i][j].set_xlim(0, 750) 
      axs[i][j].set_xticks(list(range(0, 800, 100))) 
      axs[i][j].set_xticklabels(list(range(0, 800, 100)),fontsize=8) 


    if i == 0:
      axs[i][j].set_ylim(91.1, 97.5) 
      axs[i][j].set_yticks(list(range(92, 98, 1)))
      axs[i][j].set_yticklabels(list(range(92, 98, 1)),fontsize = 8)  
    else:
      axs[i][j].set_ylim(92.25, 97.5) 
      axs[i][j].set_yticks(list(range(93, 98, 1))) 
      axs[i][j].set_yticklabels(list(range(93, 98, 1)),fontsize=8)


colors = []
for i in range(3):
    colors.append(sns.color_palette()[i])

line = axs[1, 0].get_lines()[0]
x_data = line.get_xdata()[-1]
y_data = line.get_ydata()[-1]

an = axs[1, 0].annotate(f"Last epoch\nTest Accuracy: {y_data:.2f}%\nFinishes at {x_data:.0f}s",  
        xy=(0, 0),  
        xytext=(0.75, 0.35),  
        textcoords='axes fraction',  
        ha='center', va='bottom', fontsize=6, annotation_clip=False, weight="bold", bbox=dict(boxstyle="round,pad=0.05", facecolor="white", edgecolor="none", alpha=0.5))
an.set_color(colors[0])

line = axs[1, 1].get_lines()[0]
x_data = line.get_xdata()[-1]
y_data = line.get_ydata()[-1]
an = axs[1, 1].annotate(f"Last epoch\nTest Accuracy: {y_data:.2f}%\nFinishes at {x_data:.0f}s",  
        xy=(0, 0),  
        xytext=(0.75, 0.35),  
        textcoords='axes fraction',  
        ha='center', va='bottom', fontsize=6, annotation_clip=False,weight="bold", bbox=dict(boxstyle="round,pad=0.05", facecolor="white", edgecolor="none", alpha=0.5))
an.set_color(colors[0])


fig.supxlabel(r"Total Training Time (s)", y=0.15, x=0.56, fontsize=12, weight="bold")

fig.supylabel(r"Test Accuracy (%)", y=0.572, x=0.03, fontsize=12, weight="bold")

axs[2,0].text(0.5, -0.45, '(a)', transform=axs[2,0].transAxes, fontsize=11, va='top', ha='center')
axs[2,1].text(0.5, -0.45, '(b)', transform=axs[2,1].transAxes, fontsize=11, va='top', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.89)
plt.subplots_adjust(wspace=0.145, hspace=0.2)
plt.savefig("./output/Fig_5.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_5.png", bbox_inches='tight', pad_inches=0.01)

### Batch Size.

#### Figure 6: Impact of batch size on total training time (10 global epochs) and test accuracy (at the last global epoch) for SVM MA-SGD (a), SVM GA-SGD (b), and LR ADMM (c).

In [None]:
df_dpu = process_DPU_csv("PIM_YFCC100M-HNfc6.csv")
df_baseline_1 = process_baseline_csv("baseline_yfcc.csv")
df = pd.concat([df_dpu, df_baseline_1], ignore_index=True)

In [None]:
final_df = df.copy()
final_df = final_df[(final_df['nr_procs'] == 2048) | (final_df['architecture'] != 'DPU')]
final_df = final_df[(final_df['batch_size'] != 1)]

deep_palette = sns.color_palette("deep")

green_color = deep_palette[2]
blue_color = deep_palette[0]
orange_color = deep_palette[1]


nr_procs_palette = {1: orange_color, 128: blue_color, 2048: green_color}
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(6, 4.6)) 

for i, model in enumerate(["svm"]):
  for j, algorithm in enumerate(["MA-SGD", "GA-SGD", "ADMM"]):
    if j == 2:
      model = "lr"
    for dataset in ["yfcc"]:
      data = final_df.loc[
        (final_df["model_type"] == model) &
        (final_df["algorithm"] == algorithm) &
        (final_df["dataset"] == dataset) &
        (final_df["num_local_epochs"] == 1) &
        (final_df['scaling_type']=='weak') &
        (final_df['g_epoch_id']==9)
      ]

      if algorithm == "GA-SGD":
          data = final_df.loc[
            (final_df["model_type"] == model) &
            (final_df["algorithm"] == algorithm) &
            (final_df["dataset"] == dataset) &
            (final_df["num_local_epochs"] == 1) &
            (final_df['scaling_type']=='weak') &
            (final_df['g_epoch_id']==9) 
          ]
          
      g1 = sns.barplot(data, x="batch_size", y="total_elapsed_time", hue="nr_procs",errorbar=None, ax=axs[0, i*3+j],palette=nr_procs_palette, edgecolor='black',linewidth=0.7,width=0.7)
      g2 = sns.barplot(data, x="batch_size", y="test_accuracy", hue="nr_procs",errorbar=None, ax=axs[1, i*3+j], palette=nr_procs_palette, edgecolor='black',linewidth=0.7,width=0.7)
      g1.legend_.remove()
      g2.legend_.remove()

      model_tmp = model.upper()
      b = 8 
      if algorithm == "GA-SGD":
        b = 4096
      axs[0, i*3+j].set_title(f"{model_tmp} {algorithm}",fontsize=10, weight="bold")
      

for i, ax in enumerate(axs.reshape(-1)):
  ax.set(xlabel=None, ylabel=None)
  ax.grid(axis='y',linestyle='--')
  ax.set_axisbelow(True)

  ax.tick_params(axis='y', length=0, width=0, labelsize=7)
  ax.tick_params(axis='x', length=0, width=0, labelsize=7)

  for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.3)


for j in range(3):
  handles, labels = axs[0][j].get_legend_handles_labels()
  new_labels = ["CPU", "PIM"]
  if j == 0:
    order =["CPU", "PIM"]
    ordered_handles = [next(h for h, l in zip(handles, new_labels) if l == alg) for alg in order]

    legend = fig.legend(ordered_handles, order, loc='upper center', bbox_to_anchor=(0.57, 1),ncol=2, fontsize=8,frameon=True, edgecolor="black")
   
  if j == 0:
    axs[0, j].set_ylim(0, 1100)
    axs[0, j].set_yticks(np.arange(0, 1001, 200).tolist())
    axs[0, j].set_yticklabels(np.arange(0, 1001, 200).tolist(), fontsize=10)
  if j == 2:
    axs[0, j].set_ylim(0, 155)
    axs[0, j].set_yticks(np.arange(0, 151, 50).tolist())
    axs[0, j].set_yticklabels(np.arange(0, 151, 50).tolist(), fontsize=10)
  if j == 1:
    axs[0, j].set_ylim(0, 370)
    axs[0, j].set_yticks(np.arange(0, 301, 100).tolist())
    axs[0, j].set_yticklabels(np.arange(0, 301, 100).tolist(), fontsize=10)


  axs[1, j].set_ylim(91.75, 96.7)
  axs[1, j].set_yticks(np.arange(92, 97, 1).tolist())
  axs[1, j].set_yticklabels(np.arange(92, 97, 1).tolist(), fontsize=10)

  for p in axs[0, j].patches:
    height = p.get_height()
    if height > 350:
      if j== 0 or j == 2:
        if height > 1100:
          if height < 2000:
            an = axs[0, j].annotate('{:.2f}'.format(height), 
                        xy=(0, 0),  
                        xytext=(0.415, 0.76),  
                        textcoords='axes fraction',  
                        ha='center', va='center', rotation=270,annotation_clip=False, fontsize=9,weight="bold",bbox=dict(boxstyle="round,pad=0.001", facecolor="white", edgecolor="none", alpha=0.5))
            an.set_color(sns.color_palette("deep")[0])
          if height > 2000:
              an = axs[0, j].annotate('{:.2f}'.format(height), 
                        xy=(0, 0),  
                        xytext=(0.18, 0.76),  
                        textcoords='axes fraction',  
                        ha='center', va='center', rotation=270,annotation_clip=False, fontsize=9,weight="bold",bbox=dict(boxstyle="round,pad=0.001", facecolor="white", edgecolor="none", alpha=0.5))
              an.set_color(sns.color_palette("deep")[0])
      if j == 1:
        an = axs[0, j].annotate('{:.2f}'.format(height), 
                  xy=(0, 0),  
                  xytext=(0.18, 0.8125),  
                  textcoords='axes fraction', 
                  ha='center', va='center', rotation=270,annotation_clip=False, fontsize=8,weight="bold",bbox=dict(boxstyle="round,pad=0.01", facecolor="white", edgecolor="none", alpha=0.5))
        an.set_color(sns.color_palette("deep")[0])


for i in range(2):
    for j in range(3):
        if j != 1: 
            axs[i, j].set_xticklabels([8, 16, 32, 64], fontsize=9)
        else: 
            axs[i, j].set_xticklabels(["4K", "8K", "16K", "32K"], fontsize=9)


axs[0, 0].set_ylabel("Total Training\nTime (s)", fontsize=11, weight="bold")
axs[1, 0].set_ylabel("Test Accuracy (%)", fontsize=11, weight="bold")

fig.supxlabel(r"Batch Size", y=0.19, x=0.57, fontsize=11, weight="bold")

axs[1,0].text(0.5, -0.5, '(a)', transform=axs[1,0].transAxes, fontsize=11, va='bottom', ha='center')
axs[1,1].text(0.5, -0.5, '(b)', transform=axs[1,1].transAxes, fontsize=11, va='bottom', ha='center')
axs[1,2].text(0.5, -0.5, '(c)', transform=axs[1,2].transAxes, fontsize=11, va='bottom', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.875)
plt.subplots_adjust(wspace=0.24, hspace=0.15)
plt.savefig("./output/Fig_6.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_6.png", bbox_inches='tight', pad_inches=0.01)

### Weak Scaling.

#### Figure 7: Impact of weak scaling on total training time (10 global epochs) and test accuracy (at the last global epoch) for LR (a) and SVM (b).

In [None]:
df_dpu = process_DPU_csv("PIM_YFCC100M-HNfc6.csv")

In [None]:
final_df = df_dpu.copy()
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 4.9)) 
final_df = final_df[(final_df['algorithm'] != 'GA-SGD') | ((final_df['num_local_epochs'] == 1) & (final_df['batch_size'] == 8192))]
final_df = final_df[(final_df['algorithm'] == 'GA-SGD') | ((final_df['num_local_epochs'] == 1) & (final_df['batch_size'] == 8))]
final_df = final_df[(final_df['g_epoch_id']==9)]
final_df = final_df[(final_df['scaling_type']=='weak')]

hue_order = [256, 512, 1024, 2048]
category_order = ['MA-SGD', 'GA-SGD', 'ADMM']
for j, model in enumerate(["lr", "svm"]):
  for dataset in ["yfcc"]:
    data = final_df.loc[
      (final_df["model_type"] == model) &
      (final_df["dataset"] == dataset)
    ]

    g1 = sns.barplot(data, x="algorithm", y="total_elapsed_time", hue="nr_procs", hue_order = [256, 512, 1024, 2048], order=category_order, errorbar=None, ax=axs[0, j],palette = ["#4a90e2", "#0076d1", "#0066b5", "#044d85"],edgecolor='black', linewidth=1, width=0.7)
    g2 = sns.barplot(data, x="algorithm", y="test_accuracy", hue="nr_procs", hue_order = [256, 512, 1024, 2048], order=category_order, errorbar=None, ax=axs[1, j],palette = ["#4a90e2", "#0076d1", "#0066b5", "#044d85"], edgecolor='black', linewidth=1, width=0.7)
    if j != 0:
      g1.legend_.remove()
    g2.legend_.remove()

    model_tmp = model.upper()

    axs[0, j].set_title(f"{model_tmp} (Weak Scaling)",fontsize=13, weight="bold")


all_handles = []
all_labels = []
for i in range(2):
    for j in range(2):
        handles, labels = axs[i][j].get_legend_handles_labels()
        for handle, label in zip(handles, labels):
            if label not in all_labels:
                all_handles.append(handle)
                all_labels.append(label)

order = ['256', '512', '1024', '2048']
ordered_handles = [next(h for h, l in zip(all_handles, all_labels) if l == alg) for alg in order]


for i in range(2):
    for j in range(2):
        axs[i][j].legend().remove()

leg = fig.legend(ordered_handles, order, loc='upper center', bbox_to_anchor=(0.54, 1),ncol=4, fontsize=12,frameon=True, edgecolor="black")
leg.set_title("Nr. DPUs", prop={'size': 12, 'weight': 'bold'})

for i, ax in enumerate(axs.reshape(-1)):
  ax.set(xlabel=None, ylabel=None)
  ax.grid(axis='y',linestyle='--')
  ax.set_axisbelow(True)

  ax.tick_params(axis="y", direction="in", labelsize=7)
  ax.tick_params(axis='x', length=0, width=0, labelsize=7)


  for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.3)


  if i <= 1:
    ax.set_ylim(0, 225)
    ax.set_yticks(np.arange(0, 201, 100).tolist())
    ax.set_yticklabels(np.arange(0, 201, 100).tolist(),fontsize=11.5)

  if i > 1:
    ax.set_ylim(93, 97.5)
    ax.set_yticks(np.arange(93, 97.1, 1).tolist())
    ax.set_yticklabels(np.arange(93, 97.1, 1).tolist(),fontsize=11.5)

  if i == 1 or i == 3:
    ax.set_yticklabels([])

  if i <= 1:
    ax.set_xticklabels([])


  ax.tick_params(axis='x', labelsize=13.5)  
  ax.tick_params(axis='y', labelsize=12)  


axs[0, 0].set_ylabel("Total Training\nTime (s)", fontsize=12.5, weight="bold")
axs[1, 0].set_ylabel("Test\nAccuracy (%)", fontsize=12.5, weight="bold")

fig.supxlabel(r"Optimization Algorithm", y=0.205, x=0.55, fontsize=14, weight="bold")

axs[1,0].text(0.5, -0.6, '(a)', transform=axs[1,0].transAxes, fontsize=13, va='bottom', ha='center')
axs[1,1].text(0.5, -0.6, '(b)', transform=axs[1,1].transAxes, fontsize=13, va='bottom', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.815)
plt.subplots_adjust(wspace=0.045, hspace=0.17)
plt.savefig("./output/Fig_7.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_7.png", bbox_inches='tight', pad_inches=0.01)

### Strong Scaling.

#### Figure 8: Impact of strong scaling on total training time (10 global epochs) and test accuracy (at the last global epoch) for LR (a) and SVM (b).

In [None]:
df_dpu = process_DPU_csv("PIM_YFCC100M-HNfc6.csv")

In [None]:
final_df = df_dpu.copy()
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 4.9))
final_df = final_df[(final_df['algorithm'] != 'GA-SGD') | ((final_df['num_local_epochs'] == 1) & (final_df['batch_size'] == 8192))]
final_df = final_df[(final_df['algorithm'] == 'GA-SGD') | ((final_df['num_local_epochs'] == 1) & (final_df['batch_size'] == 8))]
final_df = final_df[(final_df['g_epoch_id']==9)]
final_df = final_df[(final_df['scaling_type']=='strong')]

hue_order = [256, 512, 1024, 2048]
category_order = ['MA-SGD', 'GA-SGD', 'ADMM']
for j, model in enumerate(["lr", "svm"]):
  for dataset in ["yfcc"]:
    data = final_df.loc[
      (final_df["model_type"] == model) &
      (final_df["dataset"] == dataset)
    ]

    g1 = sns.barplot(data, x="algorithm", y="total_elapsed_time", hue="nr_procs", hue_order = [256, 512, 1024, 2048], order=category_order, errorbar=None, ax=axs[0, j],palette = ["#4a90e2", "#0076d1", "#0066b5", "#044d85"],edgecolor='black', linewidth=1, width=0.7)
    g2 = sns.barplot(data, x="algorithm", y="test_accuracy", hue="nr_procs", hue_order = [256, 512, 1024, 2048], order=category_order, errorbar=None, ax=axs[1, j],palette = ["#4a90e2", "#0076d1", "#0066b5", "#044d85"], edgecolor='black', linewidth=1, width=0.7)
    if j != 0:
      g1.legend_.remove()
    g2.legend_.remove()

    model_tmp = model.upper()

    axs[0, j].set_title(f"{model_tmp} (Strong Scaling)",fontsize=13, weight="bold")

all_handles = []
all_labels = []
for i in range(2):
    for j in range(2):
        handles, labels = axs[i][j].get_legend_handles_labels()
        for handle, label in zip(handles, labels):
            if label not in all_labels:
                all_handles.append(handle)
                all_labels.append(label)

order = ['256', '512', '1024', '2048']
ordered_handles = [next(h for h, l in zip(all_handles, all_labels) if l == alg) for alg in order]

for i in range(2):
    for j in range(2):
        axs[i][j].legend().remove()

leg = fig.legend(ordered_handles, order, loc='upper center', bbox_to_anchor=(0.54, 1),ncol=4, fontsize=12,frameon=True, edgecolor="black")
leg.set_title("Nr. DPUs", prop={'size': 12, 'weight': 'bold'})

for i, ax in enumerate(axs.reshape(-1)):
  ax.set(xlabel=None, ylabel=None)
  ax.grid(axis='y',linestyle='--')
  ax.set_axisbelow(True)

  ax.tick_params(axis="y", direction="in", labelsize=7)
  ax.tick_params(axis='x', length=0, width=0, labelsize=7)


  for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.3)


  if i <= 1:
    ax.set_ylim(0, 160)
    ax.set_yticks(np.arange(0, 161, 50).tolist())
    ax.set_yticklabels(np.arange(0, 161, 50).tolist(),fontsize=11.5)

  if i > 1:
    ax.set_ylim(91, 96.5)
    ax.set_yticks(np.arange(91, 96.1, 1).tolist())
    ax.set_yticklabels(np.arange(91, 96.1, 1).tolist(),fontsize=11.5)

  if i == 1 or i == 3:
    ax.set_yticklabels([])

  if i <= 1:
    ax.set_xticklabels([])

  ax.tick_params(axis='x', labelsize=13.5)  
  ax.tick_params(axis='y', labelsize=11)  

  


for i, p in enumerate(axs[0, 0].patches):
  height = p.get_height()
  if height > 320:
    if i == 0:
      an = axs[0, 0].annotate('{:.2f}'.format(height), 
                  xy=(0, 0),  
                  xytext=(0.13, 0.87),  
                  textcoords='axes fraction', 
                  ha='center', va='center', annotation_clip=False, fontsize=7.5, rotation=270)
      an.set_color("#4a90e2")
    if i != 0:
      an = axs[0, 0].annotate('{:.2f}'.format(height), 
                  xy=(0, 0),  
                  xytext=(0.795, 0.87),  
                  textcoords='axes fraction',  
                  ha='center', va='center', annotation_clip=False, fontsize=7.5, rotation=270)
      an.set_color("#4a90e2")

axs[0, 0].set_ylabel("Total Training\nTime (s)", fontsize=12.5, weight="bold")
axs[1, 0].set_ylabel("Test\nAccuracy (%)", fontsize=12.5, weight="bold")

fig.supxlabel(r"Optimization Algorithm", y=0.205, x=0.55, fontsize=14, weight="bold")

axs[1,0].text(0.5, -0.6, '(a)', transform=axs[1,0].transAxes, fontsize=13, va='bottom', ha='center')
axs[1,1].text(0.5, -0.6, '(b)', transform=axs[1,1].transAxes, fontsize=13, va='bottom', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.815)
plt.subplots_adjust(wspace=0.045, hspace=0.17)
plt.savefig("./output/Fig_8.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_8.png", bbox_inches='tight', pad_inches=0.01)

## §V-B. Evaluation of Criteo

### PIM Performance Breakdown.

#### Figure 9: Per global epoch training time breakdown into Comm./Sync. Para. Server, PIM Comp., and PIM data movement time for LR (a) and SVM (b).

In [None]:
df_dpu = process_benchmark_DPU_csv("benchmark_PIM_Criteo.csv")

In [None]:
final_df = df_dpu.copy()
final_df = final_df[(final_df['scaling_type'] == 'weak')]
final_df = final_df[(final_df['nr_procs'] == 2048)]
final_df = final_df[(final_df['num_local_epochs'] == 1)]
final_df = final_df[(final_df['batch_size'] == 2048) | (final_df['batch_size'] == 262144)]
final_df.loc[final_df['model_type'] == 'lr', 'model_type'] = 'LR'
final_df.loc[final_df['model_type'] == 'svm', 'model_type'] = 'SVM'

columns_to_melt = ['total_communication_time', 'DPU_compute_time', 'M_and_W_time', 'total_time']
id_vars = [col for col in final_df.columns if col not in columns_to_melt]

df_melted = final_df.melt(id_vars=id_vars, value_vars=columns_to_melt,
                    var_name='TimeType', value_name='Value')

hue_order = ['MA-SGD', 'GA-SGD', 'ADMM']
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(22, 6.6))

g1 = sns.barplot(x='TimeType', y='Value', hue='algorithm', hue_order = hue_order, data=df_melted[df_melted["model_type"] == "LR"], edgecolor='black', linewidth=1, ax=axs[0], width=0.7)
g2 = sns.barplot(x='TimeType', y='Value', hue='algorithm', hue_order = hue_order, data=df_melted[df_melted["model_type"] == "SVM"], edgecolor='black', linewidth=1, ax=axs[1], width=0.7)
g1.legend_.remove()
g2.legend_.remove()

model_types = ["LR", "SVM"]

for i in range(2):
    axs[i].set_title(f"{model_types[i]}",fontsize=30, weight="bold")
    if i == 0:
        handles, labels = axs[0].get_legend_handles_labels()
        handles, labels = handles[:3], labels[:3]
        l = fig.legend(handles, labels, loc="upper center", bbox_to_anchor=(0.54, 1),ncol=3, fontsize=28,edgecolor="black", title=None,  frameon=True)
    axs[i].set_yscale('log')

    axs[i].grid(axis='y', linestyle='--')
    axs[i].set_axisbelow(True)
    axs[i].tick_params(axis="y", direction="in", which='both',labelsize=30) 
    axs[i].tick_params(axis="x", length=0, width=0, labelsize=6)

    if i == 0:
        axs[i].set_ylabel("Per Global Epoch\nTraining Time (s)", fontsize=34, weight="bold")
    
    axs[i].set_xlabel("")
    print(axs[i].get_ylim())
    axs[i].set_ylim(0.01, 13000)
    new_x_labels = ['Comm./Sync.\nPara. Server', 'PIM\nComp.', 'PIM Data\nMovement', 'Total']
    axs[i].set_xticklabels(new_x_labels, fontsize=32)

    if i == 1:
        axs[i].set_ylabel("")
        axs[i].set_yticklabels([])

    for axis in ['top','bottom','left','right']:
        axs[i].spines[axis].set_linewidth(1.3)

colors = []
for i in range(3):
    colors.append(sns.color_palette()[i])

axs[0].text(0.5, -0.5, '(a)', transform=axs[0].transAxes, fontsize=30, va='top', ha='center')
axs[1].text(0.5, -0.5, '(b)', transform=axs[1].transAxes, fontsize=30, va='top', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.77)
plt.subplots_adjust(wspace=0.03, hspace=0.05)
plt.savefig("./output/Fig_9.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_9.png", bbox_inches='tight', pad_inches=0.01)

### Algorithm Selection.

#### Figure 10: Comparison of various models (LR (a) and SVM (b)), algorithms (MA-SGD, GA-SGD, and ADMM), and architectures (PIM and CPU). We study the AUC score (at the last global epoch) and total training time (10 global epochs).

In [None]:
df_dpu = process_DPU_csv("PIM_Criteo.csv")
df_baseline_1 = process_baseline_csv("baseline_criteo.csv")
df = pd.concat([df_dpu, df_baseline_1], ignore_index=True)

In [None]:
final_df = df.copy()
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(5, 3.5))

hue_order = ['MA-SGD', 'GA-SGD', 'ADMM']
style_order = ['DPU', 'CPU']
marker_dict = {'MA-SGD': 'o', 'GA-SGD': 's', 'ADMM': '^'}


for i, model in enumerate(["lr", "svm"]):
  for j, arch in enumerate(['DPU', 'CPU']):
    for dataset in ["criteo"]:
      data = final_df.loc[
        ( 
          (final_df["model_type"] == model) &
          (final_df["dataset"] == dataset) &
          (final_df["algorithm"] == 'ADMM') &
          (final_df["num_local_epochs"] == 1) &
          ((final_df["nr_procs"] == 2048) | (final_df["nr_procs"] == 128)) & 
          (final_df['batch_size']==2048) & 
          (final_df['scaling_type']=='weak') 
        ) | ( 
          (final_df["model_type"] == model) &
          (final_df["dataset"] == dataset) &
          (final_df["algorithm"] == 'MA-SGD') &
          (final_df["num_local_epochs"] == 1) &
          ((final_df["nr_procs"] == 2048) | (final_df["nr_procs"] == 128)) & 
          (final_df['batch_size']==2048) & 
          (final_df['scaling_type']=='weak') 
        ) | ( 
          (final_df["model_type"] == model) &
          (final_df["dataset"] == dataset) &
          (final_df["num_local_epochs"] == 1) &
          ((final_df["nr_procs"] == 2048) | (final_df["nr_procs"] == 128) | (final_df["nr_procs"] == 1)) &
          (final_df['batch_size']==524288) & 
          (final_df['scaling_type']=='weak') 
        )
      ]

      g1 = sns.lineplot(data[data["architecture"] == arch], x="total_elapsed_time", y="test_accuracy", hue="algorithm", hue_order=hue_order, style="algorithm", style_order=hue_order,markers=marker_dict, dashes=False, errorbar=None, ax=axs[j, i],palette = 'deep', markersize=3.65, markeredgewidth= 0, linewidth=1)

      model_tmp = model.upper()
      axs[0, i].set_title(f"{model_tmp}",fontsize=8, weight="bold")


all_handles = []
all_labels = []
for i in range(2):
    for j in range(2):
        handles, labels = axs[i][j].get_legend_handles_labels()
        for handle, label in zip(handles, labels):
            if label not in all_labels:
                all_handles.append(handle)
                all_labels.append(label)

order = ['MA-SGD', 'GA-SGD', 'ADMM']
ordered_handles = [next(h for h, l in zip(all_handles, all_labels) if l == alg) for alg in order]

for i in range(2):
    for j in range(2):
        axs[i][j].legend().remove()

fig.legend(ordered_handles, order, loc='upper center', bbox_to_anchor=(0.565, 1),ncol=4, fontsize=7,frameon=True, edgecolor="black", markerscale=1.0)



for i, ax in enumerate(axs.reshape(-1)):
  ax.set(xlabel=None, ylabel=None)
  ax.grid(linestyle='--')
  ax.set_axisbelow(True)

  ax.tick_params(axis="y", direction="in", labelsize=6)
  ax.tick_params(axis="x", direction="in", labelsize=6)

  for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1)


for i in range(2):
  if i == 0:
    axs[i][0].set_ylabel("PIM", fontsize=8, weight="bold")
  else:
    axs[i][0].set_ylabel(style_order[i], fontsize=8, weight="bold")

  for j in range(2):
    if i == 0 and j == 0:
      axs[i][j].set_xlim(0, 4250)
      axs[i][j].set_xticks(np.arange(0, 4301, 1000).tolist())
      axs[i][j].set_xticklabels(np.arange(0, 4301, 1000).tolist(),fontsize=8)
      axs[i][j].set_ylim(0.6875, 0.7525) 
      axs[i][j].set_yticks(np.arange(0.69, 0.7501, 0.01).tolist())
      ytick_labels = [f'{x:.2f}' for x in np.arange(0.69, 0.7501, 0.01)]
      axs[i][j].set_yticklabels(ytick_labels,fontsize=8)
    if i == 0 and j == 1:
      axs[i][j].set_xlim(0, 4250)
      axs[i][j].set_xticks(np.arange(0, 4301, 1000).tolist())
      axs[i][j].set_xticklabels(np.arange(0, 4301, 1000).tolist(),fontsize=8)
      axs[i][j].set_ylim(0.6875, 0.7525) 
      axs[i][j].set_yticks(np.arange(0.69, 0.7501, 0.01).tolist())
      ytick_labels = [f'{x:.2f}' for x in np.arange(0.69, 0.7501, 0.01)]
      axs[i][j].set_yticklabels(ytick_labels, fontsize=8)
    if i == 1 and j == 0:
      axs[i][j].set_xlim(0, 4250)
      axs[i][j].set_xticks(np.arange(0, 4301, 1000).tolist())
      axs[i][j].set_xticklabels(np.arange(0, 4301, 1000).tolist(),fontsize=8)
      axs[i][j].set_ylim(0.7175, 0.76255) 
      axs[i][j].set_yticks(np.arange(0.72, 0.7601, 0.01).tolist())
      axs[i][j].set_yticklabels(np.arange(0.72, 0.7601, 0.01).tolist(),fontsize=8)
    if i == 1 and j == 1:
      axs[i][j].set_xlim(0, 4250)
      axs[i][j].set_xticks(np.arange(0, 4301, 1000).tolist())
      axs[i][j].set_xticklabels(np.arange(0, 4301, 1000).tolist(),fontsize=8)
      axs[i][j].set_ylim(0.7175, 0.76255) 
      axs[i][j].set_yticks(np.arange(0.72, 0.7601, 0.01).tolist())
      axs[i][j].set_yticklabels(np.arange(0.72, 0.7601, 0.01).tolist(),fontsize=8)
   


colors = []
for i in range(3):
    colors.append(sns.color_palette()[i])

line = axs[0, 0].get_lines()[1]
x_data = line.get_xdata()[-1]
y_data = line.get_ydata()[-1]

an = axs[0, 0].annotate(f"Last epoch\nAUC Score: {y_data:.2f}\nFinishes at {x_data:.0f}s", 
        xy=(0, 0),  
        xytext=(0.75, 0.505), 
        textcoords='axes fraction',  
        ha='center', va='bottom', fontsize=6, annotation_clip=False, weight="bold", bbox=dict(boxstyle="round,pad=0.05", facecolor="white", edgecolor="none", alpha=0.5))
an.set_color(colors[1])

line = axs[0, 1].get_lines()[1]
x_data = line.get_xdata()[-1]
y_data = line.get_ydata()[-1]
an = axs[0, 1].annotate(f"Last epoch\nAUC Score: {y_data:.2f}\nFinishes at {x_data:.0f}s",  
        xy=(0, 0),  
        xytext=(0.75, 0.41),  
        textcoords='axes fraction',  
        ha='center', va='bottom', fontsize=6, annotation_clip=False,weight="bold", bbox=dict(boxstyle="round,pad=0.05", facecolor="white", edgecolor="none", alpha=0.5))
an.set_color(colors[1])


fig.supxlabel(r"Total Training Time (s)", y=0.18, x=0.58, fontsize=9, weight="bold")

fig.supylabel(r"AUC Score", y=0.58, x=0.05, fontsize=9, weight="bold")

axs[1,0].text(0.5, -0.35, '(a)', transform=axs[1,0].transAxes, fontsize=8, va='top', ha='center')
axs[1,1].text(0.5, -0.35, '(b)', transform=axs[1,1].transAxes, fontsize=8, va='top', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.875)
plt.subplots_adjust(wspace=0.19, hspace=0.27)
plt.savefig("./output/Fig_10.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_10.png", bbox_inches='tight', pad_inches=0.01)

### Batch Size.

#### Figure 11: Impact of batch size on total training time (10 global epochs) and AUC score (at the last global epoch) for SVM MA-SGD (a), SVM GA-SGD (b), and LR ADMM (c).

In [None]:
df_dpu = process_DPU_csv("PIM_Criteo.csv")
df_baseline_1 = process_baseline_csv("baseline_criteo.csv")
df = pd.concat([df_dpu, df_baseline_1], ignore_index=True)

In [None]:
final_df = df.copy()
final_df = final_df[(final_df['nr_procs'] == 2048) | (final_df['architecture'] != 'DPU')]

deep_palette = sns.color_palette("deep")

green_color = deep_palette[2]
blue_color = deep_palette[0]
orange_color = deep_palette[1]


nr_procs_palette = {1: orange_color, 128: blue_color, 2048: green_color}
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(7, 4.6)) 

for i, model in enumerate(["svm"]):
  for j, algorithm in enumerate(["MA-SGD", "GA-SGD", "ADMM"]):
    if j == 2:
      model = "lr"
    for dataset in ["criteo"]:
      data = final_df.loc[
        (final_df["model_type"] == model) &
        (final_df["algorithm"] == algorithm) &
        (final_df["dataset"] == dataset) &
        (final_df["num_local_epochs"] == 1) &
        (final_df['scaling_type']=='weak') &
        (final_df['g_epoch_id']==9)
      ]

      if algorithm == "GA-SGD":
          data = final_df.loc[
            (final_df["model_type"] == model) &
            (final_df["algorithm"] == algorithm) &
            (final_df["dataset"] == dataset) &
            (final_df["num_local_epochs"] == 1) &
            (final_df['scaling_type']=='weak') &
            (final_df['g_epoch_id']==9) 
          ]

      g1 = sns.barplot(data, x="batch_size", y="total_elapsed_time", hue="nr_procs",errorbar=None, ax=axs[0, i*3+j],palette=nr_procs_palette, edgecolor='black',linewidth=0.7,width=0.65)
      g2 = sns.barplot(data, x="batch_size", y="test_accuracy", hue="nr_procs",errorbar=None, ax=axs[1, i*3+j], palette=nr_procs_palette, edgecolor='black',linewidth=0.7, width=0.65)
      g1.legend_.remove()
      g2.legend_.remove()

      model_tmp = model.upper()
      b = 8 
      if algorithm == "GA-SGD":
        b = 4096
      axs[0, i*3+j].set_title(f"{model_tmp} {algorithm}",fontsize=10, weight="bold")
    

for i, ax in enumerate(axs.reshape(-1)):
  ax.set(xlabel=None, ylabel=None)
  ax.grid(axis='y',linestyle='--')
  ax.set_axisbelow(True)

  ax.tick_params(axis='y', length=0, width=0, labelsize=7)
  ax.tick_params(axis='x', length=0, width=0, labelsize=7)

  for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.3)


for j in range(3):
  handles, labels = axs[0][j].get_legend_handles_labels()
  new_labels = ["CPU", "PIM"]
  if j == 0:
    order =["CPU", "PIM"]
    ordered_handles = [next(h for h, l in zip(handles, new_labels) if l == alg) for alg in order]

    legend = fig.legend(ordered_handles, order, loc='upper center', bbox_to_anchor=(0.56, 1),ncol=2, fontsize=8,frameon=True, edgecolor="black")
    

  if j == 0:
    axs[0, j].set_ylim(0, 6150)
    axs[0, j].set_yticks(np.arange(0, 6001, 1000).tolist())
    axs[0, j].set_yticklabels(np.arange(0, 6001, 1000).tolist(), fontsize=9)
  if j == 2:
    axs[0, j].set_ylim(0, 2650)
    axs[0, j].set_yticks(np.arange(0, 2501, 500).tolist())
    axs[0, j].set_yticklabels(np.arange(0, 2501, 500).tolist(), fontsize=9)
  if j == 1:
    axs[0, j].set_ylim(0, 6150)
    axs[0, j].set_yticks(np.arange(0, 6001, 1000).tolist())
    axs[0, j].set_yticklabels(np.arange(0, 6001, 1000).tolist(), fontsize = 9)


  if j == 0:
    axs[1][j].set_ylim(0.7075, 0.7625) 
    axs[1, j].set_yticks(np.arange(0.71, 0.761, 0.01).tolist())
    axs[1, j].set_yticklabels(np.arange(0.71, 0.761, 0.01).tolist(), fontsize = 9)
  if j == 2:
    axs[1][j].set_ylim(0.7075, 0.7525) 
    axs[1, j].set_yticks(np.arange(0.71, 0.7501, 0.01).tolist())
    axs[1, j].set_yticklabels(np.arange(0.71, 0.7501, 0.01).tolist(), fontsize = 9)
  if j == 1:
    axs[1][j].set_ylim(0.7175, 0.7625) 
    axs[1, j].set_yticks(np.arange(0.72, 0.761, 0.01).tolist())
    axs[1, j].set_yticklabels(np.arange(0.72, 0.761, 0.01).tolist(), fontsize=9)
  

for j in range(3):
  if j != 1 and j != 4:
    axs[0, j].set_xticklabels(["1K", "2K", "4K", "8K"], fontsize=9)
    axs[1, j].set_xticklabels(["1K", "2K", "4K", "8K"], fontsize=9)
  else:
    axs[0, j].set_xticklabels(["131K", "262K", "524K", "1048K"], fontsize=9)
    axs[1, j].set_xticklabels(["131K", "262K", "524K", "1048K"], fontsize=9)




for j in range(3):
  for i, p in enumerate(axs[0, j].patches):
    height = p.get_height()

    if height > 6000:
      print(i)
      if i == 6:
        an = axs[0, j].annotate('{:.0f}'.format(height), 
                    xy=(0, 0),  
                    xytext=((i-3) * 0.247, 0.81),  
                    textcoords='axes fraction',  
                    ha='center', va='center', annotation_clip=False, fontsize=9, rotation=270,weight="bold",bbox=dict(boxstyle="round,pad=0.001", facecolor="white", edgecolor="none", alpha=0.5))
      else:
        an = axs[0, j].annotate('{:.0f}'.format(height), 
                    xy=(0, 0), 
                    xytext=((i-3) * 0.244, 0.81),  
                    textcoords='axes fraction',  
                    ha='center', va='center', annotation_clip=False, fontsize=9, rotation=270,weight="bold",bbox=dict(boxstyle="round,pad=0.001", facecolor="white", edgecolor="none", alpha=0.5))
      an.set_color(sns.color_palette("deep")[2])


axs[0, 0].set_ylabel("Total Training\nTime (s)", fontsize=11, weight="bold")
axs[1, 0].set_ylabel("AUC Score", fontsize=11, weight="bold")

fig.supxlabel(r"Batch Size", y=0.19, x=0.56, fontsize=12, weight="bold")

axs[1,0].text(0.5, -0.5, '(a)', transform=axs[1,0].transAxes, fontsize=11, va='bottom', ha='center')
axs[1,1].text(0.5, -0.5, '(b)', transform=axs[1,1].transAxes, fontsize=11, va='bottom', ha='center')
axs[1,2].text(0.5, -0.5, '(c)', transform=axs[1,2].transAxes, fontsize=11, va='bottom', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.875)
plt.subplots_adjust(wspace=0.24, hspace=0.15)
plt.savefig("./output/Fig_11.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_11.png", bbox_inches='tight', pad_inches=0.01)

### Weak Scaling.

#### Figure 12: Impact of weak scaling on total training time (10 global epochs) and AUC score (at the last global epoch) for LR (a) and SVM (b).

In [None]:
df_dpu = process_DPU_csv("PIM_Criteo.csv")

In [None]:
final_df = df_dpu.copy()
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 4.9))
final_df = final_df[(final_df['algorithm'] != 'GA-SGD') | ((final_df['num_local_epochs'] == 1) & (final_df['batch_size'] == 262144))]
final_df = final_df[(final_df['algorithm'] == 'GA-SGD') | ((final_df['num_local_epochs'] == 1) & (final_df['batch_size'] == 2048))]
final_df = final_df[(final_df['g_epoch_id']==9)]
final_df = final_df[(final_df['scaling_type']=='weak')]

hue_order = [256, 512, 1024, 2048]
category_order = ['MA-SGD', 'GA-SGD', 'ADMM']
for j, model in enumerate(["lr", "svm"]):
  for dataset in ["criteo"]:
    data = final_df.loc[
      (final_df["model_type"] == model) &
      (final_df["dataset"] == dataset)
    ]

    g1 = sns.barplot(data, x="algorithm", y="total_elapsed_time", hue="nr_procs", hue_order = [256, 512, 1024, 2048], order=category_order, errorbar=None, ax=axs[0, j],palette = ["#4a90e2", "#0076d1", "#0066b5", "#044d85"],edgecolor='black', linewidth=1, width=0.7)
    g2 = sns.barplot(data, x="algorithm", y="test_accuracy", hue="nr_procs", hue_order = [256, 512, 1024, 2048], order=category_order, errorbar=None, ax=axs[1, j],palette = ["#4a90e2", "#0076d1", "#0066b5", "#044d85"], edgecolor='black', linewidth=1, width=0.7)
    if j != 0:
      g1.legend_.remove()
    g2.legend_.remove()

    model_tmp = model.upper()

    axs[0, j].set_title(f"{model_tmp} (Weak Scaling)",fontsize=13, weight="bold")

all_handles = []
all_labels = []
for i in range(2):
    for j in range(2):
        handles, labels = axs[i][j].get_legend_handles_labels()
        for handle, label in zip(handles, labels):
            if label not in all_labels:
                all_handles.append(handle)
                all_labels.append(label)

order = ['256', '512', '1024', '2048']
ordered_handles = [next(h for h, l in zip(all_handles, all_labels) if l == alg) for alg in order]


for i in range(2):
    for j in range(2):
        axs[i][j].legend().remove()

leg = fig.legend(ordered_handles, order, loc='upper center', bbox_to_anchor=(0.54, 1),ncol=4, fontsize=12,frameon=True, edgecolor="black")
leg.set_title("Nr. DPUs", prop={'size': 12, 'weight': 'bold'})

for i, ax in enumerate(axs.reshape(-1)):
  ax.set(xlabel=None, ylabel=None)
  ax.grid(axis='y',linestyle='--')
  ax.set_axisbelow(True)

  ax.tick_params(axis="y", direction="in", labelsize=7)
  ax.tick_params(axis='x', length=0, width=0, labelsize=7)


  for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.3)

  if i <= 1:
    ax.set_ylim(0, 6250)
    ax.set_yticks(np.arange(0, 6100, 2000).tolist())
    ax.set_yticklabels(np.arange(0, 6100, 2000).tolist(),fontsize=10)

  if i > 1:
    ax.set_ylim(0.72, 0.755) 
    ax.set_yticks(np.arange(0.72, 0.7551, 0.01).tolist()) 
    ax.set_yticklabels(np.arange(0.72, 0.7551, 0.01).tolist(),fontsize=10) 


  if i == 1 or i == 3:
    ax.set_yticklabels([])

  if i <= 1:
    ax.set_xticklabels([])


  ax.tick_params(axis='x', labelsize=13.5) 
  ax.tick_params(axis='y', labelsize=11) 

  
for i, p in enumerate(axs[0, 0].patches):
  height = p.get_height()
  if height > 5000:
      an = axs[0, 0].annotate('{:.0f}'.format(height), 
                  xy=(0, 0), 
                  xytext=(0.64, 0.71),
                  textcoords='axes fraction', 
                  ha='center', va='center', annotation_clip=False, rotation=270,fontsize=12, weight='bold',bbox=dict(boxstyle="round,pad=0.01", facecolor="white", edgecolor="none", alpha=0.5))
      an.set_color("#044d85")
for i, p in enumerate(axs[0, 1].patches):
  height = p.get_height()
  if height > 5000:
      an = axs[0, 1].annotate('{:.0f}'.format(height), 
                  xy=(0, 0), 
                  xytext=(0.64, 0.71),  
                  textcoords='axes fraction',  
                  ha='center', va='center', annotation_clip=False,  rotation=270,fontsize=12, weight='bold',bbox=dict(boxstyle="round,pad=0.01", facecolor="white", edgecolor="none", alpha=0.5))
      an.set_color("#044d85")



axs[0, 0].set_ylabel("Total Training\nTime (s)", fontsize=13, weight="bold")
axs[1, 0].set_ylabel("AUC Score", fontsize=13, weight="bold")

fig.supxlabel(r"Optimization Algorithm", y=0.205, x=0.55, fontsize=14, weight="bold")

axs[1,0].text(0.5, -0.6, '(a)', transform=axs[1,0].transAxes, fontsize=13, va='bottom', ha='center')
axs[1,1].text(0.5, -0.6, '(b)', transform=axs[1,1].transAxes, fontsize=13, va='bottom', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.815)
plt.subplots_adjust(wspace=0.045, hspace=0.17)
plt.savefig("./output/Fig_12.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_12.png", bbox_inches='tight', pad_inches=0.01)

### Strong Scaling.

#### Figure 13: Impact of strong scaling on total training time (10 global epochs) and AUC score (at the last global epoch) for LR (a) and SVM (b).

In [None]:
df_dpu = process_DPU_csv("PIM_Criteo.csv")

In [None]:
final_df = df_dpu.copy()
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(9, 4.9))

final_df = final_df[(final_df['algorithm'] != 'GA-SGD') | ((final_df['num_local_epochs'] == 1) & (final_df['batch_size'] == 262144))]
final_df = final_df[(final_df['algorithm'] == 'GA-SGD') | ((final_df['num_local_epochs'] == 1) & (final_df['batch_size'] == 2048))]
final_df = final_df[(final_df['g_epoch_id']==9)]
final_df = final_df[(final_df['scaling_type']=='strong')]

hue_order = [256, 512, 1024, 2048]
category_order = ['MA-SGD', 'GA-SGD', 'ADMM']
for j, model in enumerate(["lr", "svm"]):
  for dataset in ["criteo"]:
    data = final_df.loc[
      (final_df["model_type"] == model) &
      (final_df["dataset"] == dataset)
    ]

    g1 = sns.barplot(data, x="algorithm", y="total_elapsed_time", hue="nr_procs", hue_order = [256, 512, 1024, 2048], order=category_order, errorbar=None, ax=axs[0, j],palette = ["#4a90e2", "#0076d1", "#0066b5", "#044d85"],edgecolor='black', linewidth=1, width=0.7)
    g2 = sns.barplot(data, x="algorithm", y="test_accuracy", hue="nr_procs", hue_order = [256, 512, 1024, 2048], order=category_order, errorbar=None, ax=axs[1, j],palette = ["#4a90e2", "#0076d1", "#0066b5", "#044d85"], edgecolor='black', linewidth=1, width=0.7)
    if j != 0:
      g1.legend_.remove()
    g2.legend_.remove() 

    model_tmp = model.upper()

    axs[0, j].set_title(f"{model_tmp} (Strong Scaling)",fontsize=13, weight="bold")

all_handles = []
all_labels = []
for i in range(2):
    for j in range(2):
        handles, labels = axs[i][j].get_legend_handles_labels()
        for handle, label in zip(handles, labels):
            if label not in all_labels:
                all_handles.append(handle)
                all_labels.append(label)

order = ['256', '512', '1024', '2048']
ordered_handles = [next(h for h, l in zip(all_handles, all_labels) if l == alg) for alg in order]

for i in range(2):
    for j in range(2):
        axs[i][j].legend().remove()

leg = fig.legend(ordered_handles, order, loc='upper center', bbox_to_anchor=(0.54, 1),ncol=4, fontsize=12,frameon=True, edgecolor="black")
leg.set_title("Nr. DPUs", prop={'size': 12, 'weight': 'bold'})

for i, ax in enumerate(axs.reshape(-1)):
  ax.set(xlabel=None, ylabel=None)
  ax.grid(axis='y',linestyle='--')
  ax.set_axisbelow(True)

  ax.tick_params(axis="y", direction="in", labelsize=7)
  ax.tick_params(axis='x', length=0, width=0, labelsize=7)


  for axis in ['top','bottom','left','right']:
    ax.spines[axis].set_linewidth(1.3)

  if i <= 1:
    ax.set_ylim(0, 2600)
    ax.set_yticks(np.arange(0, 2600, 500).tolist())
    ax.set_yticklabels(np.arange(0, 2600, 500).tolist(),fontsize=10)

  if i > 1:
    ax.set_ylim(0.70, 0.755) 
    ax.set_yticks([0.70, 0.71,0.72, 0.73, 0.74, 0.75]) 
    ax.set_yticklabels([f'{tick:.2f}' for tick in [0.70, 0.71,0.72, 0.73, 0.74, 0.75]],fontsize=10)


  if i == 1 or i == 3:
    ax.set_yticklabels([])

  if i <= 1:
    ax.set_xticklabels([])


      
  ax.tick_params(axis='x', labelsize=13.5) 
  ax.tick_params(axis='y', labelsize=11) 



axs[0, 0].set_ylabel("Total Training\nTime (s)", fontsize=13, weight="bold")
axs[1, 0].set_ylabel("AUC Score", fontsize=13, weight="bold")

fig.supxlabel(r"Optimization Algorithm", y=0.205, x=0.55, fontsize=14, weight="bold")

axs[1,0].text(0.5, -0.6, '(a)', transform=axs[1,0].transAxes, fontsize=13, va='bottom', ha='center')
axs[1,1].text(0.5, -0.6, '(b)', transform=axs[1,1].transAxes, fontsize=13, va='bottom', ha='center')

fig.tight_layout()
plt.subplots_adjust(top=0.815)
plt.subplots_adjust(wspace=0.045, hspace=0.17)
plt.savefig("./output/Fig_13.pdf", bbox_inches='tight', pad_inches=0.01)
plt.savefig("./output/Fig_13.png", bbox_inches='tight', pad_inches=0.01)