In [1]:
import torch
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import scienceplots as sp
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    'font.family': 'serif',
    'text.usetex': True,
    'pgf.rcfonts': False,
})

In [3]:
df = pd.read_csv('./latency.csv')
df['load_time'] *= 1E3
df = df.rename(columns={
    'load_time': 'Load Time',
    'inference_time': 'Inference Time',
    'swap_time': 'Swap Time',
    'model_name': 'Model',
})
df

Unnamed: 0,Model,Load Time,Inference Time,pruned_inference_time,Swap Time
0,VGG-11,718.0,74.38016,1,73.24
1,VGG-13,777.0,104.989527,1,74.24
2,VGG-16,1156.0,157.0711,1,146.51
3,VGG-19,1580.0,209.052128,1,127.62


In [4]:
df = pd.melt(df, id_vars=['Model'], value_vars=['Load Time', 'Inference Time', 'Swap Time'])
df

Unnamed: 0,Model,variable,value
0,VGG-11,Load Time,718.0
1,VGG-13,Load Time,777.0
2,VGG-16,Load Time,1156.0
3,VGG-19,Load Time,1580.0
4,VGG-11,Inference Time,74.38016
5,VGG-13,Inference Time,104.989527
6,VGG-16,Inference Time,157.0711
7,VGG-19,Inference Time,209.052128
8,VGG-11,Swap Time,73.24
9,VGG-13,Swap Time,74.24


In [5]:
x_axis_label = 'Model'
y_axis_label = 'Duration (ms)'

# # Get the current figure size
# current_figsize = plt.gcf().get_size_inches()

# # Set the desired height (adjust the height value to your preference)
# desired_height = 3  # Change this value to adjust the height

# # Update the figure size with the desired height while keeping the width unchanged
# plt.gcf().set_size_inches(current_figsize[0], desired_height)

with plt.style.context(['science', 'ieee']):
    ax = sns.barplot(data=df, x='Model', y='value', hue='variable')
    ax.legend(title='', loc='lower center', bbox_to_anchor=(0.5, 1.06), ncol=3, fancybox=True, shadow=True, borderaxespad=0)

    # Set custom axis labels
    ax.set_xlabel(x_axis_label)
    ax.set_ylabel(y_axis_label)

    # Despine the plot (remove right and top spines)
    sns.despine()

    # Remove tick labels from the top and right sides of the figure
    plt.tick_params(axis='both', which='both', top=False, right=False)

    # Manually add horizontal grid lines behind the bars
    y_ticks = ax.get_yticks()
    for y in y_ticks:
        ax.axhline(y, color='gray', linestyle='--', linewidth=0.5, zorder=0)

    # plt.show()
    plt.savefig('vgg_barplot.pgf', bbox_inches='tight')