In [1]:
import os

import jsonlines
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [2]:
sns.set_style('white')

In [3]:
VIZ_DIR = '../viz'
OUTPUTS_DIR = '../outputs'
NUM_TOKENS_MEDIUM = os.path.join(OUTPUTS_DIR, 'num_tokens_medium.jsonl')
NUM_TOKENS_LARGE = os.path.join(OUTPUTS_DIR, 'num_tokens_large.jsonl')

EVAL_NAMES = ['baseline_medium', 'baseline_large'] + [
    f'split_{num_gen_layers}_BOS{i}_medium'
    for num_gen_layers in (4, 8, 12, 16, 20, 23)
    for i in range(4)
] + [
    'split_4_BOS3_large',
    'split_8_BOS3_large',
    'split_12_BOS3_large',
]

In [4]:
def read_output(eval_name: str) -> tuple[np.ndarray, pd.DataFrame]:
    path = os.path.join(OUTPUTS_DIR, f'{eval_name}.jsonl')
    loss = list()
    summary = list()

    with jsonlines.open(path) as reader:
        for output in reader:
            output.pop('block_names')
            loss.append(output.pop('cross_entropy'))
            summary.append(output)

    max_num_blocks = max(map(len, loss))
    loss = [blocks + [np.nan] * (max_num_blocks - len(blocks)) for blocks in loss]
    
    return np.array(loss), pd.DataFrame(summary)

In [5]:
losses = {eval_name: read_output(eval_name)[0] for eval_name in EVAL_NAMES}

## Finding the best BOS usage strategy

In [6]:
BOS_groups = [
    [
        f'split_{num_gen_layers}_BOS{i}_medium'
        for i in range(4)
    ]
    for num_gen_layers in (4, 8, 12, 16, 20, 23)
]

In [7]:
def calc_better_rate(a, b) -> float:
    num_blocks = np.sum(~np.isnan(a), axis=-1)
    better_rate = np.sum(a <= b, axis=-1) / num_blocks
    return np.mean(better_rate)


def compare_within_group(group) -> tuple[plt.Figure, plt.Axes]:
    scores = [
        [
            calc_better_rate(losses[eval_name_1], losses[eval_name_2])
            for eval_name_2 in group
        ]
        for eval_name_1 in group
    ]

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_title('loss[row] <= loss[column] proportion')
    sns.heatmap(
        data=scores, 
        annot=True,
        square=True,
        xticklabels=group,
        yticklabels=group,
        cmap=sns.cm.rocket_r,
        ax=ax,
    )
    
    return fig, ax

In [26]:
for num_gen_layers, group in zip((4, 8, 12, 16, 20, 23), BOS_groups):
    fig, _ = compare_within_group(group)
    path = os.path.join(VIZ_DIR, f'split_{num_gen_layers}_BOS_comparison.png')
    fig.savefig(path, bbox_inches='tight')
    plt.close(fig)

In [8]:
BEST_EVALS = [
    'split_4_BOS3_medium',
    'split_8_BOS3_medium',
    'split_12_BOS3_medium',
    'split_16_BOS3_medium',
    'split_20_BOS3_medium',
    'split_23_BOS3_medium',
    'baseline_medium',

    'split_4_BOS3_large',
    'split_8_BOS3_large',
    'split_12_BOS3_large',
    'baseline_large',
]

## CE improvement rate

In [9]:
for eval_name in BEST_EVALS:
    loss = losses[eval_name]
    print(f'{eval_name}: {np.mean(loss[:, 0] / np.nanmin(loss, axis=-1)):.04f}')

split_4_BOS3_medium: 1.0205
split_8_BOS3_medium: 1.0395
split_12_BOS3_medium: 1.0574
split_16_BOS3_medium: 1.1154
split_20_BOS3_medium: 1.4851
split_23_BOS3_medium: 2.8912
baseline_medium: 5.5676
split_4_BOS3_large: 1.0201
split_8_BOS3_large: 1.0391
split_12_BOS3_large: 1.0564
baseline_large: 3.1875


In [10]:
for eval_name in BEST_EVALS:
    loss = losses[eval_name]
    print(f'{eval_name}: {np.median(loss[:, 0] / np.nanmin(loss, axis=-1)):.04f}')

