In [1]:
import sys
import torch
import sys

sys.path.append("..")
from flashrnn.config import FlashRNNConfig
from flashrnn.flashrnn import _zero_state,flashrnn
import numpy as np



In [2]:
def torch_dtype_to_str(dtype: torch.dtype):
    if dtype == torch.float:
        return "float32"
    else:
        return str(dtype)[6:]

def create_inputs(
    batch_size: int,
    sequence_size: int,
    num_heads: int,
    head_dim: int,
    function: str,
    create_states: bool = True,
    dtype: torch.dtype = torch.float16,
    device="cuda",
    **kwargs,
):
    cfg = FlashRNNConfig(
        batch_size=batch_size,
        num_heads=num_heads,
        function=function,
        head_dim=head_dim,
        dtype=torch_dtype_to_str(dtype),
    )

    num_gates_w = cfg.num_gates_w
    num_gates_r = cfg.num_gates_r
    num_gates_t = cfg.num_gates_t

    Wx = torch.randn(
        [batch_size, sequence_size, num_gates_w, num_heads, head_dim],
        device=device,
        dtype=dtype,
    )
    R = torch.randn(
        [num_gates_r, num_heads, head_dim, head_dim],
        device=device,
        dtype=dtype,
    ) / head_dim ** (0.5)
    b = torch.randn(
        [num_gates_t, num_heads, head_dim],
        device=device,
        dtype=dtype,
    )
    states = _zero_state(cfg, Wx)
    assert states.dtype == dtype

    if create_states:
        return Wx, states, R, b
    else:
        return Wx, R, b


In [16]:
torch.manual_seed(0)
B = 1
S = 1
# NH = 4
# DH = 64
NH = 1
DH = 768

In [17]:
inputs_fp64 = create_inputs(
    batch_size=B,
    sequence_size=S,
    num_heads=NH,
    head_dim=DH,
    function="lstm",
    dtype=torch.float64,
    create_states=False,
)

In [18]:
dtype_target = torch.bfloat16
# inputs_dtype =  create_inputs(
#     batch_size=B,
#     sequence_size=S,
#     num_heads=NH,
#     head_dim=DH,
#     function="lstm",
#     dtype=dtype_target,
#     create_states=False,
# )
inputs_dtype = (x.clone().to(dtype_target) for x in inputs_fp64)

In [19]:
res_fp64 = flashrnn(*inputs_fp64, function="lstm", backend="vanilla")

h:  torch.Size([1, 1, 2, 1, 768])
last_h:  torch.Size([1, 1, 2, 1, 768])


In [20]:
res_dtype = flashrnn(*inputs_dtype, function="lstm", backend="cuda_fused")

states,in multilayer :  torch.Size([2, 1, 1, 1, 768])
states input[:, 0],in multilayer :  torch.Size([2, 1, 1, 768])
Wx.shape:  torch.Size([1, 1, 1, 768, 4])
R.shape:  torch.Size([1, 768, 4, 768])
b.shape:  torch.Size([1, 768, 4])
state.shape:  torch.Size([2, 1, 1, 1, 768])
h:  torch.Size([2, 1, 1, 1, 768])
last_h:  torch.Size([2, 1, 1, 1, 768])


In [21]:
baseline_np = res_fp64[0].cpu()
target_np = res_dtype[0].to(dtype=torch.float64).cpu()

In [22]:
baseline_np.shape

torch.Size([2, 1, 1, 1, 768])

In [23]:
target_np.shape

torch.Size([2, 1, 1, 1, 768])

In [24]:
target_np[0].reshape(B, S, -1).shape

torch.Size([1, 1, 768])

In [25]:
def compute_errors_c_h(baseline, target, sequence_length, batch_size):
    bl_h = baseline[0].reshape(batch_size, sequence_length, -1)
    bl_c = baseline[1].reshape(batch_size, sequence_length, -1)

    tg_h = target[0].reshape(batch_size, sequence_length, -1)
    tg_c = target[1].reshape(batch_size, sequence_length, -1)

    c_err = np.abs(bl_c - tg_c)
    h_err = np.abs(bl_h - tg_h)
    return c_err, h_err
def max_abs_diff(a, b):
    return (a - b).abs().max().item()


def mean_abs_diff(a, b):
    return (a - b).abs().mean().item()


def max_ref_diff(a, b, eps=1e-8):
    return ((a - b).abs() / (b.abs() + eps)).max().item()


def mean_ref_diff(a, b, eps=1e-8):
    return ((a - b).abs() / (b.abs() + eps)).mean().item()

In [26]:
c_err, h_err = compute_errors_c_h(baseline_np, target_np, S, B)
bl_h = baseline_np[0].reshape(B, S, -1)
bl_c = baseline_np[1].reshape(B, S, -1)

tg_h = target_np[0].reshape(B, S, -1)
tg_c = target_np[1].reshape(B, S, -1)
print(c_err)
print(h_err)

