In [263]:
import plotly.graph_objects as go
import numpy as np

In [264]:
# PARAMS VS MACS
ef_params = np.array([5.17, 7.34, 9.95, 12.94, 16.33, 20.18])  # Number of parameters (in millions)
ef_macs = np.array([0.24, 0.34, 0.46, 0.60, 0.76, 0.94])      # MACs (in billions)

mob_params = np.array([0.40, 0.56, 0.76, 0.98, 1.23, 1.5])  # Number of parameters (in millions)
mob_macs = np.array([0.01, 0.01, 0.01, 0.01, 0.02, 0.02])      # MACs (in billions)

res_params = np.array([5.17, 7.34, 9.95, 12.94, 16.33, 20.18])  # Number of parameters (in millions)
res_macs = np.array([0.24, 0.34, 0.46, 0.60, 0.76, 0.94])      # MACs (in billions)

In [265]:
# PARAMS VS EPOCH TIME
ef_epoch = np.array([563.95635, 213.123686, 205.6690, 176.361886, 155.5, 138.540963])

mob_epoch = np.array([68.3157, 63.762096, 62.477608, 60.734002, 60.8, 60.79])

res_epoch = np.array([71.6299, 79.027618, 75.391776, 64.61705, 68.7056, 61.3322])

In [266]:
# TOTAL TRAIN TIME
ef_total_train_time = np.array([1873.67082, 1030.09381, 989.312281, 899.70634, 779.863616, 706.81691])

mob_total_train_time = np.array([342.2679, 303.72387, 306.630741, 300, 298.745293, 301.907569])

res_total_train_time = np.array([350.7032, 375.856494, 351.52805, 315.1210, 313.4378, 297.1388])

In [267]:
# MAX MEMORY ALLOCATED DURING INFERENCE
ef_max_mem_allocated = np.array([294.5303, 200.229, 187.292, 175.3643, 165.4092, 156.9355])

mob_max_mem_allocated = np.array([148.678, 141.3115, 140.9175, 140.5005, 139.0293, 138.4209])

res_max_mem_allocated = np.array([104.9175, 50.3745, 41.9199, 34.1973, 28.0264, 21.9897])

In [268]:
# MODELS THROUGHPUT
ef_throughput = np.array([1531.879, 1593.3982, 1755.7384, 1902.4423, 2337.9873, 3116.4105])

mob_throughput = np.array([15878.39, 16855.295, 18143.9315, 19043.8873, 20409.3648, 21415.5001])

res_throughput = np.array([5551.8762, 5267.255, 5775.6364, 6358.7822, 6351.1721, 11016.1805])

In [269]:
# MODELS STAGE
model_stages = np.array(['BASE', 'PRUNED X1', 'PRUNED X2', 'PRUNED X3', 'PRUNED X4', 'PRUNED X5'])

In [270]:
def plot_chart(x=None, y=None, title=None, xhovertext=None, yhovertext=None, xtitle=None, ytitle=None, xinverted=False, yinverted=False, xneworder=None, yneworder=None, chart_type='scatter'):
    # Reorder x and y if new orders are provided
    if xneworder is not None:
        x = x[xneworder]
    if yneworder is not None:
        y = y[yneworder]

    # Determine mode based on chart type
    if chart_type == 'scatter':
        mode = 'markers'
    elif chart_type == 'line':
        mode = 'lines+markers'
    else:
        raise ValueError("Invalid chart_type. Use 'scatter' or 'line'.")

    # Create scatter or line plot
    trace = go.Scatter(
        x=x,
        y=y,
        mode=mode,
        marker=dict(color='blue', size=10),
        name='Data points',
        text=[f'{xhovertext}: {x_val}<br>{yhovertext}: {y_val}' for x_val, y_val in zip(x, y)],  # Hover text
        hoverinfo='text'
    )

    # Create figure
    fig = go.Figure(data=[trace])

    # Fit a trend line (linear regression) if x is numeric
    if np.issubdtype(y.dtype, np.number):
        coefficients = np.polyfit(np.arange(len(x)), y, 1)
        trendline = np.poly1d(coefficients)(np.arange(len(x)))
        trend_line = go.Scatter(
            x=x,
            y=trendline,
            mode='lines',
            line=dict(color='red', width=2),
            name='Trend line'
        )
        fig.add_trace(trend_line)

    # Customize layout to rotate axis labels
    fig.update_layout(
        title=title,
        xaxis_title=xtitle,
        yaxis_title=ytitle,
        xaxis=dict(
            tickangle=45,
            title_font=dict(size=18, family='Arial, bold'),
            tickfont=dict(size=14, family='Arial, bold'),
            autorange='reversed' if xinverted else True
        ),
        yaxis=dict(
            tickangle=45,
            title_font=dict(size=18, family='Arial, bold'),
            tickfont=dict(size=14, family='Arial, bold'),
            autorange='reversed' if yinverted else True
        ),
        plot_bgcolor='white'
    )

    # Show plot
    fig.show()

