# Array-loader batch timings

Inspecting batch timings from Fig 1, panel 1 of Lamin.ai's ["Training foundation models on large collections of scRNA-seq data"][blog post]: 

![](fig1panel1.svg)

- Blog post direct link: "[A large-scale benchmark]"
- Source notebook [on lamin.ai](https://lamin.ai/laminlabs/arrayloader-benchmarks/transform/faAhgiIDemaP4BB5), [in this repo](./Plot%20Figure%201.ipynb)

[A large-scale benchmark]: https://lamin.ai/blog/arrayloader-benchmarks#a-large-scale-benchmark
[blog post]: https://lamin.ai/blog/arrayloader-benchmarks#non-sharded-loading-from-local-array-backends

In [1]:
import pandas as pd
from IPython.display import Markdown

from os.path import splitext
from sys import stderr
def err(msg):
    stderr.write(msg)
    stderr.write('\n')

## Inspect `batch_times` distribution for MappedCollection, Merlin, and CELLxGENE Census benchmarks

In [2]:
df = pd.read_parquet('WDNVolxzqPiZ2Mtus9vJ.parquet')
df.method = df.method.astype(pd.CategoricalDtype(df.method.unique()))
df.index.name = 'batch'
df = df.reset_index()
df

Unnamed: 0,batch,method,epoch,batch_times
0,0,Merlin,0,0.586334
1,1,Merlin,0,0.015708
2,2,Merlin,0,0.189669
3,3,Merlin,0,0.171891
4,4,Merlin,0,0.184234
...,...,...,...,...
148040,9865,Census,4,0.041398
148041,9866,Census,4,0.046678
148042,9867,Census,4,0.041140
148043,9868,Census,4,0.093823


The dataframe index here enumerates the batch-loading operation within an epoch.

Verify it's consecutive non-negative ints, for each {method,epoch}:

In [3]:
def fsck_batch_idx_series(batch_idxs):
    assert batch_idxs.tolist() == list(range(len(batch_idxs)))

def fsck_batch_idxs(df, col='batch'):
    df.groupby(['method', 'epoch'], observed=True)[col].apply(fsck_batch_idx_series)

fsck_batch_idxs(df)

Check number of batches for each {method,epoch}:

In [4]:
df[['method', 'epoch']].value_counts()

method            epoch
Merlin            0        9870
                  1        9870
                  2        9870
                  3        9870
                  4        9870
Census            0        9870
                  1        9870
                  2        9870
                  3        9870
                  4        9870
MappedCollection  0        9869
                  1        9869
                  2        9869
                  3        9869
                  4        9869
Name: count, dtype: int64

In [5]:
nbatches = df[['method', 'epoch']].value_counts().value_counts().index[0]
nepochs = 5
nbatches, nepochs

(9870, 5)

In [6]:
def batch_times_df(batch_times):
    n = len(batch_times)
    batch_time_sum = batch_times.sort_values().cumsum().reset_index(drop=True).rename('batch_time_sum')
    batch_time_sum.index.name = 'batch_rank'
    total = batch_time_sum.iloc[-1]
    time_frac = (batch_time_sum / total).rename('time_frac')
    batch_rank = batch_time_sum.index.to_series()
    batch_frac = ((batch_rank + 1) / n).rename('batch_frac')
    return pd.concat([ batch_time_sum, time_frac, batch_frac ], axis=1)

cdf = (
    df
    .groupby(['method', 'epoch'], observed=True)
    ['batch_times']
    .apply(batch_times_df)
)
cdf['batch_time_sum_mins'] = cdf['batch_time_sum'] / 60
# ratio: [average slower batch time] / [average faster batch time]
cdf['ratio'] = (cdf.batch_frac / cdf.time_frac * (1 - cdf.time_frac) / (1 - cdf.batch_frac)).rename('ratio')
cdf

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,batch_time_sum,time_frac,batch_frac,batch_time_sum_mins,ratio
method,epoch,batch_rank,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
Merlin,0,0,0.013759,0.000004,0.000101,0.000229,25.448577
Merlin,0,1,0.027567,0.000008,0.000203,0.000459,25.405710
Merlin,0,2,0.041481,0.000012,0.000304,0.000691,25.328592
Merlin,0,3,0.055413,0.000016,0.000405,0.000924,25.283284
Merlin,0,4,0.069379,0.000020,0.000507,0.001156,25.244492
...,...,...,...,...,...,...,...
Census,4,9865,7028.237910,0.995012,0.999595,117.137299,12.365318
Census,4,9866,7036.836652,0.996229,0.999696,117.280611,12.449583
Census,4,9867,7045.657804,0.997478,0.999797,117.427630,12.475525
Census,4,9868,7054.529921,0.998734,0.999899,117.575499,12.510464


In [7]:
fsck_batch_idxs(cdf.reset_index(), 'batch_rank')

In [8]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import plotly
default_colors = plotly.colors.DEFAULT_PLOTLY_COLORS

In [9]:
i = False
def interactive():
    global i
    return i

def plot(
    fig,
    name,
    save=True,
    json=False,
    W=1200, H=1200,
    i=None, v=False,
):
    if save:
        err(f'Saving: {name}')
        fig.write_image(name, width=W, height=H)
        stem = splitext(name)[0]
        if json:
            json_path = f'{stem}.json'
            err(f'Saving: {json_path}')
            fig.write_json(json_path)
    if i is None:
        i = interactive()
    if i:
        if v:
            err("Returning interactive plot")
        return fig
    else:
        if v:
            err("Returning markdown image")
        return Markdown(f'![]({name})')

In [10]:
colors = [ 'red', 'green', 'blue' ]
def epoch_cdfs(name, x, y, title, xtitle, ytitle, hoverfmt, log_y=False, rng=None, bg='white', W=1000, H=800, grid='#ccc', rangemode=None, v=False, i=None,):
    fig = go.Figure()
    for idx, method in enumerate(['Merlin', 'MappedCollection', 'Census']):
        dm = cdf.loc[method]
        for epoch in range(nepochs):
            d = dm.loc[epoch].reset_index()
            fig.add_trace(go.Scatter(
                name=method,
                hovertemplate='%%{y:%s} (epoch %d)' % (hoverfmt, epoch),
                x=d[x],
                y=d[y],
                mode='lines',
                marker=dict(
                    color=colors[idx],
                ),
                showlegend=epoch == 0,
            ))
    fig.update_layout(
        hovermode='x',
        title=dict(x=0.5, text=title),
        width=W,
        height=H,
        plot_bgcolor=bg,
    ).update_xaxes(
        title=dict(text=xtitle),
        gridcolor=grid,
        range=rng,
        linecolor=grid,
        rangemode=rangemode,
    ).update_yaxes(
        title=dict(text=ytitle),
        gridcolor=grid,
        range=rng,
        **(dict(type='log') if log_y else {}),
        linecolor=grid,
        rangemode=rangemode,
    )
    return plot(fig, name, W=1000, H=800, v=v, i=i,)

## Sorted+Cumulative batch times

In [11]:
epoch_cdfs(
    name='time_sums.png',
    x='batch_rank',
    y='batch_time_sum_mins',
    title=f'Cumulative batch times ({nbatches:,} batches x {nepochs} epochs)',
    xtitle="Batch #",
    ytitle="Total time (minutes)",
    hoverfmt='.1f',
    rangemode='tozero',
)

Saving: time_sums.png


![](time_sums.png)

### Batch-time distributions

In [12]:
epoch_cdfs(
    name='cdfs.png',
    x='batch_frac',
    y='time_frac',
    title=f'Batch-time CDFs ({nbatches:,} batches x {nepochs} epochs)',
    xtitle="Batch %",
    ytitle="Total time %",
    hoverfmt='.2f',
    rng=[-.005, 1.005],
)

Saving: cdfs.png


![](cdfs.png)

### [avg slower batch time] / [avg faster batch time]

In [13]:
epoch_cdfs(
    name='ratios.png',
    x='batch_frac',
    y='ratio',
    title=f'Ratio: [avg slower batch time] / [avg faster batch time]',
    xtitle="Batch %",
    ytitle="[avg slower batch] / [avg faster batch]",
    hoverfmt='.1f',
    rangemode='tozero',
    log_y=True,
    # i=True,
)

Saving: ratios.png


![](ratios.png)

## Census batch timings

In [14]:
df.index.name = 'batch'
df

Unnamed: 0_level_0,batch,method,epoch,batch_times
batch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,0,Merlin,0,0.586334
1,1,Merlin,0,0.015708
2,2,Merlin,0,0.189669
3,3,Merlin,0,0.171891
4,4,Merlin,0,0.184234
...,...,...,...,...
148040,9865,Census,4,0.041398
148041,9866,Census,4,0.046678
148042,9867,Census,4,0.041140
148043,9868,Census,4,0.093823


#### Plot helper

In [15]:
def epochs_batches(
    method,
    batch_range=None,
    period=None,
    epochs=nepochs,
    log=True,
    vertical_spacing=0.05,
    size=3,
    save=True,
    W=1200, H=1200,
    grid='#ccc', bg='white',
    i=None,
):
    dm = df[df.method == method]
    fig = make_subplots(
        rows=epochs, cols=1,
        x_title='Batch #',
        subplot_titles=[ f'Epoch {epoch}' for epoch in range(epochs) ],
        vertical_spacing=vertical_spacing,
    )

    name = f'{method.lower()}_batches'
    if period:
        name += f'_mod{period}'
    if batch_range:
        start, end = batch_range
        name += f'_{start}:{end}'
    name += '.png'

    mod_str = f', mod {period}' if period else ''

    for epoch in range(epochs):
        de = dm[dm.epoch == epoch].reset_index(drop=True)
        de.index.name = 'Batch #'
        if batch_range:
            start, end = batch_range
            de = de.iloc[start:end]
            batches_str = f', batches [{start}:{end})'
        else:
            batches_str = f' x {nbatches} batches'
        if period:
            mod = (de.index.to_series() % period).rename('mod')
            for res in range(period):
                dr = de[mod == res]
                fig.add_trace(
                    go.Scatter(
                        x=dr.index,
                        y=dr.batch_times,
                        name=f'{res}mod{period}',
                        mode='markers',
                        marker=dict(size=size, color=default_colors[res]),
                        showlegend=epoch == 0,
                    ),
                    row=epoch + 1, col=1,
                )
        else:
            fig.add_trace(
                go.Scatter(
                    x=de.index,
                    y=de.batch_times,
                    name=f'Epoch {epoch}',
                    mode='markers',
                    marker=dict(size=size),
                ),
                row=epoch + 1, col=1,
            )
            mod_str = ''
    fig.update_layout(
        title=dict(x=0.5, text=f'{method}: {epochs} epochs{batches_str}{mod_str}'),
        height=200 + 200 * epochs,
        legend=dict(title=dict(text='Batch #')),
        plot_bgcolor=bg,
    ).update_yaxes(
        gridcolor=grid,
        linecolor=grid,
        title=dict(text='Time (s)'),
        **(dict(type="log") if log else {}),
    ).update_xaxes(
        gridcolor=grid,
        #linecolor=grid,
    )
    return plot(fig, name, save=save, i=i, W=W, H=H)

#### Every 10th batch is ≈100x slower, accounting for most of the total latency.

In [16]:
epochs_batches('Census', period=10)

Saving: census_batches_mod10.png


![](census_batches_mod10.png)

### Census batch timings (detail)

- Top "line" of slow batch outliers occurs every ≈10th batch, but slips by 1 every 40-50 batches (manifesting as runs of 4-5 dots of the same color)
- Middle blue line of ≈1s batches is more consistently every 10th batch

In [17]:
epochs_batches('Census', period=10, batch_range=(1200, 1800), size=6)

Saving: census_batches_mod10_1200:1800.png


![](census_batches_mod10_1200:1800.png)

## Merlin batch timings
- 1st epoch noisy, but contains clear every-10th-batch artifacts
- Other epochs: every 10th batch ≈30x slower

In [18]:
epochs_batches('Merlin', period=10)

Saving: merlin_batches_mod10.png


![](merlin_batches_mod10.png)

### Merlin batch timings (detail)

In [19]:
epochs_batches('Merlin', period=10, batch_range=(1200, 1800), size=6)

Saving: merlin_batches_mod10_1200:1800.png


![](merlin_batches_mod10_1200:1800.png)

## MappedCollection batch timings
Outliers occur every 7th batch (as opposed to every 10 batches, as seen in Census/Merlin)

In [20]:
epochs_batches('MappedCollection', period=7)

Saving: mappedcollection_batches_mod7.png


![](mappedcollection_batches_mod7.png)

### MappedCollection batch timings (detail)

In [21]:
epochs_batches('MappedCollection', period=7, batch_range=(1200, 1800), size=6)

Saving: mappedcollection_batches_mod7_1200:1800.png


![](mappedcollection_batches_mod7_1200:1800.png)