In [None]:
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"
import torch

from einops import rearrange
from tqdm import tqdm

torch.set_printoptions(linewidth=300, threshold=100000)

In [None]:
%load_ext autoreload
%autoreload 2
import sys

sys.path.append("../..")
from flashrnn.flashrnn import flashrnn

from flashrnn.flashrnn.vanilla_fwbw.fw import forward_sequence, slstm_pointwise_fw
from flashrnn.flashrnn.vanilla_fwbw.fwbw import slstm_pt_fwbw
from flashrnn.flashrnn.triton_fused.fwbw import slstm_tr_fwbw

# Match sLSTM triton kernel to torch version

In [None]:
device = "cuda"
dtype = torch.float32
TGT_DTYPE = torch.bfloat16
B = 16  # batch size
T = 23  # sequence length
NG = 4  # number of gates (NGI == NGR)
NH = 5  # number of heads
D = 32  # input/hidden (embedding) dimension
NS = 4  # number of states (c, h)

In [None]:
torch.manual_seed(0)
# Wx = torch.zeros([B, T, NG, NH, D], device=device, dtype=dtype)
# Wx[:, :, 0, :, :] = 1.0 # input gate
# Wx[:, :, 1, :, :] = 2.0 # forget gate
# Wx[:, :, 2, :, :] = 3.0 # cell gate
# # Wx[1, 2, 2, :, 5] = 500.
# Wx[:, :, 3, :, :] = 4.0 # output gate
# R = torch.zeros([NG, NH, D, D], device=device, dtype=dtype)
# # R[0, :, :, :] = 1.0 # input gate
# # R[1, :, :, :] = 2.0 # forget gate
# # R[2, :, :, :] = 3.0 # cell gate
# # R[2, :, 1, 1] = 1.11
# # R[3, :, :, :] = 4.0 # output gate
# b = torch.zeros([NG, NH, D], device=device, dtype=dtype)
# # b[0, :, :] = 1.0
# # b[1, :, :] = 2.0
# # b[2, :, :] = 3.0
# # b[3, :, :] = 4.0
# states_initial = torch.zeros([NS, B, NH, D], device=device, dtype=dtype)
# states_initial[0, :, :, :] = 1.0
# states_initial[0, 0, :, 1] = 50.0
# states_initial[1, :, :, :] = 2.0

In [None]:
Wx = torch.randn([B, T, NG, NH, D], device=device, dtype=dtype)
R = torch.randn([NG, NH, D, D], device=device, dtype=dtype) / (D**0.5)
b = torch.randn([NG, NH, D], device=device, dtype=dtype)
states_initial = torch.zeros([NS, B, NH, D], device=device, dtype=dtype)
#! Note all states but the h state must be zero otherwise numerics do not match
states_initial[1, :, :, :] = 0.0
states_initial[2, :, :, :] = 0.0
states_initial[3, :, :, :] = 0.0

## [Direct Function Call] Check for numerical correctness

### torch autograd

In [None]:
Wx_mpt_ag = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_mpt_ag = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_mpt_ag = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_mpt_ag = (
    states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)
)

In [None]:
h_mpt_ag, hlast_mpt_ag = forward_sequence(
    states_initial=states_initial_mpt_ag,
    Wx=Wx_mpt_ag,
    R=R_mpt_ag,
    b=b_mpt_ag,
    forward_pointwise=slstm_pointwise_fw,
    output_gates_and_states_initial=False,
)
h_mpt_ag.shape, hlast_mpt_ag.shape  # , gates_mpt_ag.shape

In [None]:
hst_mpt_ag, cst_mpt_ag, nst_mpt_ag, mst_mpt_ag = h_mpt_ag.unbind(dim=1)
hst_mpt_ag.shape, cst_mpt_ag.shape

In [None]:
hst_mpt_ag.sum().backward()

In [None]:
# R_mpt_ag.grad

### torch obw

In [None]:
Wx_mpt_obw = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_mpt_obw = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_mpt_obw = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_mpt_obw = (
    states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)
)

In [None]:
h_mpt_obw, hlast_mpt_obw = slstm_pt_fwbw(
    states_initial=states_initial_mpt_obw,
    Wx=Wx_mpt_obw,
    R=R_mpt_obw,
    b=b_mpt_obw,
    autocast_kernel_dtype="float32",
)

In [None]:
hst_mpt_obw, cst_mpt_obw, nst_mpt_obw, mst_mpt_obw = h_mpt_obw.unbind(dim=1)
hst_mpt_obw.shape, cst_mpt_obw.shape

In [None]:
(hst_mpt_ag - hst_mpt_obw).abs().max()

In [None]:
(
    (cst_mpt_ag - cst_mpt_obw).abs().max(),
    (nst_mpt_ag - nst_mpt_obw).abs().max(),
    (mst_mpt_ag - mst_mpt_obw).abs().max(),
)

In [None]:
mst_mpt_obw

In [None]:
hst_mpt_obw.sum().backward()

In [None]:
(Wx_mpt_ag.grad - Wx_mpt_obw.grad).abs().max()

In [None]:
(R_mpt_ag.grad - R_mpt_obw.grad).abs().max()