tensor([[[2.3620e-04, 6.7739e-04, 4.9857e-05, 4.1151e-03, 2.2125e-03,
          4.0267e-03, 1.3897e-04, 1.2696e-03, 6.7078e-04, 3.4403e-03,
          2.5916e-03, 6.2382e-04, 1.2335e-03, 1.6246e-04, 1.2681e-04,
          2.6080e-03, 1.0106e-03, 4.2873e-04, 1.1488e-03, 3.2763e-04,
          4.7295e-03, 3.1081e-03, 8.1583e-04, 4.3883e-04, 3.7182e-03,
          4.1386e-04, 2.1239e-04, 2.3438e-04, 6.2551e-04, 1.9348e-03,
          1.4134e-04, 3.8097e-04, 4.8781e-03, 7.5305e-04, 2.5220e-03,
          2.0614e-05, 9.0908e-04, 1.0386e-04, 1.9265e-03, 2.0101e-03,
          1.1603e-03, 4.5543e-03, 3.2066e-03, 3.0658e-03, 1.4219e-05,
          2.2696e-03, 1.5822e-03, 5.5062e-04, 2.7431e-03, 9.1109e-04,
          2.7916e-03, 5.8505e-03, 3.7486e-03, 7.1793e-03, 3.6239e-03,
          5.7160e-04, 1.4090e-03, 1.3744e-03, 7.7148e-04, 9.4602e-04,
          7.2507e-04, 4.7415e-04, 3.0770e-04, 1.2552e-03, 9.8809e-04,
          4.9032e-03, 5.7575e-04, 1.4807e-03, 5.5181e-04, 2.5544e-03,
          2.0632e-03

  c_err = np.abs(bl_c - tg_c)
  h_err = np.abs(bl_h - tg_h)


In [27]:
print(f"[Forward] hn     max abs diff: {max_abs_diff(tg_h, bl_h):.2e}")
print(f"[Forward] hn mean abs diff: {mean_abs_diff(tg_h, bl_h):.2e}")
print(f"[Forward] hn max ref diff: {max_ref_diff(tg_h, bl_h):.2e}")
print(f"[Forward] hn mean ref diff: {mean_ref_diff(tg_h, bl_h):.2e}")

print(f"[Forward] cn     max abs diff: {max_abs_diff(tg_c, bl_c):.2e}")
print(f"[Forward] cn     mean abs diff: {mean_abs_diff(tg_c, bl_c):.2e}")
print(f"[Forward] cn     max ref diff: {max_ref_diff(tg_c, bl_c):.2e}")
print(f"[Forward] cn     mean ref diff: {mean_ref_diff(tg_c, bl_c):.2e}")

[Forward] hn     max abs diff: 7.61e-03
[Forward] hn mean abs diff: 1.04e-03
[Forward] hn max ref diff: 8.58e-01
[Forward] hn mean ref diff: 1.76e-02
[Forward] cn     max abs diff: 7.18e-03
[Forward] cn     mean abs diff: 1.55e-03
[Forward] cn     max ref diff: 6.81e-01
[Forward] cn     mean ref diff: 1.14e-02


In [15]:
from plot.diff_lineplot import plot_error_statistics_over_time_single, plot_error_statistics_over_time_per_batchhead
from flashrnn.speed_experiments.plot_config import (
        FONTSIZE,
        FONTSIZE_SMALL,
        FONTSIZE_TICKS,
        FIGSIZE,
        style_dict,
        save_path,
    )
import matplotlib as mpl

ModuleNotFoundError: No module named 'flashrnn.speed_experiments'

In [None]:
with mpl.rc_context(
    rc={
        "text.usetex": True,
        "font.size": FONTSIZE,
        "axes.labelsize": FONTSIZE,
        "legend.fontsize": FONTSIZE_SMALL,
        "xtick.labelsize": FONTSIZE_TICKS,
        "ytick.labelsize": FONTSIZE_TICKS,
        "axes.titlesize": FONTSIZE,
        "lines.markersize": 4.0,  # * default: 6.0
    }
):
    fig = plot_error_statistics_over_time_per_batchhead(
        errors=h_err,
        percentiles=[50, 90, 100],
        title="LSTM Hidden State Error",
        add_mean=True,
        ema_alpha=0.7,
    )


In [None]:
fig

In [None]:
fig[0].savefig("./lstm_hidden_state_error.pdf", bbox_inches="tight")

In [None]:
# with mpl.rc_context(
#     rc={
#         "text.usetex": True,
#         "font.size": FONTSIZE,
#         "axes.labelsize": FONTSIZE,
#         "legend.fontsize": FONTSIZE_SMALL,
#         "xtick.labelsize": FONTSIZE_TICKS,
#         "ytick.labelsize": FONTSIZE_TICKS,
#         "axes.titlesize": FONTSIZE,
#         "lines.markersize": 4.0,  # * default: 6.0
#     }
# ):
#     fig = plot_error_statistics_over_time_per_batchhead(
#         errors=c_err,
#         percentiles=[50, 90, 100],
#         title="LSTM Cell State Error",
#         add_mean=True,
#         ema_alpha=0.5,
#     )