In [1]:
import os
import torch
from torch import nn
from tqdm import tqdm

import sys

sys.path.append("..")
# sys.path.append(os.path.abspath(os.path.join(__file__, "../../..")))


from flashrnn.frameworks.cuda_alternating.lstm import LSTMCuda
from flashrnn.frameworks.cuda_fused.lstm import LSTMFused


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.manual_seed(0)


device = "cuda"
dtype = torch.bfloat16
dtype_str = "bfloat16"
###
# Config
input_size = 1
hidden_size = 32
batch = 8
num_layers = 1
num_head = 1
num_gate = 4
requires_grad = True
total_elems = 4*hidden_size*hidden_size

R_g = torch.randn(total_elems, device=device,
            dtype=dtype,
            requires_grad=requires_grad,)  # 随机初始化一次

file path: /mnt/second/qinhaoping/repo/flashRNN/main/flashrnn/flashrnn
file path: /mnt/second/qinhaoping/repo/flashRNN/main/flashrnn/flashrnn


In [2]:

def initialize_ref_lstm_constant(ref_lstm, value=1.0):
    with torch.no_grad():
        for name, param in ref_lstm.named_parameters():
            param.fill_(value)


def sync_from_pytorch_lstm(my_lstm: LSTMFused, ref_lstm: nn.LSTM, fused: bool):

    """
    同步 nn.LSTM 的第一层权重到自定义的 LSTMFused。
    要求：
    - my_lstm.num_heads == 1
    - my_lstm.num_layers == 1
    - ref_lstm.num_layers == 1，单向
    """
    assert my_lstm.num_heads == 1, "只能同步 num_heads == 1 的模型"
    assert my_lstm.num_layers == 1, "只能同步单层模型"
    assert (
        ref_lstm.num_layers == 1 and not ref_lstm.bidirectional
    ), "只支持同步单层单向 LSTM"

    H = my_lstm.hidden_size
    I = my_lstm.linear.in_features  # 输入维度
    # 初始化 ref_lstm 的 recurrent 权重（保持 nn.Parameter 类型）
    with torch.no_grad():


        # R = R=torch.cat([torch.ones((total_elems//2,),device=device,dtype=dtype),torch.full((total_elems//2,),2.0,device=device,dtype=dtype)],dim=0)

        ref_lstm.weight_hh_l0.data = R_g.view(4*H, H)
        # ref_lstm.weight_hh_l0.data = torch.randn(
        #     4 * H, H, device=device, dtype=dtype, requires_grad=requires_grad
        # )
        print("R:",R_g)


    with torch.no_grad():
        # ref_lstm.bias_ih_l0.zero_()
        # ref_lstm.bias_hh_l0.zero_()
        # ========== 1. 同步 Linear 权重 ==========
        # ref: weight_ih_l0: [4H, I]
        my_lstm.linear.weight.copy_(ref_lstm.weight_ih_l0)  # [4H, I]
        my_lstm.linear.bias.copy_(ref_lstm.bias_ih_l0)  # [4H]

        # ========== 2. 同步 Recurrent 权重 R ==========
        weight_hh = ref_lstm.weight_hh_l0  # shape [4H, H]
        gates = torch.split(weight_hh, H, dim=0)  # 4 tensors of shape [H, H]
        stacked = torch.stack(gates, dim=0)  # [4, H, H]
        R = stacked.unsqueeze(0).permute(0, 2, 1, 3).contiguous()  # [1, H, 4, H]
        my_lstm.recurrents[0].copy_(R)
        print(R)

        # ========== 3. 同步 bias ==========
        if fused:
            total_bias = ref_lstm.bias_hh_l0  # shape [4H]
            gates_b = torch.split(total_bias, H, dim=0)  # 4 tensors of shape [H]
            b_stacked = (
                torch.stack(gates_b, dim=0).unsqueeze(0).permute(0, 2, 1)
            )  # [1, H, 4]
            my_lstm.biases[0].copy_(b_stacked)
        else:
            total_bias =  ref_lstm.bias_hh_l0  # shape [4H]
            gates_b = torch.split(total_bias, H, dim=0)  # 4 tensors of shape [H]
            b_stacked = (
                torch.stack(gates_b, dim=0).unsqueeze(0).permute(0, 1, 2)
            )  # [1, H, 4]
            my_lstm.biases[0].copy_(b_stacked)

        # ========== 验证是否同步成功 ==========
        # [4H, I]
        diff_w = (my_lstm.linear.weight - ref_lstm.weight_ih_l0).abs().max()
        print(f"[Check] Linear weight max abs diff: {diff_w:.2e}")

        # [4H]
        expected_bias = ref_lstm.bias_ih_l0
        diff_b = (my_lstm.linear.bias - expected_bias).abs().max()
        print(f"[Check] Linear bias   max abs diff: {diff_b:.2e}")

        # [1, H, 4, H] -> [4, H, H]
        R = my_lstm.recurrents[0].permute(2, 1, 3, 0).squeeze(3)  # [4, H, H]
        R_flat = torch.cat([R[i] for i in range(4)], dim=0)  # [4H, H]
        diff_R = (R_flat - ref_lstm.weight_hh_l0).abs().max()
        print(f"[Check] Recurrent weight max abs diff: {diff_R:.2e}")

        # [1, H, 4] -> [4H]
        b_my = my_lstm.biases[0].permute(2, 1, 0).reshape(-1)  # [4H]
        b_ref = ref_lstm.bias_hh_l0
        diff_bias = (b_my - b_ref).abs().max()
        print(f"[Check] Bias max abs diff: {diff_bias:.2e}")
    print("[✓] LSTMFused 参数成功同步自 nn.LSTM。")


