In [None]:
import sys
import torch
from flashrnn import FlashRNNConfig, flashrnn
from flashrnn.flashrnn import _zero_state
import numpy as np

In [None]:
def torch_dtype_to_str(dtype: torch.dtype):
    if dtype == torch.float:
        return "float32"
    else:
        return str(dtype)[6:]

def create_inputs(
    batch_size: int,
    sequence_size: int,
    num_heads: int,
    head_dim: int,
    function: str,
    create_states: bool = True,
    dtype: torch.dtype = torch.float16,
    device="cuda",
    **kwargs,
):
    cfg = FlashRNNConfig(
        batch_size=batch_size,
        num_heads=num_heads,
        function=function,
        head_dim=head_dim,
        dtype=torch_dtype_to_str(dtype),
    )

    num_gates_w = cfg.num_gates_w
    num_gates_r = cfg.num_gates_r
    num_gates_t = cfg.num_gates_t

    Wx = torch.randn(
        [batch_size, sequence_size, num_gates_w, num_heads, head_dim],
        device=device,
        dtype=dtype,
    )
    R = torch.randn(
        [num_gates_r, num_heads, head_dim, head_dim],
        device=device,
        dtype=dtype,
    ) / head_dim ** (0.5)
    b = torch.randn(
        [num_gates_t, num_heads, head_dim],
        device=device,
        dtype=dtype,
    )
    states = _zero_state(cfg, Wx)
    assert states.dtype == dtype

    if create_states:
        return Wx, states, R, b
    else:
        return Wx, R, b


In [None]:
torch.manual_seed(0)
B = 1
S = 512
# NH = 4
# DH = 64
NH = 1
DH = 768

In [None]:
inputs_fp64 = create_inputs(
    batch_size=B,
    sequence_size=S,
    num_heads=NH,
    head_dim=DH,
    function="lstm",
    dtype=torch.float64,
    create_states=False,
)

In [None]:
dtype_target = torch.bfloat16
# inputs_dtype =  create_inputs(
#     batch_size=B,
#     sequence_size=S,
#     num_heads=NH,
#     head_dim=DH,
#     function="lstm",
#     dtype=dtype_target,
#     create_states=False,
# )
inputs_dtype = (x.clone().to(dtype_target) for x in inputs_fp64)

In [None]:
res_fp64 = flashrnn(*inputs_fp64, function="lstm", backend="vanilla")

In [None]:
res_dtype = flashrnn(*inputs_dtype, function="lstm", backend="cuda_fused")

In [None]:
baseline_np = res_fp64[0].cpu().numpy()
target_np = res_dtype[0].to(dtype=torch.float64).cpu().numpy()

In [None]:
baseline_np.shape

In [None]:
target_np.shape

In [None]:
target_np[0].reshape(B, S, -1).shape

In [None]:
def compute_errors_c_h(baseline, target, sequence_length, batch_size):
    bl_h = baseline[0].reshape(batch_size, sequence_length, -1)
    bl_c = baseline[1].reshape(batch_size, sequence_length, -1)

    tg_h = target[0].reshape(batch_size, sequence_length, -1)
    tg_c = target[1].reshape(batch_size, sequence_length, -1)

    c_err = np.abs(bl_c - tg_c)
    h_err = np.abs(bl_h - tg_h)
    return c_err, h_err
    

In [None]:
c_err, h_err = compute_errors_c_h(baseline_np, target_np, S, B)

In [None]:
from plot.diff_lineplot import plot_error_statistics_over_time_single, plot_error_statistics_over_time_per_batchhead
from flashrnn.speed_experiments.plot_config import (
        FONTSIZE,
        FONTSIZE_SMALL,
        FONTSIZE_TICKS,
        FIGSIZE,
        style_dict,
        save_path,
    )
import matplotlib as mpl

In [None]:
with mpl.rc_context(
    rc={
        "text.usetex": True,
        "font.size": FONTSIZE,
        "axes.labelsize": FONTSIZE,
        "legend.fontsize": FONTSIZE_SMALL,
        "xtick.labelsize": FONTSIZE_TICKS,
        "ytick.labelsize": FONTSIZE_TICKS,
        "axes.titlesize": FONTSIZE,
        "lines.markersize": 4.0,  # * default: 6.0
    }
):
    fig = plot_error_statistics_over_time_per_batchhead(
        errors=h_err,
        percentiles=[50, 90, 100],
        title="LSTM Hidden State Error",
        add_mean=True,
        ema_alpha=0.7,
    )


In [None]:
fig

In [None]:
fig[0].savefig("./lstm_hidden_state_error.pdf", bbox_inches="tight")

In [None]:
# with mpl.rc_context(
#     rc={
#         "text.usetex": True,
#         "font.size": FONTSIZE,
#         "axes.labelsize": FONTSIZE,
#         "legend.fontsize": FONTSIZE_SMALL,
#         "xtick.labelsize": FONTSIZE_TICKS,
#         "ytick.labelsize": FONTSIZE_TICKS,
#         "axes.titlesize": FONTSIZE,
#         "lines.markersize": 4.0,  # * default: 6.0
#     }
# ):
#     fig = plot_error_statistics_over_time_per_batchhead(
#         errors=c_err,
#         percentiles=[50, 90, 100],
#         title="LSTM Cell State Error",
#         add_mean=True,
#         ema_alpha=0.5,
#     )