# Summarising compute speed

`benchmark.py` was run against the REFSOIL collection -- 960 whole bacterial genomes. We separately recorded times to:
- load the sequences from compressed files and convert them into  `SeqRecord` objects
- identify the divergent set

Each condition was run 5 times with the sequences randomly drawn without replacement.

# Synopsis

Performance is approximately linear with the number of sequences for both the divergent `prep` and `max` steps.

In [None]:
import plotly.express as px
import project_path
from cogent3 import load_table, make_table


def _cast_element(row):
    try:
        return tuple(row)
    except TypeError:
        return int(row)


def group_by(table, columns, quant_col, sort_cols):
    columns = tuple(columns)
    distinct = table.distinct_values(columns)
    one_col = len(columns) == 1
    results = []
    for group in distinct:
        subtable = table.filtered(lambda x: _cast_element(x) == group, columns=columns)
        quants = subtable.columns[quant_col]
        mean, std = quants.mean(), quants.std(ddof=1)
        results.append(
            ([group] if one_col else list(group)) + [mean, std, quants.shape[0]],
        )

    table = make_table(
        header=list(columns) + [f"mean_{quant_col}", f"std_{quant_col}", "n"],
        rows=results,
    )
    table = table.sorted(columns=sort_cols)
    return table


table = load_table("benchmark.tsv")
st = table.filtered(lambda x: x == "prep", columns="command").get_columns(
    ("numseqs", "time(s)"),
)
prep_time = group_by(st, ("numseqs",), "time(s)", "numseqs")

In [None]:
px.scatter(
    prep_time,
    x="numseqs",
    y="mean_time(s)",
    error_y="std_time(s)",
    labels={"mean_time(s)": "Seconds", "numseqs": "Number of sequences"},
    trendline="ols",
)

In [None]:
st = table.filtered(lambda x: x == max, columns="command").get_columns(
    ("numseqs", "k", "time(s)"),
)
max_time = group_by(st, ("numseqs", "k"), "time(s)", ("numseqs", "k"))
st = max_time.filtered(lambda x: str(x) in "28", columns="k")
st.columns["k"] = st.columns["k"].astype(str)

tickfont = dict(size=16)
titlefont = dict(size=20)
legend = dict(title=dict(text="<i>k</i>"), font=dict(size=17), tracegroupgap=10)

fig = px.scatter(
    st,
    x="numseqs",
    y="mean_time(s)",
    error_y="std_time(s)",
    color="k",
    labels={"mean_time(s)": "Seconds", "numseqs": "Number of sequences"},
    trendline="ols",
)
fig.update_layout(
    height=600,
    width=1400,
    xaxis=dict(
        title="Number of Sequences",
        titlefont=titlefont,
        tickfont=tickfont,
    ),
    yaxis=dict(
        title="Mean time (seconds)",
        titlefont=titlefont,
        tickfont=tickfont,
    ),
    legend=legend,
)
fig.update_traces(marker=dict(size=12))

outpath = project_path.FIG_DIR / "compute_time.png"
fig.write_image(outpath)

In [None]:
st = max_time.filtered(lambda x: str(x) == "200", columns="numseqs")
st.columns["numseqs"] = st.columns["numseqs"].astype(str)
px.scatter(
    st,
    x="k",
    y="mean_time(s)",
    error_y="std_time(s)",
    color="numseqs",
    labels={"mean_time(s)": "Seconds", "k": "<i>k</i>"},
    trendline="ols",
)