# Plot Figure 1

In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt 

sns.set_theme()
%config InlineBackend.figure_formats = ['svg']

In [None]:
BATCH_SIZE = 1024

## Replot MappedCollection, Merlin, cellxgene-census benchmarks

In [None]:
df = pd.read_parquet('WDNVolxzqPiZ2Mtus9vJ.parquet')
df.method = df.method.astype(pd.CategoricalDtype(df.method.unique()))
df

The dataframe index here enumerates the batch-loading operation within an epoch.

In [None]:
def panel1(ax=None, frac=0.01, stripplot=True):
    df_subsampled = df.sample(frac=frac)
    print(f"Subsampled: {df_subsampled.shape}")
    ax = sns.boxplot(df_subsampled, x="method", y="batch_times", hue="epoch", ax=ax, showfliers=False, legend=False)
    if stripplot:
        ax = sns.stripplot(df_subsampled, x="method", y="batch_times", hue="epoch", ax=ax, legend=False, dodge=True, size=2, jitter=0.1, alpha=0.7, palette='dark:black')
    ax.set_ylabel("time per batch (s)")
    ax.set(yscale="log", xlabel=None)
    return ax

In [None]:
frac = 0.05
for i in range(5):
    ax = panel1(frac=frac)
    ax.set(ylim=(.01, 10))
    ax.figure.savefig(f'f1_5p_{i}.png')
    plt.clf()

In [None]:
panel1(frac=1, stripplot=False)

In [None]:
panel1(frac=0.1)

In [None]:
panel1()

In [None]:
def batches_histplot(method, ax=None):
    return sns.histplot(
        df[df.method == method],
        x='batch_times',
        hue='epoch',
        multiple='stack',
        log_scale=True,
        ax=ax,
    )

In [None]:
sns.violinplot(df.sample(frac=0.1), x='method', y='batch_times', hue='epoch', log_scale=True)

In [None]:
sns.violinplot(df.sample(frac=0.1), x='method', y='batch_times', log_scale=True)

In [None]:
df.method.value_counts()

In [None]:
df[['method', 'epoch']].value_counts()

In [None]:
nbatches = df[['method', 'epoch']].value_counts().value_counts().index[0]
nepochs = 5
nbatches, nepochs

In [None]:
def batch_times_df(d):
    d = d.batch_times.sort_values().cumsum().reset_index(drop=True)
    d.index.name = 'batch'
    return d

cdf = df.groupby(['method', 'epoch']).apply(batch_times_df)
cdf = cdf.reset_index().sort_values(['batch', 'epoch', 'method']).reset_index(drop=True)
cdf

In [None]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
cdf['trace'] = cdf.apply(lambda r: f'{r.method} (e{r.epoch})', axis=1)

In [None]:
colors = [ 'red', 'green', 'blue' ]
def epoch_cdfs():
    fig = go.Figure()
    for idx, method in enumerate(['Merlin', 'MappedCollection', 'Census']):
        dm = cdf[cdf.method == method]
        for epoch in range(nepochs):
            d = dm[dm.epoch == epoch]
            fig.add_trace(go.Scatter(
                name=method,
                hovertemplate='%%{y:.1f} %s (%s)' % (method, epoch),
                x=d['batch'],
                y=d['batch_times'] / 60,
                mode='lines',
                marker=dict(
                    color=colors[idx],
                ), showlegend=epoch == 0, #legendgrouptitle=dict(text=method), #legendgroup=method,
            ))
    return fig.update_layout(
        hovermode='x',
        title=dict(x=0.5, text=f'Batch-time CDFs ({nbatches:,} batches x {nepochs} epochs)'),
        height=800,
    ).update_traces(
        #hovertemplate=None,
    ).update_xaxes(
        title=dict(text="Batch #",),
    ).update_yaxes(
        title=dict(text="Cumulative time (m)"),
    )

In [None]:
epoch_cdfs()

In [None]:
def epoch_batches(method, epoch):
    return px.scatter(df[(df.method == method) & (df.epoch == epoch)], y='batch_times')

In [None]:
import plotly
colors = plotly.colors.DEFAULT_PLOTLY_COLORS
colors

