In [None]:
import torch
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

In [None]:
%load_ext autoreload
%autoreload 2
from plot_results import plot_runtime_results, plot_runtime_results_fwbw, savefig
from plot_config import (
    col_order_lstm_fw,
    col_order_lstm_fwbw,
    col_order_slstm_fw,
    col_order_slstm_fwbw,
    FIGSIZE_2COL,
    GRIDSPEC_KWARGS,
    save_path
)

In [None]:
# DATA_FILE_DH64_NH12 = "../../outputs_speed_exps_h100_v1/sequence_length_exp/sequence_length_exp--dh-64--nh-12--B-16--dtype-bfloat16.csv"
# DATA_FILE_DH768_NH1 = "../../outputs_speed_exps_h100_v1/sequence_length_exp/sequence_length_exp--dh-768--nh-1--B-16--dtype-bfloat16.csv"
# DATA_FILE_DH64_NH12 = "../../outputs_speed_exps_h100_v3/sequence_length_exp/sequence_length_exp--dh-64--nh-12--B-16--dtype-bfloat16.csv"
# DATA_FILE_DH768_NH1 = "../../outputs_speed_exps_h100_v3/sequence_length_exp/sequence_length_exp--dh-768--nh-1--B-16--dtype-bfloat16.csv"
DATA_FILE_DH64_NH12 = "../../outputs_speed_exps_h100sxm_v5/sequence_length_exp/sequence_length_exp--dh-64--nh-12--B-16--dtype-bfloat16.csv"
DATA_FILE_DH768_NH1 = "../../outputs_speed_exps_h100sxm_v5/sequence_length_exp/sequence_length_exp--dh-768--nh-1--B-16--dtype-bfloat16.csv"

# Plot sequence length experiments DH=64, NH=12

In [None]:
sequence_length_dh64_nh12_df = pd.read_csv(DATA_FILE_DH64_NH12)
sequence_length_dh64_nh12_df.style

## LSTM

In [None]:
sequence_length_dh64_nh12_lstm_fw_df = sequence_length_dh64_nh12_df.filter(regex="(^lstm.*(triton_fused|cuda_fused|cuda)|^attention_causal--fa2.*)\+\+fw$|(?<![\w\d])T(?![\w\d])")
sequence_length_dh64_nh12_lstm_fw_df

In [None]:
sequence_length_dh64_nh12_lstm_fwbw_df = sequence_length_dh64_nh12_df.filter(regex="(^lstm.*(triton_fused|cuda_fused|cuda)|^attention_causal--fa2.*)\+\+fwbw$|(?<![\w\d])T(?![\w\d])")
sequence_length_dh64_nh12_lstm_fwbw_df

In [None]:
f = plot_runtime_results_fwbw(
    # left
    df_left=sequence_length_dh64_nh12_lstm_fw_df,
    yticks_left=[0, 5, 10, 15],
    # right
    df_right=sequence_length_dh64_nh12_lstm_fwbw_df,
    yticks_right=[0, 5, 10, 20, 30, 40],
    filename_wo_ending="sequence_length_dh64_nh12--lstm",
    group_cols=["T"],
    # modify_df_func=modify_df,
)
f

## sLSTM

In [None]:
sequence_length_dh64_nh12_slstm_fw_df = sequence_length_dh64_nh12_df.filter(regex="^slstm.*(triton_fused|cuda_fused|cuda)\+\+fw$|(?<![\w\d])T(?![\w\d])")
sequence_length_dh64_nh12_slstm_fw_df

In [None]:
f = plot_runtime_results(
    data_df=sequence_length_dh64_nh12_slstm_fw_df,
    slow_cols=[],
    slow_cols_offset=0.0,
    group_cols=["T"],
    yticks=[0, 5, 10, 15, 20],
    plot_column_order=None,
    filename="sequence_length_dh64_nh12--slstm--fw"
)
f

In [None]:
sequence_length_dh64_nh12_slstm_fwbw_df = sequence_length_dh64_nh12_df.filter(regex="^slstm.*(triton_fused|cuda_fused|cuda)\+\+fwbw$|(?<![\w\d])T(?![\w\d])")
sequence_length_dh64_nh12_slstm_fwbw_df

In [None]:
f = plot_runtime_results(
    data_df=sequence_length_dh64_nh12_slstm_fwbw_df,
    slow_cols=[],
    slow_cols_offset=0.0,
    group_cols=["T"],
    yticks=[0, 10, 20, 30, 40],
    plot_column_order=None,
    filename="sequence_length_dh64_nh12--slstm--fwbw"
)
f

