In [None]:
import torch
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

In [None]:
%load_ext autoreload
%autoreload 2
from plot_results import plot_runtime_results, savefig
from plot_config import (
    col_order_lstm_fw,
    col_order_lstm_fwbw,
    col_order_slstm_fw,
    col_order_slstm_fwbw,
    FIGSIZE_2COL,
    GRIDSPEC_KWARGS,
    save_path
)


# Plot head dimension experiments

In [None]:
# DATA_FILE = "../../outputs_speed_exps_h100_v1/head_dimension_exp/head_dimension_exp--batch-16--T-1024--dtype-bfloat16.csv"
# DATA_FILE = "../../flashrnn/outputs_speed_exps_h100_v2/head_dimension_exp/head_dimension_exp--batch-16--T-1024--dtype-bfloat16.csv"
DATA_FILE = "../../flashrnn/outputs_speed_exps_h100_v3/head_dimension_exp/head_dimension_exp--batch-16--T-1024--dtype-bfloat16.csv"


In [None]:
head_dim_df = pd.read_csv(DATA_FILE)
head_dim_df.style

## LSTM

In [None]:
head_dim_lstm_fw_df = head_dim_df.filter(regex="DH|NH|(^lstm.*(vanilla_fwbw|triton_fused|cuda_fused|cuda)|^attention_causal--fa2.*|^nn.LSTM.*)\+\+fw$") #.filter(regex="DH|NH|^lstm.*(vanilla_fwbw|triton_fused|cuda_fused|cuda)\+\+fw$")
head_dim_lstm_fw_df

In [None]:
col_order_lstm_fw+["attention_causal--fa2++fw", "nn.LSTM--pytorch++fw"]

In [None]:
f = plot_runtime_results(
    data_df=head_dim_lstm_fw_df,
    slow_cols=["lstm--vanilla_fwbw++fw"],
    slow_cols_offset=16.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10],
    plot_column_order=col_order_lstm_fw,#+["attention_causal--fa2++fw", "nn.LSTM--pytorch++fw"],
    filename="head_dim--lstm--fw"
)
f

In [None]:
head_dim_lstm_fwbw_df = head_dim_df.filter(regex="DH|NH|^lstm.*(vanilla_fwbw|triton_fused|cuda_fused|cuda)\+\+fwbw$")
head_dim_lstm_fwbw_df

In [None]:
f = plot_runtime_results(
    data_df=head_dim_lstm_fwbw_df,
    slow_cols=["lstm--vanilla_fwbw++fwbw"],
    slow_cols_offset=30.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10, 15, 20, 25],
    plot_column_order=col_order_lstm_fwbw,
    filename="head_dim--lstm--fwbw"
)
f

In [None]:
f, (ax_left, ax_right) = plt.subplots(
    1, 2, figsize=FIGSIZE_2COL, gridspec_kw=GRIDSPEC_KWARGS
)

f = plot_runtime_results(
    data_df=head_dim_lstm_fw_df,
    slow_cols=["lstm--vanilla_fwbw++fw"],
    slow_cols_offset=16.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10],
    plot_column_order=col_order_lstm_fw,
    legend_args={
        "loc": "lower center",
        "ncol": 2,
        "bbox_to_anchor": (0.0, 0.97, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    },
    ax=ax_left,
)
f = plot_runtime_results(
    data_df=head_dim_lstm_fwbw_df,
    slow_cols=["lstm--vanilla_fwbw++fwbw"],
    slow_cols_offset=30.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10, 15, 20, 25],
    plot_column_order=col_order_lstm_fwbw,
    legend_args={
        "loc": "lower center",
        "ncol": 2,
        "bbox_to_anchor": (0.0, 0.97, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    },
    ax=ax_right,
)
savefig(f, savedir=save_path, name="head_dim--lstm")
f

## sLSTM

In [None]:
head_dim_slstm_fw_df = head_dim_df.filter(regex="DH|NH|^slstm.*(vanilla_fwbw|van|triton_fused|cuda_fused|cuda)\+\+fw$")
head_dim_slstm_fw_df

In [None]:
f = plot_runtime_results(
    data_df=head_dim_slstm_fw_df,
    slow_cols=["slstm--vanilla_fwbw++fw"],
    slow_cols_offset=17.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10],
    plot_column_order=col_order_slstm_fw,
    filename="head_dim--slstm--fw"
)
f

In [None]:
head_dim_slstm_fwbw_df = head_dim_df.filter(regex="DH|NH|^slstm.*(vanilla_fwbw|triton_fused|cuda_fused|cuda)\+\+fwbw$")
head_dim_slstm_fwbw_df

In [None]:
f = plot_runtime_results(
    data_df=head_dim_slstm_fwbw_df,
    slow_cols=["slstm--vanilla_fwbw++fwbw"],
    slow_cols_offset=40.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10, 15, 20, 25],
    plot_column_order=col_order_slstm_fwbw,
    filename="head_dim--slstm--fwbw"
)
f

In [None]:
f, (ax_left, ax_right) = plt.subplots(
    1, 2, figsize=FIGSIZE_2COL, gridspec_kw=GRIDSPEC_KWARGS
)

f = plot_runtime_results(
    data_df=head_dim_slstm_fw_df,
    slow_cols=["slstm--vanilla_fwbw++fw"],
    slow_cols_offset=17.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10, 15],
    plot_column_order=col_order_slstm_fw,
    filename="head_dim--slstm--fw",
    legend_args={
        "loc": "lower center",
        "ncol": 2,
        "bbox_to_anchor": (0.0, 0.97, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    },
    
    ax=ax_left,
)
f = plot_runtime_results(
    data_df=head_dim_slstm_fwbw_df,
    slow_cols=["slstm--vanilla_fwbw++fwbw"],
    slow_cols_offset=40.0,
    group_cols=["NH", "DH"],
    yticks=[0, 5, 10, 15, 20, 25],
    plot_column_order=col_order_slstm_fwbw,
    legend_args={
        "loc": "lower center",
        "ncol": 2,
        "bbox_to_anchor": (0.0, 0.97, 1.0, 0.102),
        "frameon": False,
        "facecolor": "white",
    },
    ax=ax_right,
)
savefig(f, savedir=save_path, name="head_dim--slstm")
f