In [1]:
import json
import plotly.graph_objects as go
import plotly.express as px
import numpy as np

In [2]:
filepath = "../data/times_proper_batch.json"
with open(filepath, "r") as f:
    data = json.load(f)
data

{'multiemo': {'Unbabel/xlm-roberta-comet-small': {'time_mean': 1.0302825735211372,
   'time_std': 0.0059691365088282424},
  'xlm-roberta-base': {'time_mean': 4.462515539169312,
   'time_std': 0.024626607176193596},
  'sentence-transformers/LaBSE': {'time_mean': 4.377058901071549,
   'time_std': 0.018637889210165594},
  'microsoft/xtremedistil-l6-h256-uncased': {'time_mean': 0.5637715818583965,
   'time_std': 0.0036845343983389917}},
 'persent': {'Unbabel/xlm-roberta-comet-small': {'time_mean': 1.7373228252530097,
   'time_std': 0.006571901037244942},
  'xlm-roberta-base': {'time_mean': 7.057367500782013,
   'time_std': 0.02131077117226314},
  'sentence-transformers/LaBSE': {'time_mean': 7.093593764781952,
   'time_std': 0.0243420378095007},
  'microsoft/xtremedistil-l6-h256-uncased': {'time_mean': 1.1139679727554321,
   'time_std': 0.0040236054272918004}}}

In [3]:
multiemo_data = data["multiemo"]
persent_data = data["persent"]
transformers = {
    'Unbabel/xlm-roberta-comet-small': "RoBERTa-small",
    'xlm-roberta-base': "RoBERTa-base",
    'microsoft/xtremedistil-l6-h256-uncased': "XtremeDistil",
}

In [4]:
def get_values(data):
    names, means, errs = [], [], []
    for transformer, values in data.items():
        # values = np.array(values)[np.array(values) < 0.02]
        # avg = values["time"]/values["size"]
        avg = values["time_mean"]
        if transformer not in transformers:
            continue
        names.append(transformers[transformer])
        means.append(avg)
    return names, means

In [5]:
p_names, p_means = get_values(persent_data)

In [6]:
m_names, m_means = get_values(multiemo_data)

In [10]:
def plot_data(names, values, title):
    fig = go.Figure()
    fig.add_trace(
        go.Bar(
            x=names,
            y=values,
            text=[f"{x:.3}" for x in values],
            textposition="auto"
        )
    )
    fig.update_layout(
        # title_text="Inference time per sample",
        xaxis_title="Transformer",
        yaxis_title="Time [ms]",
        font=dict(
            family="Times New Roman",
            size=16,
        ),
        height=400,
        width=800,
    )
    return fig

In [11]:
fig = plot_data(p_names, p_means, "persent")
fig.show()
fig.write_image("time_plot.pdf")

In [9]:
plot_data(m_names, m_means, "multiemo")