def max_abs_diff(a, b):
    return (a - b).abs().max().item()


def mean_abs_diff(a, b):
    return (a - b).abs().mean().item()


def max_ref_diff(a, b, eps=1e-8):
    return ((a - b).abs() / (b.abs() + eps)).max().item()


def mean_ref_diff(a, b, eps=1e-8):
    return ((a - b).abs() / (b.abs() + eps)).mean().item()

def print_lstm_all_params(lstm: nn.LSTM):
    with torch.no_grad():
        for name, param in lstm.named_parameters():
            print("name :",name)
            print("param:",param)


In [3]:
seq_len = 2


# Models
ref_lstm = nn.LSTM(
    input_size,
    hidden_size,
    num_layers,
    bias=True,
    batch_first=True,
    bidirectional=False,
).to(device=device, dtype=dtype)
my_lstm = LSTMCuda(input_size, hidden_size, num_layers).to(
    device=device, dtype=dtype
)
fused = False
initialize_ref_lstm_constant(ref_lstm)
# print_lstm_all_params(ref_lstm)
sync_from_pytorch_lstm(my_lstm, ref_lstm, fused)  # 同步权重

# Inputs
# x = torch.randn(batch, seq_len, input_size, device="cuda", requires_grad=False)
x = torch.randn(
    batch, seq_len, input_size, device="cuda", requires_grad=True, dtype=dtype
)
h0 = torch.zeros(
    num_layers, batch, hidden_size, device="cuda", requires_grad=True, dtype=dtype
)
c0 = torch.zeros(
    num_layers, batch, hidden_size, device="cuda", requires_grad=True, dtype=dtype
)

# Clone inputs for reference
x_ref = x.detach().clone().requires_grad_()
h0_ref = h0.detach().clone().requires_grad_()
c0_ref = c0.detach().clone().requires_grad_()

# Forward
# out_ref: [B,T,H]
# hn_ref \ cn_ref: [num_layers,B,H]
out_ref, (hn_ref, cn_ref) = ref_lstm(x_ref, (h0_ref, c0_ref))
out_my, (hn_my, cn_my) = my_lstm(x, (h0, c0))
# out_ref, (hn_ref, cn_ref) = ref_lstm(x_ref)
# out_my, (hn_my, cn_my) = my_lstm(x)
# print("out my: ", out_my)
# print("out ref:", out_ref)
print("out_my shape: ", out_my.shape)
print("out_ref shape: ", out_ref.shape)
# Backward
loss_my = out_my.sum()
loss_ref = out_ref.sum()
loss_my.backward()
loss_ref.backward()



