In [None]:
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"
import torch

from einops import rearrange
from tqdm import tqdm

torch.set_printoptions(linewidth=300, threshold=100000)

In [None]:
%load_ext autoreload
%autoreload 2
import sys

sys.path.append("../..")
from flashrnn.flashrnn import flashrnn

from flashrnn.flashrnn.vanilla_fwbw.fw import forward_sequence, lstm_pointwise_fw
from flashrnn.flashrnn.vanilla_fwbw.fwbw import lstm_pt_fwbw
from flashrnn.flashrnn.triton_fused.fwbw import lstm_tr_fwbw

# Match LSTM triton kernel to torch version

In [None]:
device = "cuda"
dtype = torch.float32
TGT_DTYPE = torch.float32
B = 3  # batch size
T = 23  # sequence length
NG = 4  # number of gates (NGI == NGR)
NH = 5  # number of heads
D = 32  # input/hidden (embedding) dimension
NS = 2  # number of states (c, h)

In [None]:
torch.manual_seed(0)
# Wx = torch.zeros([B, T, NG, NH, D], device=device, dtype=dtype)
# Wx[:, :, 0, :, :] = 1.0 # input gate
# Wx[:, :, 1, :, :] = 2.0 # forget gate
# Wx[:, :, 2, :, :] = 3.0 # cell gate
# # Wx[1, 2, 2, :, 5] = 500.
# Wx[:, :, 3, :, :] = 4.0 # output gate
# R = torch.zeros([NG, NH, D, D], device=device, dtype=dtype)
# R[0, :, :, :] = 1.0 # input gate
# R[1, :, :, :] = 2.0 # forget gate
# R[2, :, :, :] = 3.0 # cell gate
# R[2, :, 1, 1] = 1.11
# R[3, :, :, :] = 4.0 # output gate
# b = torch.zeros([NG, NH, D], device=device, dtype=dtype)
# b[0, :, :] = 1.0
# b[1, :, :] = 2.0
# b[2, :, :] = 3.0
# b[3, :, :] = 4.0
# states_initial = torch.zeros([NS, B, NH, D], device=device, dtype=dtype)
# states_initial[0, :, :, :] = 1.0
# states_initial[0, 0, :, 1] = 50.0
# states_initial[1, :, :, :] = 2.0

In [None]:
Wx = torch.randn([B, T, NG, NH, D], device=device, dtype=dtype)
R = torch.randn([NG, NH, D, D], device=device, dtype=dtype) / (D**0.5)
b = torch.randn([NG, NH, D], device=device, dtype=dtype)
states_initial = torch.zeros([NS, B, NH, D], device=device, dtype=dtype)

## [Direct Function Call] Check for numerical correctness

### torch autograd

In [None]:
	# •	Wx.clone()：创建 Wx 的一个副本，确保后续操作不会影响原始张量。
	# •	.to(TGT_DTYPE)：将张量转换为目标数据类型（如 torch.float32 或 torch.float16）。
	# •	.detach()：从当前的计算图中分离张量，防止其历史操作被追踪。
	# •	.requires_grad_(True)：启用梯度计算

Wx_mpt_ag = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_mpt_ag = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_mpt_ag = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_mpt_ag = (
    states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)
)

In [None]:
# 执行 flashrnn中LSTM 的并行点式前向传播
	# •	h_mpt_ag：整个序列的隐藏状态输出，形状为 (sequence_length, batch_size, hidden_dim)。
	# •	hlast_mpt_ag：序列的最后一个隐藏状态，形状为 (batch_size, hidden_dim)。
h_mpt_ag, hlast_mpt_ag = forward_sequence(
    states_initial=states_initial_mpt_ag,
    Wx=Wx_mpt_ag,
    R=R_mpt_ag,
    b=b_mpt_ag,
    forward_pointwise=lstm_pointwise_fw,
    output_gates_and_states_initial=False,
)
h_mpt_ag.shape, hlast_mpt_ag.shape  # , gates_mpt_ag.shape

In [None]:
hst_mpt_ag, cst_mpt_ag = h_mpt_ag.unbind(dim=1)
hst_mpt_ag.shape, cst_mpt_ag.shape

In [None]:
hst_mpt_ag.sum().backward()

In [None]:
# R_mpt_ag.grad

### torch obw

In [None]:
Wx_mpt_obw = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_mpt_obw = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_mpt_obw = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_mpt_obw = (
    states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)
)

In [None]:
# pytorch实现的lstm？
h_mpt_obw, hlast_mpt_obw = lstm_pt_fwbw(
    states_initial=states_initial_mpt_obw,
    Wx=Wx_mpt_obw,
    R=R_mpt_obw,
    b=b_mpt_obw,
    autocast_kernel_dtype="float32",
)