In [None]:
f, (ax_left, ax_right) = plt.subplots(
    1, 2, figsize=FIGSIZE_2COL, gridspec_kw=GRIDSPEC_KWARGS
)

f = plot_runtime_results(
    data_df=sequence_length_dh64_nh12_slstm_fw_df,
    slow_cols=[],
    slow_cols_offset=0.0,
    group_cols=["T"],
    yticks=[0, 5, 10, 15, 20],
    plot_column_order=None,
    ax=ax_left,
)
f = plot_runtime_results(
    data_df=sequence_length_dh64_nh12_slstm_fwbw_df,
    slow_cols=[],
    slow_cols_offset=0.0,
    group_cols=["T"],
    yticks=[0, 10, 20, 30, 40],
    plot_column_order=None,
    ax=ax_right,
)

savefig(f, savedir=save_path, name="sequence_length_dh64_nh12--slstm")
f

# Plot batch size experiments DH=768, NH=1

In [None]:
sequence_length_dh768_nh1_df = pd.read_csv(DATA_FILE_DH768_NH1)
sequence_length_dh768_nh1_df.style

## LSTM

In [None]:
sequence_length_dh768_nh1_lstm_fw_df = sequence_length_dh768_nh1_df.filter(regex="(^lstm.*(cuda_fused|cuda)|^nn.LSTM--pytorch-float16.*|^haste.*)\+\+fw$|(?<![\w\d])T(?![\w\d])")
sequence_length_dh768_nh1_lstm_fw_df

In [None]:
sequence_length_dh768_nh1_lstm_fwbw_df = sequence_length_dh768_nh1_df.filter(regex="(^lstm.*(cuda_fused|cuda)|^nn.LSTM--pytorch-float16.*|^haste.*)\+\+fwbw$|(?<![\w\d])T(?![\w\d])")
sequence_length_dh768_nh1_lstm_fwbw_df

In [None]:
f = plot_runtime_results_fwbw(
    # left
    df_left=sequence_length_dh768_nh1_lstm_fw_df,
    yticks_left=[0, 5, 10, 15, 20, 25],
    # right
    df_right=sequence_length_dh768_nh1_lstm_fwbw_df,
    yticks_right=[0, 5, 10, 20, 30, 40, 50, 60, 70],
    filename_wo_ending="sequence_length_dh768_nh1--lstm",
    group_cols=["T"],
    # modify_df_func=modify_df,
)
f

## sLSTM

In [None]:
sequence_length_dh768_nh1_slstm_fw_df = sequence_length_dh768_nh1_df.filter(regex="^slstm.*(cuda_fused|cuda)\+\+fw$|(?<![\w\d])T(?![\w\d])")
sequence_length_dh768_nh1_slstm_fw_df

In [None]:
f = plot_runtime_results(
    data_df=sequence_length_dh768_nh1_slstm_fw_df,
    slow_cols=[],
    slow_cols_offset=0.0,
    group_cols=["T"],
    yticks=[0, 5, 10, 15, 20],
    plot_column_order=None,
    filename="sequence_length_dh768_nh1--slstm--fw"
)
f

In [None]:
sequence_length_dh768_nh1_slstm_fwbw_df = sequence_length_dh768_nh1_df.filter(regex="^slstm.*(cuda_fused|cuda)\+\+fwbw$|(?<![\w\d])T(?![\w\d])")
sequence_length_dh768_nh1_slstm_fwbw_df

In [None]:
f = plot_runtime_results(
    data_df=sequence_length_dh768_nh1_slstm_fwbw_df,
    slow_cols=[],
    slow_cols_offset=0.0,
    group_cols=["T"],
    yticks=[0, 10, 20, 30, 40, 50],
    plot_column_order=None,
    filename="sequence_length_dh768_nh1--slstm--fwbw"
)
f

In [None]:
f, (ax_left, ax_right) = plt.subplots(
    1, 2, figsize=FIGSIZE_2COL, gridspec_kw=GRIDSPEC_KWARGS
)

f = plot_runtime_results(
    data_df=sequence_length_dh768_nh1_slstm_fw_df,
    slow_cols=[],
    slow_cols_offset=0.0,
    group_cols=["T"],
    yticks=[0, 5, 10, 15, 20],
    plot_column_order=None,
    ax=ax_left,
)
f = plot_runtime_results(
    data_df=sequence_length_dh768_nh1_slstm_fwbw_df,
    slow_cols=[],
    slow_cols_offset=0.0,
    group_cols=["T"],
    yticks=[0, 10, 20, 30, 40, 50],
    plot_column_order=None,
    ax=ax_right,
)

savefig(f, savedir=save_path, name="sequence_length_dh768_nh1--slstm")
f