R: tensor([-0.9258, -0.4258, -2.6406,  ...,  1.6016, -1.0703,  1.6016],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
tensor([[[[-0.9258, -0.4258, -2.6406,  ..., -0.8359,  1.3516, -0.2871],
          [-1.0312, -0.8906, -0.1914,  ..., -1.3125,  0.5312, -0.7266],
          [ 1.4922, -1.5703,  0.3320,  ...,  0.0315,  0.1030, -1.0156],
          [-0.1016, -0.3477,  0.6562,  ..., -2.2344, -0.3496,  0.9023]],

         [[-0.5977, -0.3281, -0.9102,  ...,  0.1211,  0.4727, -1.0859],
          [ 0.0554, -0.6758,  1.2969,  ...,  0.8711, -0.4844,  0.1631],
          [-1.1562,  1.7109,  0.4316,  ...,  1.0703,  1.0625,  0.9023],
          [ 0.7773,  0.2119, -0.6133,  ...,  1.3281, -0.8867,  1.1641]],

         [[-0.0334, -0.9727,  0.9570,  ...,  0.5234,  1.1719, -0.9570],
          [ 0.1157, -0.8398, -1.0078,  ...,  0.0693, -0.3086, -0.0967],
          [ 0.7930,  0.2129,  0.8047,  ...,  0.0074, -0.9102, -0.0879],
          [ 0.1953, -0.8555, -1.0312,  ..., -0.1104, -0.6406, -0.2

Using /home/qinhaoping/.cache/torch_extensions/py311_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/qinhaoping/.cache/torch_extensions/py311_cu118/lstm_HS32BS8NH1NS2DbDBbDRbDWbDGbDSbDAbNGR4NGW4NGI4NGT4SA1UDB1GRCV0GR/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module lstm_HS32BS8NH1NS2DbDBbDRbDWbDGbDSbDAbNGR4NGW4NGI4NGT4SA1UDB1GRCV0GR...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.
out_my shape:  torch.Size([8, 2, 32])
out_ref shape:  torch.Size([8, 2, 32])
grads: [torch.Size([2, 8, 1, 4, 32]), torch.Size([2, 8, 1, 32]), torch.Size([1, 32, 4, 32]), torch.Size([1, 4, 32])]


Loading extension module lstm_HS32BS8NH1NS2DbDBbDRbDWbDGbDSbDAbNGR4NGW4NGI4NGT4SA1UDB1GRCV0GR...


In [4]:
print(cn_ref)
print(cn_ref.shape)

tensor([[[ 0.2812,  1.3750, -0.3105,  0.2539,  0.6758,  0.5000,  0.8711,
          -0.0640,  1.6172,  0.1157,  1.1797,  0.0364, -0.5430,  1.4844,
          -0.4043,  0.7383,  1.3672,  1.1484,  0.8438,  0.0378,  0.1289,
           1.5625,  0.7305,  1.2188, -0.4492,  0.9258, -0.2324,  1.0312,
          -0.7500,  0.5625,  1.5859,  1.7578],
         [ 0.2578,  1.3281, -0.4453,  0.3945,  0.5312,  0.5938,  0.9258,
          -0.0072,  1.6406,  0.0222,  1.1328,  0.0179, -0.4844,  1.4844,
          -0.4707,  0.8477,  1.3203,  1.1016,  0.9219, -0.2422,  0.2236,
           1.5781,  0.6250,  1.1406, -0.4766,  0.8203, -0.1992,  1.0156,
          -0.8047,  0.4609,  1.5938,  1.8203],
         [ 1.1406,  1.7812,  0.6211,  0.1167,  1.2422,  0.5469,  1.0078,
          -0.0192,  1.8828,  0.5664,  1.4141,  0.0713, -0.6719,  1.8125,
           0.1934,  0.8828,  1.6797,  1.3516,  0.9414,  0.8281,  0.2100,
           1.8672,  1.2969,  1.6719, -0.1875,  1.4844, -0.0664,  1.4609,
          -0.4746,  0.8945,  1

In [5]:
print(cn_my)
print(cn_my.shape)

tensor([[[ 0.2812,  1.3750, -0.3105,  0.2539,  0.6758,  0.5000,  0.8711,
          -0.0640,  1.6172,  0.1157,  1.1797,  0.0364, -0.5430,  1.4844,
          -0.4043,  0.7383,  1.3672,  1.1484,  0.8438,  0.0378,  0.1289,
           1.5625,  0.7305,  1.2188, -0.4492,  0.9258, -0.2324,  1.0312,
          -0.7500,  0.5625,  1.5859,  1.7578],
         [ 0.2578,  1.3281, -0.4453,  0.3945,  0.5312,  0.5938,  0.9258,
          -0.0072,  1.6406,  0.0222,  1.1328,  0.0179, -0.4844,  1.4844,
          -0.4707,  0.8477,  1.3203,  1.1016,  0.9219, -0.2422,  0.2236,
           1.5781,  0.6250,  1.1406, -0.4766,  0.8203, -0.1992,  1.0156,
          -0.8047,  0.4609,  1.5938,  1.8203],
         [ 1.1406,  1.7812,  0.6211,  0.1167,  1.2422,  0.5469,  1.0078,
          -0.0192,  1.8828,  0.5664,  1.4141,  0.0713, -0.6719,  1.8125,
           0.1934,  0.8828,  1.6797,  1.3516,  0.9414,  0.8281,  0.2100,
           1.8672,  1.2969,  1.6719, -0.1875,  1.4844, -0.0664,  1.4609,
          -0.4746,  0.8945,  1

In [6]:
result1 = cn_my[0, :, 0].tolist()  # list of 16 elements
result2 = cn_ref[0, :, 0].tolist()  # list of 16 elements

print(result1)
print(result2)

[0.28125, 0.2578125, 1.140625, 0.1982421875, 1.21875, 0.70703125, 0.84765625, 0.734375]
[0.28125, 0.2578125, 1.140625, 0.1953125, 1.21875, 0.70703125, 0.84375, 0.734375]


In [7]:

# Output comparison
print(f"[Forward] Output max abs diff: {max_abs_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     max abs diff: {max_abs_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     max abs diff: {max_abs_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output mean abs diff: {mean_abs_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     mean abs diff: {mean_abs_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     mean abs diff: {mean_abs_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output max ref diff: {max_ref_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     max ref diff: {max_ref_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     max ref diff: {max_ref_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output mean ref diff: {mean_ref_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     mean ref diff: {mean_ref_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     mean ref diff: {mean_ref_diff(cn_my, cn_ref):.2e}")


[Forward] Output max abs diff: 3.91e-03
[Forward] hn     max abs diff: 3.91e-03
[Forward] cn     max abs diff: 7.81e-03
[Forward] Output mean abs diff: 1.05e-04
[Forward] hn     mean abs diff: 2.10e-04
[Forward] cn     mean abs diff: 4.86e-04
[Forward] Output max ref diff: 2.81e-01
[Forward] hn     max ref diff: 2.81e-01
[Forward] cn     max ref diff: 2.73e-01
[Forward] Output mean ref diff: 1.14e-03
[Forward] hn     mean ref diff: 2.29e-03
[Forward] cn     mean ref diff: 2.24e-03


In [8]:

# Gradients
print(f"[Grad] Input x     grad diff: {max_abs_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {mean_abs_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {max_ref_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {mean_ref_diff(x.grad, x_ref.grad):.2e}")

print(f"[Grad] h0          grad diff: {max_abs_diff(h0.grad, h0_ref.grad):.2e}")
print(f"[Grad] c0          grad diff: {max_abs_diff(c0.grad, c0_ref.grad):.2e}")

for (n1, p1), (n2, p2) in zip(
    my_lstm.named_parameters(), ref_lstm.named_parameters()
):
    if p1.grad is not None and p2.grad is not None:
        diff = max_abs_diff(p1.grad, p2.grad)
        print(f"[Grad] Param {n1:20s} grad diff: {diff:.2e}")

[Grad] Input x     grad diff: 3.12e-02
[Grad] Input x     grad diff: 2.44e-03
[Grad] Input x     grad diff: 6.96e-03
[Grad] Input x     grad diff: 6.79e-04
[Grad] h0          grad diff: 1.56e-02
[Grad] c0          grad diff: 1.56e-02
[Grad] Param linear.weight        grad diff: 3.91e-03


RuntimeError: The size of tensor a (128) must match the size of tensor b (32) at non-singleton dimension 1

In [None]:
from flashrnn.flashrnn import flashrnn


In [None]:
seq_len = 2
DH = hidden_size // num_head
NH=num_head

gate_linear = torch.nn.Linear(input_size, num_gate * hidden_size).to(
    device=device, dtype=dtype
)
with torch.no_grad():
    gate_linear.weight.fill_(1.0)
    if gate_linear.bias is not None:
        gate_linear.bias.fill_(1.0)

x = torch.randn(
    batch, seq_len, input_size, device="cuda", requires_grad=True, dtype=dtype
)
total_elems = 4*NH*DH*DH
# R=torch.cat([torch.ones((total_elems//2,),device=device,dtype=dtype),torch.full((total_elems//2,),2.0,device=device,dtype=dtype)],dim=0)
# R.requires_grad_(requires_grad)
print("R_g:",R_g)
R=R_g.view(num_gate, NH, DH, DH)
print(R.shape)
print(R)
# R = torch.randn(
#     [num_gate, NH, DH, DH],
#     # [NH,DH,num_gate,DH],

#     device=device,
#     dtype=dtype,
#     requires_grad=requires_grad,
# ) 

b = torch.ones(
    [num_gate, NH, DH],
    device=device,
    dtype=dtype,
    requires_grad=requires_grad,
)
R_mtr = R.clone().to(dtype=dtype).detach().requires_grad_(requires_grad)
b_mtr = b.clone().to(dtype=dtype).detach().requires_grad_(requires_grad)

Wx = gate_linear(x)
Wx = Wx.reshape(
    Wx.shape[0], Wx.shape[1], R.shape[0], R.shape[1], R.shape[2]
)
# Wx = Wx.reshape(
#     seq_len, batch, num_head, DH, num_gate
# )
Wx_mtr = Wx.clone().to(dtype=dtype).detach().requires_grad_(requires_grad)


print(Wx.shape)
r=R
r=r.permute(1,3,0,2)
print(r)
print(r.shape)

In [None]:
# Models
ref_lstm = nn.LSTM(
    input_size,
    hidden_size,
    num_layers,
    bias=True,
    batch_first=True,
    bidirectional=False,
).to(device=device, dtype=dtype)

initialize_ref_lstm_constant(ref_lstm)

# ========== 2. 同步 Recurrent 权重 R ==========
# 转换成 [4H, H] 形式，用于赋值给 ref_lstm.weight_hh_l0
# 步骤：
R_perm = R         # [4, NH, D, D]
R_reshaped = R_perm.reshape(4, hidden_size, hidden_size)  # [4, H, D]
weight_hh = R_reshaped.reshape(4 * hidden_size, hidden_size)  # [4H, D]
# 赋值到 ref_lstm 的 recurrent 权重
ref_lstm.weight_hh_l0.data.copy_(weight_hh)
print("ref_lstm.weight_hh_l0.data:",ref_lstm.weight_hh_l0.data)


# Inputs
# x = torch.randn(batch, seq_len, input_size, device="cuda", requires_grad=False)
h0 = torch.zeros(
    num_layers, batch, hidden_size, device="cuda", requires_grad=True, dtype=dtype
)
c0 = torch.zeros(
    num_layers, batch, hidden_size, device="cuda", requires_grad=True, dtype=dtype
)

# Clone inputs for reference
x_ref = x.detach().clone().requires_grad_()
h0_ref = h0.detach().clone().requires_grad_()
c0_ref = c0.detach().clone().requires_grad_()

# Forward
out_ref, (hn_ref, cn_ref) = ref_lstm(x_ref, (h0_ref, c0_ref))      #
# out_my, last_h = flashrnn(Wx,R,b, function="lstm",backend="cuda_fused",dtype=dtype_str)
out_my, (hn_my, cn_my) = flashrnn(Wx,R,b, function="lstm",backend="cuda",dtype=dtype_str)
out_my=out_my.reshape(batch, seq_len, hidden_size)
hn_my=hn_my.reshape(num_layers, batch, hidden_size)
cn_my=cn_my.reshape(num_layers, batch, hidden_size)

# out_ref, (hn_ref, cn_ref) = ref_lstm(x_ref)
# out_my, (hn_my, cn_my) = my_lstm(x)
# print("out my: ", out_my)
# print("out ref:", out_ref)
print("out_my shape: ", out_my.shape)    #  [NS, B,T,NH,D]
print("out_ref shape: ", out_ref.shape)  # [B,T,H]
# # Backward
# loss_my = out_my.sum()
# loss_ref = out_ref.sum()
# loss_my.backward()
# loss_ref.backward()

In [None]:
cn_my=cn_my.reshape(num_layers, batch, hidden_size)
print(cn_my.shape)

print(cn_my)

In [None]:
result1 = cn_my[0, :, 0].tolist()  # list of 16 elements
result2 = cn_ref[0, :, 0].tolist()  # list of 16 elements

print(result1)
print(result2)

In [None]:
# Output comparison
print(f"[Forward] Output max abs diff: {max_abs_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     max abs diff: {max_abs_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     max abs diff: {max_abs_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output mean abs diff: {mean_abs_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     mean abs diff: {mean_abs_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     mean abs diff: {mean_abs_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output max ref diff: {max_ref_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     max ref diff: {max_ref_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     max ref diff: {max_ref_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output mean ref diff: {mean_ref_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     mean ref diff: {mean_ref_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     mean ref diff: {mean_ref_diff(cn_my, cn_ref):.2e}")