In [None]:
hst_mpt_obw, cst_mpt_obw = h_mpt_obw.unbind(dim=1)
hst_mpt_obw.shape, cst_mpt_obw.shape

In [None]:
hst_mpt_obw.sum().backward()

In [None]:
(hst_mpt_ag - hst_mpt_obw).abs().max()

In [None]:
(Wx_mpt_ag.grad - Wx_mpt_obw.grad).abs().max()

In [None]:
(R_mpt_ag.grad - R_mpt_obw.grad).abs().max()

In [None]:
(b_mpt_ag.grad - b_mpt_obw.grad).abs().max()

In [None]:
(states_initial_mpt_ag.grad - states_initial_mpt_obw.grad).abs().max()

In [None]:
Wx_mpt_ag, R_mpt_ag, b_mpt_ag, states_initial_mpt_ag

In [None]:
R_mpt_ag.grad, R_mpt_obw.grad

In [None]:
Wx_mpt_ag.grad, Wx_mpt_obw.grad

In [None]:
b_mpt_ag.grad, b_mpt_obw.grad

In [None]:
(
    states_initial_mpt_ag.grad,
    states_initial_mpt_obw.grad,
    states_initial_mpt_ag.grad.shape,
)

### triton impl

In [None]:
Wx_mtr = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_mtr = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_mtr = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_mtr = states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)

In [None]:
# triton实现的LSTM
h_mtr, hlast_mtr = lstm_tr_fwbw(
    states_initial=states_initial_mtr,
    Wx=Wx_mtr,
    R=R_mtr,
    b=b_mtr,
    autocast_kernel_dtype="float32",
)
hlast_mtr.shape, h_mtr.shape

In [None]:
(h_mtr - h_mpt_ag).abs().max(), h_mtr.shape, h_mpt_ag.shape

In [None]:
hlast_mtr.shape, hlast_mpt_ag.shape

In [None]:
(hlast_mtr - hlast_mpt_ag).abs().max()

In [None]:
hst_mtr, cst_mtr = h_mtr.unbind(dim=1)
hst_mtr.shape, cst_mtr.shape

In [None]:
hst_mtr.sum().backward()

In [None]:
(Wx_mtr.grad - Wx_mpt_ag.grad).abs().max()

In [None]:
(
    (R_mtr.grad - R_mpt_ag.grad).abs().max(),
    (R_mtr.grad - R_mpt_obw.grad).abs().max(),
    (R_mpt_ag.grad - R_mpt_obw.grad).abs().max(),
    R_mtr.grad.shape,
)

In [None]:
(R_mtr.grad - R_mpt_ag.grad)[0, 0]

In [None]:
(
    (b_mtr.grad - b_mpt_ag.grad).abs().max(),
    (b_mtr.grad - b_mpt_obw.grad).abs().max(),
    (b_mpt_obw.grad - b_mpt_ag.grad).abs().max(),
)

In [None]:
(states_initial_mtr.grad - states_initial_mpt_ag.grad).abs().max()

In [None]:
# states_initial_mtr.grad[0], states_initial_mpt_ag.grad[0]

In [None]:
states_initial_mtr.grad.shape

## [flashrnn integration] Integrate LSTM torch_fwbw + triton fused into flashrnn

In [None]:
Wx_frnn = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_frnn = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_frnn = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_frnn = states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)

In [None]:
h_frnn, hlast_frnn = flashrnn(
    Wx=Wx_frnn,
    R=R_frnn,
    b=b_frnn,
    states=None,  # states_initial_frnn,
    function="lstm",
    backend="vanilla_fwbw",
    dtype="float32",
)
h_frnn.shape, h_mpt_ag.shape

In [None]:
h_frnn[0].sum().backward()

In [None]:
h_frnn_sh = rearrange(h_frnn, "ns b t nh d -> t ns b nh d")
h_frnn_sh.shape

In [None]:
(h_frnn_sh - h_mpt_ag).abs().max()

In [None]:
h_frnn_sh.shape, h_mpt_ag.shape

In [None]:
Wx_frnn_tr = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_frnn_tr = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_frnn_tr = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_frnn_tr = (
    states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)
)

In [None]:
h_frnn_tr, hlast_frnn_tr = flashrnn(
    Wx=Wx_frnn_tr,
    R=R_frnn_tr,
    b=b_frnn_tr,
    states=None,  # states_initial_frnn,
    function="lstm",
    backend="triton_fused",
    dtype="float32",
)
h_frnn_sh.shape, h_mpt_ag.shape

In [None]:
h_frnn_tr[0].sum().backward()

In [None]:
h_frnn_tr_sh = rearrange(h_frnn_tr, "ns b t nh d -> t ns b nh d")