In [None]:
(b_mpt_ag.grad - b_mpt_obw.grad).abs().max()

In [None]:
# the error for this is high since in the custom backward pass we do not compute the gradients through the m state
(states_initial_mpt_ag.grad - states_initial_mpt_obw.grad).abs().max()

In [None]:
# Wx_mpt_ag, R_mpt_ag, b_mpt_ag, states_initial_mpt_ag

In [None]:
# R_mpt_ag.grad, R_mpt_obw.grad

In [None]:
# Wx_mpt_ag.grad, Wx_mpt_obw.grad

In [None]:
b_mpt_ag.grad, b_mpt_obw.grad

In [None]:
(
    states_initial_mpt_ag.grad,
    states_initial_mpt_obw.grad,
    states_initial_mpt_ag.grad.shape,
)

### triton impl

In [None]:
Wx_mtr = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_mtr = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_mtr = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_mtr = states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)

In [None]:
h_mtr, hlast_mtr = slstm_tr_fwbw(
    states_initial=states_initial_mtr,
    Wx=Wx_mtr,
    R=R_mtr,
    b=b_mtr,
    autocast_kernel_dtype="float32",
)

In [None]:
h_mtr.shape, hlast_mtr.shape, h_mpt_ag.shape

In [None]:
h_mpt_ag.shape

In [None]:
# h_mtr, h_mpt_ag

In [None]:
(h_mtr - h_mpt_obw).abs().max()

In [None]:
(hlast_mtr - hlast_mpt_ag).abs().max()

In [None]:
hst_mtr, cst_mtr, nst_mtr, mst_mtr = h_mtr.unbind(dim=1)
hst_mtr.shape, cst_mtr.shape, nst_mtr.shape, mst_mtr.shape

In [None]:
hst_mtr.sum().backward()

In [None]:
(Wx_mtr.grad - Wx_mpt_obw.grad).abs().max(), (Wx_mtr.grad - Wx_mpt_ag.grad).abs().max()

In [None]:
(R_mtr.grad - R_mpt_obw.grad).abs().max(), (R_mtr.grad - R_mpt_ag.grad).abs().max()

In [None]:
(b_mtr.grad - b_mpt_obw.grad).abs().max(), (b_mtr.grad - b_mpt_ag.grad).abs().max()

In [None]:
(
    (states_initial_mtr.grad[:3] - states_initial_mpt_obw.grad[:3]).abs().max(),
    (states_initial_mtr.grad - states_initial_mpt_ag.grad).abs().max(),
)

In [None]:
# (b_mtr.grad, b_mpt_obw.grad)

In [None]:
states_initial_mpt_ag.grad

## [flashrnn integration] Integrate LSTM torch_fwbw + triton fused into flashrnn

In [None]:
Wx_frnn = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_frnn = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_frnn = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_frnn = states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)

In [None]:
h_frnn, hlast_frnn = flashrnn(
    Wx=Wx_frnn,
    R=R_frnn,
    b=b_frnn,
    states=None,  # states_initial_frnn,
    function="slstm",
    backend="vanilla_fwbw",
    dtype="bfloat16",
)
h_frnn.shape, h_mpt_ag.shape

In [None]:
h_frnn[0].sum().backward()

In [None]:
h_frnn_sh = rearrange(h_frnn, "ns b t nh d -> t ns b nh d")
h_frnn_sh.shape

In [None]:
(h_frnn_sh - h_mpt_ag).abs().max(), (h_frnn_sh - h_mpt_ag).abs().mean()

In [None]:
h_frnn_sh.shape, h_mpt_ag.shape

In [None]:
Wx_frnn_tr = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_frnn_tr = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_frnn_tr = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_frnn_tr = (
    states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)
)

In [None]:
h_frnn_tr, hlast_frnn_tr = flashrnn(
    Wx=Wx_frnn_tr,
    R=R_frnn_tr,
    b=b_frnn_tr,
    states=None,  # states_initial_frnn,
    function="slstm",
    backend="triton_fused",
    dtype="float32",
)
h_frnn_sh.shape, h_mpt_ag.shape

In [None]:
h_frnn_tr[0].sum().backward()

In [None]:
h_frnn_tr_sh = rearrange(h_frnn_tr, "ns b t nh d -> t ns b nh d")

In [None]:
(h_frnn_tr_sh - h_mpt_ag).abs().max(), (h_frnn_tr_sh - h_mpt_ag).abs().mean()

In [None]:
h_frnn.shape

In [None]:
(
    (Wx_frnn_tr.grad - Wx_mpt_ag.grad).abs().max(),
    (Wx_frnn_tr.grad - Wx_mpt_obw.grad).abs().max(),
)

In [None]:
(
    (b_frnn_tr.grad - b_mpt_ag.grad).abs().max(),
    (b_frnn_tr.grad - b_mpt_obw.grad).abs().max(),
)

In [None]:
from flashrnn.tests.utils import model_test

In [None]:
model_test(
    batch_size=B,
    sequence_size=T,
    num_heads=NH,
    head_dim=D,
    backend="triton_fused",
    backend_cmp="vanilla_fwbw",
    function="slstm",
    dtype=dtype,
    include_backward=True,
    tensor_compare_kwargs={"atol": 0.5, "rtol": 1.0},
)

