# Plot Figure 1

In [None]:
import lamindb as ln
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt 

ln.connect("laminlabs/arrayloader-benchmarks")
sns.set_theme()
ln.settings.transform.stem_uid = "faAhgiIDemaP"
ln.settings.transform.version = "4.1"
%config InlineBackend.figure_formats = ['svg']

In [None]:
ln.track()

In [None]:
BATCH_SIZE = 1024

## Replot MappedCollection, Merlin, cellxgene-census benchmarks

In [None]:
artifact = ln.Artifact.filter(uid="WDNVolxzqPiZ2Mtus9vJ").one()
artifact

In [None]:
df = artifact.load()
df

In [None]:
df.method = df.method.astype(pd.CategoricalDtype(df.method.unique()))

The dataframe index here enumerates the batch-loading operation within an epoch.

In [None]:
df_subsampled = df.sample(frac=0.01)

In [None]:
df_subsampled.shape

In [None]:
def panel1(ax=None):
    ax = sns.boxplot(df_subsampled, x="method", y="batch_times", hue="epoch", ax=ax, showfliers=False, legend=False)
    ax = sns.stripplot(df_subsampled, x="method", y="batch_times", hue="epoch", ax=ax, legend=False, dodge=True, size=2, jitter=0.1, alpha=0.7, palette='dark:black')
    ax.set_ylabel("time per batch (s)")
    ax.set(yscale="log", xlabel=None)

panel1()

In [None]:
ax = sns.barplot(df, x="method", y="batch_times", hue="epoch")
ax.set_ylabel("time per batch (s)")
ax.set(yscale="log", xlabel=None)

## Convert to per epoch statistics

In [None]:
df.method = df.method.astype(pd.CategoricalDtype(["Merlin", "MappedCollection", "Census"]))

In [None]:
epoch_stats = df.groupby(["method", "epoch"]).sum()

In [None]:
epoch_stats

In [None]:
epoch_stats /= 3600

In [None]:
def panel2(ax=None):
    ax = sns.barplot(epoch_stats, x="method", y="batch_times", hue="epoch", ax=ax)
    ax.set_ylabel("time per epoch (h)")
    ax.set(xlabel=None)

panel2()

## Convert to samples per second

In [None]:
samples_per_second = 10e6 / (epoch_stats * 3600)
samples_per_second

In [None]:
def panel3(ax=None):
    ax = sns.barplot(samples_per_second, x="method", y="batch_times", hue="epoch", ax=ax)
    ax.set_ylabel("samples per second (avg per epoch)")
    ax.set(xlabel=None)

panel3()

## One figure

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 4))

panel1(axs[0])
panel2(axs[1])
panel3(axs[2])

fig.suptitle(f"Loading batches of size {BATCH_SIZE} out of a 10M x 20k array across 5 epochs", fontsize=12)

plt.tight_layout(rect=[0, 0.03, 1, 1.05])
plt.show()

In [None]:
ln.finish(i_saved_the_notebook=True)