split_4_BOS3_medium: 1.0090
split_8_BOS3_medium: 1.0185
split_12_BOS3_medium: 1.0221
split_16_BOS3_medium: 1.0359
split_20_BOS3_medium: 1.0531
split_23_BOS3_medium: 1.1398
baseline_medium: 1.1734
split_4_BOS3_large: 1.0086
split_8_BOS3_large: 1.0180
split_12_BOS3_large: 1.0262
baseline_large: 1.1410


In [11]:
for eval_name in BEST_EVALS:
    loss = losses[eval_name]
    print(f'{eval_name}: {np.mean(loss[:, 0] / np.nanquantile(loss, 0.1, axis=-1)):.04f}')

split_4_BOS3_medium: 1.0103
split_8_BOS3_medium: 1.0295
split_12_BOS3_medium: 1.0455
split_16_BOS3_medium: 1.0956
split_20_BOS3_medium: 1.2868
split_23_BOS3_medium: 2.7142
baseline_medium: 4.8359
split_4_BOS3_large: 1.0041
split_8_BOS3_large: 1.0148
split_12_BOS3_large: 1.0245
baseline_large: 2.0095


## Optimal input size

In [13]:
optim_num_files_medium = np.nanargmin(losses['baseline_medium'], axis=-1) + 1
optim_num_files_large = np.nanargmin(losses['baseline_large'], axis=-1) + 1

In [14]:
with jsonlines.open(NUM_TOKENS_MEDIUM) as reader:
    cumnum_tokens_medium = [np.cumsum(dp['num_tokens']) for dp in reader]

with jsonlines.open(NUM_TOKENS_LARGE) as reader:
    cumnum_tokens_large = [np.cumsum(dp['num_tokens']) for dp in reader]

In [50]:
optim_num_files = np.concatenate([optim_num_files_medium, optim_num_files_large])
cumnum_tokens = cumnum_tokens_medium + cumnum_tokens_large
dataset_used = [('medium', 'large')[i >= len(optim_num_files_medium)] for i in range(len(optim_num_files))]

In [51]:
optim_num_tokens = np.array([num_tokens[file_idx - 1] for file_idx, num_tokens in zip(optim_num_files, cumnum_tokens)])

In [56]:
joint = sns.jointplot(
    x=optim_num_tokens,
    y=optim_num_files,
    alpha=0,
    height=8,
    marginal_kws=dict(bins=50))
scatter = sns.scatterplot(
    x=optim_num_tokens,
    y=optim_num_files,
    hue=dataset_used,
    size=np.concatenate([
        np.nanmin(losses['baseline_medium'], axis=-1), 
        np.nanmin(losses['baseline_large'], axis=-1),
    ]),
    alpha=0.7,
    edgecolor='black',
    ax=joint.ax_joint,
)

joint.fig.suptitle('Optimal input size distribution')
joint.ax_joint.set_xlabel('Optimal number of tokens')
joint.ax_joint.set_ylabel('Optimal number of files')

joint.ax_joint.set_xlim(left=0)
joint.ax_joint.set_ylim(bottom=0)
joint.ax_joint.grid(axis='x', color='black', alpha=0.125, linewidth=.5)
joint.ax_joint.grid(axis='y', color='black', alpha=0.125, linewidth=.5)

scatter.legend_.remove()
handles, labels = scatter.get_legend_handles_labels()
hue_legend = joint.ax_joint.legend(handles[:2], labels[:2], title="Dataset", loc="upper left", bbox_to_anchor=(0.125, 1))
size_legend = joint.ax_joint.legend(handles[2:], labels[2:], title="CE Value", loc="upper left")
joint.ax_joint.add_artist(hue_legend)

path = os.path.join(VIZ_DIR, f'optimal_input_size_distribution.png')
joint.fig.savefig(path, bbox_inches='tight')
plt.close(joint.fig)

## OOM

In [None]:
# TODO

fig, ax = plt.subplots(figsize=(8, 8))
ax.set_title('Generator OOM proportion')
ax.set_xlabel('Split')
ax.set_ylabel('OOM occurrence')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

_ = sns.barplot({
    f'split_{i}': (summary.OOM == 'generator').mean()
    for i, summary in enumerate([summary_4, summary_8])
})