In [None]:
def epochs_batches(method, batch_range=None, period=None, epochs=nepochs):
    dm = df[df.method == method]
    fig = make_subplots(rows=epochs, cols=1, subplot_titles=[ f'Epoch {epoch}' for epoch in range(epochs) ])
    for epoch in range(epochs):
        de = dm[dm.epoch == epoch].reset_index(drop=True)
        de.index.name = 'Batch #'
        if batch_id_range:
            start, end = batch_id_range
            de = de.iloc[start:end]
        if period:
            mod = (de.index.to_series() % period).rename('mod')
            for res in range(period):
                dr = de[mod == res]
                fig.add_trace(
                    go.Scatter(
                        x=dr.index,
                        y=dr.batch_times,
                        name=f'{res}mod{period}',
                        mode='markers',
                        marker=dict(size=4, color=colors[res]),
                        showlegend=epoch == 0,
                    ),
                    row=epoch + 1, col=1,
                )                
        else:
            fig.add_trace(
                go.Scatter(
                    x=de.index,
                    y=de.batch_times,
                    name=f'Epoch {epoch}',
                    mode='markers',
                    marker=dict(size=4),
                ),
                row=epoch + 1, col=1,
            )
    fig.update_layout(
        title=dict(x=0.5, text=f'{method}: batch times ({epochs} epochs x {nbatches} batches)'),
        height=400 + 120 * epochs,
        legend=dict(title=dict(text='Batch #')),
    )
    return fig

In [None]:
epochs_batches('Census', period=10)

In [None]:
epochs_batches('Census')

In [None]:
epochs_batches('Merlin')

In [None]:
epochs_batches('MappedCollection')

In [None]:
epoch_batches('Census', 0)

In [None]:
epoch_batches('Census', 1)

In [None]:
epoch_batches('Census', 2)

In [None]:
import numpy as np

plt.clf()
sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

# Initialize the FacetGrid object
#pal = sns.cubehelix_palette(10, rot=-.25, light=.7)
g = sns.FacetGrid(df, row="method", hue='method', aspect=4, height=2)

# Draw the densities in a few steps
g.map(sns.kdeplot, "batch_times", clip_on=False, bw_adjust=.5, log_scale=True, fill=True, alpha=1, linewidth=1.5)
g.map(sns.kdeplot, "batch_times", clip_on=False, bw_adjust=.5, log_scale=True, color="w", lw=2)

# passing color=None to refline() uses the hue mapping
g.refline(y=0, linewidth=2, linestyle="-", clip_on=False)


# Define and use a simple function to label the plot in axes coordinates
def label(x, color, label):
    ax = plt.gca()
    ax.text(0, 0.2, label, color=color,
            ha="left", va="center", transform=ax.transAxes)
    ax.set(xlim=(5e-3, 5e2))

g.map(label, "batch_times")

# Set the subplots to overlap
g.figure.subplots_adjust(hspace=-.25)

# Remove axes details that don't play well with overlap
g.set_titles("")
g.set(yticks=[], ylabel="")
g.despine(bottom=True, left=True)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 4))

batches_histplot("Merlin", axs[0])
batches_histplot("MappedCollection", axs[1])
batches_histplot("Census", axs[2])

fig.suptitle(f"Loading batches of size {BATCH_SIZE} out of a 10M x 20k array across 5 epochs", fontsize=12)

plt.tight_layout(rect=[0, 0.03, 1, 1.05])
plt.show()

In [None]:
sns.histplot(
    df[df.method == 'Merlin'], x='batch_times', hue='epoch', multiple='stack', log_scale=True
)

In [None]:
ax = sns.barplot(df, x="method", y="batch_times", hue="epoch", errorbar=("pi", 90))
ax.set_ylabel("time per batch (s)")
ax.set(yscale="log", xlabel=None)

## Convert to per epoch statistics

In [None]:
df.method = df.method.astype(pd.CategoricalDtype(["Merlin", "MappedCollection", "Census"]))

In [None]:
epoch_stats = df.groupby(["method", "epoch"]).sum()

In [None]:
epoch_stats

In [None]:
epoch_stats /= 3600

In [None]:
def panel2(ax=None):
    ax = sns.barplot(epoch_stats, x="method", y="batch_times", hue="epoch", ax=ax)
    ax.set_ylabel("time per epoch (h)")
    ax.set(xlabel=None)

panel2()

## Convert to samples per second

In [None]:
samples_per_second = 10e6 / (epoch_stats * 3600)
samples_per_second

In [None]:
def panel3(ax=None):
    ax = sns.barplot(samples_per_second, x="method", y="batch_times", hue="epoch", ax=ax)
    ax.set_ylabel("samples per second (avg per epoch)")
    ax.set(xlabel=None)

panel3()

## One figure

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(12, 4))

panel1(axs[0])
panel2(axs[1])
panel3(axs[2])

fig.suptitle(f"Loading batches of size {BATCH_SIZE} out of a 10M x 20k array across 5 epochs", fontsize=12)

plt.tight_layout(rect=[0, 0.03, 1, 1.05])
plt.show()