In [271]:
# Define the new order of indices
new_order = [5, 4, 3, 2, 1, 0]

In [272]:
plot_chart(ef_params, ef_macs, 'EFFICIENTNET V2 S: PARAMS VS MACS', 'PARAMS', 'MACS', 'NUMBER OF PARAMETERS (MILLION)', 'MACS (BILLION)')

In [273]:
plot_chart(res_params, res_macs, 'RESNET 18: PARAMS VS MACS', 'PARAMS', 'MACS', 'NUMBER OF PARAMETERS (MILLION)', 'MACS (BILLION)')

In [274]:
plot_chart(mob_params, mob_macs, 'MOBILE NET V3 S: PARAMS VS MACS', 'PARAMS', 'MACS', 'NUMBER OF PARAMETERS (MILLION)', 'MACS (BILLION)')

In [275]:
plot_chart(ef_params, ef_epoch, 'EFFICIENTNET V2 S: PARAMS VS EPOCH TIME (AVG)', 'PARAMS', 'AVG EPOCH TIME', 'NUMBER OF PARAMETERS (MILLION)', 'EPOCH TIME (AVG)', yneworder=new_order)

In [276]:
plot_chart(res_params, res_epoch, 'RESNET 18: PARAMS VS EPOCH TIME (AVG)', 'PARAMS', 'AVG EPOCH TIME', 'NUMBER OF PARAMETERS (MILLION)', 'EPOCH TIME (AVG)', yneworder=new_order)

In [277]:
plot_chart(mob_params, mob_epoch, 'MOBILE NET V3 S: PARAMS VS EPOCH TIME (AVG)', 'PARAMS', 'AVG EPOCH TIME', 'NUMBER OF PARAMETERS (MILLION)', 'EPOCH TIME (AVG)', yneworder=new_order)

In [278]:
plot_chart(model_stages, ef_total_train_time, 'EFFICIENTNET V2 S: TOTAL TRAIN TIME', 'MODEL STAGES', 'TOTAL TIME (SEC)', 'MODEL STAGES', 'TOTAL TIME (SEC)', chart_type='line')

In [279]:
plot_chart(model_stages, res_total_train_time, 'RESNET 18: TOTAL TRAIN TIME', 'MODEL STAGES', 'TOTAL TIME (SEC)', 'MODEL STAGES', 'TOTAL TIME (SEC)', chart_type='line')

In [280]:
plot_chart(model_stages, mob_total_train_time, 'MOBILE NET V3 S: TOTAL TRAIN TIME', 'MODEL STAGES', 'TOTAL TIME (SEC)', 'MODEL STAGES', 'TOTAL TIME (SEC)', chart_type='line')

In [281]:
plot_chart(model_stages, ef_max_mem_allocated, 'EFFICIENTNET V2 S: MAX MEMORY ALLOCATED DURING INFERENCE', 'MODEL STAGES', 'TOTAL MEM (MB)', 'MODEL STAGES', 'TOTAL MEM (MB)', chart_type='line')

In [282]:
plot_chart(model_stages, res_max_mem_allocated, 'RESNET 18: MAX MEMORY ALLOCATED DURING INFERENCE', 'MODEL STAGES', 'TOTAL MEM (MB)', 'MODEL STAGES', 'TOTAL MEM (MB)', chart_type='line')

In [283]:
plot_chart(model_stages, mob_max_mem_allocated, 'MOBILE NET V3 S: MAX MEMORY ALLOCATED DURING INFERENCE', 'MODEL STAGES', 'TOTAL MEM (MB)', 'MODEL STAGES', 'TOTAL MEM (MB)', chart_type='line')

In [284]:
plot_chart(model_stages, ef_throughput, 'EFFICIENTNET V2 S: MODELS THROUGHPUT', 'MODEL STAGES', 'MODEL THROUGHPUT (IMG/S)', 'MODEL STAGES', 'MODEL THROUGHPUT (IMG/S)', chart_type='line')

In [285]:
plot_chart(model_stages, res_throughput, 'RESNET 18: MODELS THROUGHPUT', 'MODEL STAGES', 'MODEL THROUGHPUT (IMG/S)', 'MODEL STAGES', 'MODEL THROUGHPUT (IMG/S)', chart_type='line')

In [286]:
plot_chart(model_stages, mob_throughput, 'MOBILE NET V3 S: MODELS THROUGHPUT', 'MODEL STAGES', 'MODEL THROUGHPUT (IMG/S)', 'MODEL STAGES', 'MODEL THROUGHPUT (IMG/S)', chart_type='line')