## Quick speed check

In [None]:
device = "cuda"
dtype = torch.float32
TGT_DTYPE = torch.bfloat16
B = 16  # batch size
T = 1024  # sequence length
NG = 4  # number of gates (NGI == NGR)
NH = 1  # 1 #4      # number of heads
D = 64  # input/hidden (embedding) dimension
NS = 4  # number of states (c, h)

###
WARMUP_ITERS = 50
ITERS = 1000

In [None]:
Wx = torch.randn([B, T, NG, NH, D], device=device, dtype=dtype)
R = torch.randn([NG, NH, D, D], device=device, dtype=dtype) / (D**0.5)
b = torch.randn([NG, NH, D], device=device, dtype=dtype)
states_initial = torch.zeros([NS, B, NH, D], device=device, dtype=dtype)

In [None]:
Wx_mpt_ag = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_mpt_ag = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_mpt_ag = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_mpt_ag = (
    states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)
)

Wx_mtr = Wx.clone().to(TGT_DTYPE).detach().requires_grad_(True)
R_mtr = R.clone().to(TGT_DTYPE).detach().requires_grad_(True)
b_mtr = b.clone().to(TGT_DTYPE).detach().requires_grad_(True)
states_initial_mtr = states_initial.clone().to(TGT_DTYPE).detach().requires_grad_(True)

In [None]:
# pytorch autograd baseline
def slstm_pt_autograd():
    h_mpt_ag, hlast_mpt_ag = forward_sequence(
        states_initial=states_initial_mpt_ag,
        Wx=Wx_mpt_ag,
        R=R_mpt_ag,
        b=b_mpt_ag,
        forward_pointwise=slstm_pointwise_fw,
        output_gates_and_states_initial=False,
    )
    hst_mpt_ag, cst_mpt_ag, _, _ = h_mpt_ag.unbind(dim=1)
    hst_mpt_ag.sum().backward()


# triton fused kernel
def slstm_triton():
    h_mtr, hlast_mtr = slstm_tr_fwbw(
        states_initial=states_initial_mtr,
        Wx=Wx_mtr,
        R=R_mtr,
        b=b_mtr,
        autocast_kernel_dtype="float32",
    )
    hst_mtr, cst_mtr, _, _ = h_mtr.unbind(dim=1)
    hst_mtr.sum().backward()


def slstm_triton_frnn():
    h_mtr, hlast_mtr = flashrnn(
        Wx=Wx_mtr,
        R=R_mtr,
        b=b_mtr,
        states=None,  # states_initial_mtr,
        function="slstm",
        backend="triton_fused",
        dtype="bfloat16",
    )
    h_mtr[0].sum().backward()


# cuda fused kernel
def slstm_cuda_fused():
    out = flashrnn(
        Wx=Wx_mtr,
        R=R_mtr,
        b=b_mtr,
        function="slstm",
        dtype="bfloat16",
        backend="cuda_fused",
    )
    out[0][0].sum().backward()


torch_lstm = torch.nn.LSTM(
    D, D, 1, bias=True, batch_first=False, bidirectional=False
).to(device=device, dtype=dtype)
pt_in = (
    torch.randn([T, B, D], device=device, dtype=dtype)
    .clone()
    .detach()
    .requires_grad_(True)
)
print(torch_lstm)
print(pt_in.shape)


def lstm_pt_fused_cuda():
    out = torch_lstm(pt_in)
    out[0].sum().backward()

In [None]:
for _ in tqdm(range(WARMUP_ITERS), desc="Warmup - Triton"):
    slstm_triton()

for _ in tqdm(range(ITERS), desc="Main - Triton"):
    slstm_triton()

for _ in tqdm(range(WARMUP_ITERS), desc="Warmup - Triton frnn"):
    slstm_triton_frnn()

for _ in tqdm(range(ITERS), desc="Main - Triton frnn"):
    slstm_triton_frnn()

# for _ in tqdm(range(WARMUP_ITERS), desc="Warmup - Torch"):
#     slstm_pt_autograd()

# for _ in tqdm(range(WARMUP_ITERS), desc="Main - Torch"):
#     slstm_pt_autograd()

# for _ in tqdm(range(WARMUP_ITERS), desc="Warmup - CUDA fused"):
#     slstm_cuda_fused()

# for _ in tqdm(range(ITERS), desc="Warmup - CUDA fused"):
#     slstm_cuda_fused()

# for _ in tqdm(range(WARMUP_ITERS), desc="Warmup - Torch CUDA fused"):
#     lstm_pt_fused_cuda()

# for _ in tqdm(range(ITERS), desc="Warmup - Torch CUDA fused"):
#     lstm_pt_fused_cuda()

In [None]:
# Warmup - Triton: 100%|██████████| 50/50 [00:00<00:00, 1141.98it/s]
# Main - Triton:   4%|▎         | 37/1000 [00:00<00:14, 67.93it/s]
# Main - Triton: 100%|██████████| 1000/1000 [00:17<00:00, 58.80it/s]