In [None]:
import torch
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

In [None]:
%load_ext autoreload
%autoreload 2
from plot_results import plot_runtime_results, plot_runtime_results_fwbw, savefig
from plot_config import (
    col_order_lstm_wbl_fw,
    col_order_lstm_fw,
    col_order_lstm_wbl_fwbw,
    col_order_lstm_fwbw,
    col_order_slstm_fw,
    col_order_slstm_fwbw,
    FIGSIZE_2COL,
    GRIDSPEC_KWARGS,
    save_path
)


# Plot head dimension experiments

In [None]:
# DATA_FILE = "../../outputs_speed_exps_h100_v1/head_dimension_exp/head_dimension_exp--batch-16--T-1024--dtype-bfloat16.csv"
# DATA_FILE = "../../outputs_speed_exps_h100_v2/head_dimension_exp/head_dimension_exp--batch-16--T-1024--dtype-bfloat16.csv"
# DATA_FILE = "../../outputs_speed_exps_h100_v3/head_dimension_exp/head_dimension_exp--batch-16--T-1024--dtype-bfloat16.csv"
DATA_FILE = "../../outputs_speed_exps_h100sxm_v5/head_dimension_exp/head_dimension_exp--batch-16--T-1024--dtype-bfloat16.csv"
# DATA_FILE = "../../outputs_speed_exps_v5_h100nvl/head_dimension_exp/head_dimension_exp--batch-16--T-1024--dtype-bfloat16.csv"


In [None]:
head_dim_df = pd.read_csv(DATA_FILE)
head_dim_df.style
head_dim_df.sort_values(by="DH", inplace=True)

## LSTM

In [None]:
head_dim_lstm_fw_df = head_dim_df.filter(
    regex="DH|NH|(^lstm.*(vanilla_fwbw|triton_fused|cuda_fused|cuda)|^attention_causal--fa2.*|^nn.LSTM--pytorch-float16*|^haste.LSTM--pytorch-float32)\+\+fw$"
)  # .filter(regex="DH|NH|^lstm.*(vanilla_fwbw|triton_fused|cuda_fused|cuda)\+\+fw$")
head_dim_lstm_fw_df

In [None]:
head_dim_lstm_fwbw_df = head_dim_df.filter(
    regex="DH|NH|(^lstm.*(vanilla_fwbw|triton_fused|cuda_fused|cuda)|^attention_causal--fa2.*|^nn.LSTM--pytorch-float16*|^haste.LSTM--pytorch-float32)\+\+fwbw$"
)
head_dim_lstm_fwbw_df

In [None]:
def modify_df(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if "nn.LSTM--pytorch-float16++fw" in df.columns:
        # df["nn.LSTM--pytorch-float32++fw"][df["DH"] != 768.] = float("nan")
        df.loc[df["DH"] != 768., "nn.LSTM--pytorch-float16++fw"] = float("nan")
    if "nn.LSTM--pytorch-float16++fwbw" in df.columns:
        df.loc[df["DH"] != 768., "nn.LSTM--pytorch-float16++fwbw"] = float("nan")
    if "haste.LSTM--pytorch-float32++fw" in df.columns:
        df.loc[df["DH"] != 768., "haste.LSTM--pytorch-float32++fw"] = float("nan")
    if "haste.LSTM--pytorch-float32++fwbw" in df.columns:
        df.loc[df["DH"] != 768., "haste.LSTM--pytorch-float32++fwbw"] = float("nan")
        # df["nn.LSTM--pytorch-float32++fwbw"][df["DH"] != 768.] = float("nan")
    return df

In [None]:
df = modify_df(head_dim_lstm_fw_df)
# df.loc[df.columns.difference(["nn.LSTM--pytorch-float32++fw"])].fillna(0.2, inplace=True)
df_n = df.loc[:, df.columns != "nn.LSTM--pytorch-float16++fw"].fillna(0.2, inplace=True)

df.sort_values(by="DH", inplace=True)

In [None]:
df

In [None]:
group_cols=["DH", "NH"]

In [None]:
f = plot_runtime_results_fwbw(
    # left
    df_left=head_dim_lstm_fw_df,
    col_order_left=col_order_lstm_fw+ ["attention_causal--fa2++fw", "nn.LSTM--pytorch-float16++fw", "haste.LSTM--pytorch-float32++fw"],
    slow_cols_left=["lstm--vanilla_fwbw++fw"],
    fillna_exclude_cols_left=["nn.LSTM--pytorch-float16++fw", "haste.LSTM--pytorch-float32++fw"],
    slow_cols_offset_left=25,
    yticks_left=[0, 5, 10, 15],
    # right
    df_right=head_dim_lstm_fwbw_df,
    col_order_right=col_order_lstm_fwbw + ["attention_causal--fa2++fwbw", "nn.LSTM--pytorch-float16++fwbw", "haste.LSTM--pytorch-float32++fwbw"],
    slow_cols_right=["lstm--vanilla_fwbw++fwbw"],
    slow_cols_offset_right=50.0,
    yticks_right=[0, 5, 10, 20, 30, 40],
    fillna_exclude_cols_right=["nn.LSTM--pytorch-float16++fwbw", "haste.LSTM--pytorch-float32++fwbw"],
    filename_wo_ending="head_dim--lstm",
    group_cols=["DH", "NH"],
    modify_df_func=modify_df,
)
f

## sLSTM

In [None]:
head_dim_slstm_fw_df = head_dim_df.filter(regex="DH|NH|^slstm.*(vanilla_fwbw|van|triton_fused|cuda_fused|cuda|haste)\+\+fw$")
head_dim_slstm_fw_df

In [None]:
f = plot_runtime_results(
    data_df=head_dim_slstm_fw_df,
    slow_cols=["slstm--vanilla_fwbw++fw"],
    slow_cols_offset=17.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10],
    plot_column_order=col_order_slstm_fw,
    filename="head_dim--slstm--fw"
)
f

In [None]:
head_dim_slstm_fwbw_df = head_dim_df.filter(regex="DH|NH|^slstm.*(vanilla_fwbw|triton_fused|cuda_fused|cuda)\+\+fwbw$")
head_dim_slstm_fwbw_df

In [None]:
f = plot_runtime_results(
    data_df=head_dim_slstm_fwbw_df,
    slow_cols=["slstm--vanilla_fwbw++fwbw"],
    slow_cols_offset=40.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10, 15, 20, 25],
    plot_column_order=col_order_slstm_fwbw,
    filename="head_dim--slstm--fwbw"
)
f

In [None]:
f, (ax_left, ax_right) = plt.subplots(
    1, 2, figsize=FIGSIZE_2COL, gridspec_kw=GRIDSPEC_KWARGS
)

f = plot_runtime_results(
    data_df=head_dim_slstm_fw_df,
    slow_cols=["slstm--vanilla_fwbw++fw"],
    slow_cols_offset=17.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10, 15],
    plot_column_order=col_order_slstm_fw,
    filename="head_dim--slstm--fw",
    legend_args={
        "loc": "lower center",
        "ncol": 2,
        "bbox_to_anchor": (0.0, 0.97, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    },
    
    ax=ax_left,
)
f = plot_runtime_results(
    data_df=head_dim_slstm_fwbw_df,
    slow_cols=["slstm--vanilla_fwbw++fwbw"],
    slow_cols_offset=40.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10, 15, 20, 25],
    plot_column_order=col_order_slstm_fwbw,
    legend_args={
        "loc": "lower center",
        "ncol": 2,
        "bbox_to_anchor": (0.0, 0.97, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    },
    ax=ax_right,
)
savefig(f, savedir=save_path, name="head_dim--slstm")
f