In [None]:
(h_frnn_tr_sh - h_mpt_ag).abs().max(), (h_frnn_tr_sh - h_mpt_ag).abs().mean()

In [None]:
h_frnn.shape

## Quick speed check

In [None]:
device = "cuda"
dtype = torch.float32
TGT_DTYPE = torch.bfloat16
B = 16  # batch size
T = 1024  # sequence length
NG = 4  # number of gates (NGI == NGR)
NH = 1  # 1 #4      # number of heads
D = 64  # input/hidden (embedding) dimension
NS = 2  # number of states (c, h)

###
WARMUP_ITERS = 50
ITERS = 5000

In [None]:
Wx = torch.randn([B, T, NG, NH, D], device=device, dtype=dtype)
R = torch.randn([NG, NH, D, D], device=device, dtype=dtype) / (D**0.5)
b = torch.randn([NG, NH, D], device=device, dtype=dtype)
states_initial = torch.randn([NS, B, NH, D], device=device, dtype=dtype)

In [None]:
Wx_mpt_ag = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_mpt_ag = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_mpt_ag = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_mpt_ag = (
    states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)
)

Wx_mtr = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_mtr = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_mtr = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_mtr = states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)

In [None]:
# pytorch autograd baseline
def lstm_pt_autograd():
    h_mpt_ag, hlast_mpt_ag = forward_sequence(
        states_initial=states_initial_mpt_ag,
        Wx=Wx_mpt_ag,
        R=R_mpt_ag,
        b=b_mpt_ag,
        forward_pointwise=lstm_pointwise_fw,
        output_gates_and_states_initial=False,
    )
    hst_mpt_ag, cst_mpt_ag = h_mpt_ag.unbind(dim=1)
    hst_mpt_ag.sum().backward()


# triton fused kernel
# lstm_tr_fwbw 和 flashrnn triton_fused 有什么区别
def lstm_triton():
    h_mtr, hlast_mtr = lstm_tr_fwbw(
        states_initial=states_initial_mtr,
        Wx=Wx_mtr,
        R=R_mtr,
        b=b_mtr,
        autocast_kernel_dtype="float32",
    )
    hst_mtr, cst_mtr = h_mtr.unbind(dim=1)
    # hst_mtr.sum().backward()


def lstm_triton_frnn():
    h_mtr, hlast_mtr = flashrnn(
        Wx=Wx_mtr,
        R=R_mtr,
        b=b_mtr,
        states=None,  # states_initial_mtr,
        function="lstm",
        backend="triton_fused",
        dtype="bfloat16",
    )
    # h_mtr[0].sum().backward()


# cuda fused kernel
def lstm_cuda_fused():
    out = flashrnn(
        Wx=Wx_mtr,
        R=R_mtr,
        b=b_mtr,
        function="lstm",
        dtype="bfloat16",
        backend="cuda_fused",
    )
    out[0][0].sum().backward()


torch_lstm = torch.nn.LSTM(
    D, D, 1, bias=True, batch_first=False, bidirectional=False
).to(device=device, dtype=dtype)
pt_in = (
    torch.randn([T, B, D], device=device, dtype=dtype)
    .clone()
    .detach()
    .requires_grad_(True)
)
print(torch_lstm)
print(pt_in.shape)


def lstm_pt_fused_cuda():
    out = torch_lstm(pt_in)
    out[0].sum().backward()

In [None]:
for _ in tqdm(range(WARMUP_ITERS), desc="Warmup - Triton"):
    lstm_triton()

for _ in tqdm(range(ITERS), desc="Main - Triton"):
    lstm_triton()

for _ in tqdm(range(WARMUP_ITERS), desc="Warmup - Triton"):
    lstm_triton_frnn()

for _ in tqdm(range(ITERS), desc="Main - Triton"):
    lstm_triton_frnn()

# for _ in tqdm(range(WARMUP_ITERS), desc="Warmup - Torch"):
#     lstm_pt_autograd()

# for _ in tqdm(range(WARMUP_ITERS), desc="Main - Torch"):
#     lstm_pt_autograd()

for _ in tqdm(range(WARMUP_ITERS), desc="Warmup - CUDA fused"):
    lstm_cuda_fused()

for _ in tqdm(range(ITERS), desc="Warmup - CUDA fused"):
    lstm_cuda_fused()

for _ in tqdm(range(WARMUP_ITERS), desc="Warmup - Torch CUDA fused"):
    lstm_pt_fused_cuda()

for _ in tqdm(range(ITERS), desc="Warmup - Torch CUDA fused"):
    lstm_pt_fused_cuda()

In [None]:
# Main - Triton:   2%|▏         | 22/1000 [00:00<00:04, 213.30it/s]
# Main - Triton: 100%|██████████| 1000/1000 [00:14<00:00, 68.19it/s]