In [1]:
'''
debug diff
R: all ones
b: all ones
x: randn 
'''
import os
import torch
from torch import nn
from tqdm import tqdm

import sys

sys.path.append("..")
# sys.path.append(os.path.abspath(os.path.join(__file__, "../../..")))


from flashrnn.frameworks.cuda_alternating.gru import GRUCuda
from flashrnn.frameworks.cuda_alternating.lstm import LSTMCuda
from flashrnn.frameworks.cuda_fused.gru import GRUFused
from flashrnn.frameworks.cuda_fused.lstm import LSTMFused
from flashrnn.flashrnn import flashrnn
from flashrnn.flashrnn_fused import FlashRNNFuncGeneratorFused

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

torch.manual_seed(0)

device = "cuda"
dtype = torch.bfloat16
dtype_str = "bfloat16"
###
# Config
input_size = 1
hidden_size = 32
batch = 8
num_layers = 1
num_head = 1
num_gates = 4
requires_grad = True

file path: /mnt/second/qinhaoping/repo/flashRNN/main/flashrnn/flashrnn
file path: /mnt/second/qinhaoping/repo/flashRNN/main/flashrnn/flashrnn


In [2]:

def initialize_ref_lstm_constant(ref_lstm:nn.LSTM, value=1.0):
    with torch.no_grad():
        for name, param in ref_lstm.named_parameters():
            param.fill_(value)


def sync_from_pytorch_lstm(my_lstm: LSTMFused, ref_lstm: nn.LSTM, fused: bool):
    """
    同步 nn.LSTM 的第一层权重到自定义的 LSTMFused。
    要求：
    - my_lstm.num_heads == 1
    - my_lstm.num_layers == 1
    - ref_lstm.num_layers == 1，单向
    """
    assert my_lstm.num_heads == 1, "只能同步 num_heads == 1 的模型"
    assert my_lstm.num_layers == 1, "只能同步单层模型"
    assert (
        ref_lstm.num_layers == 1 and not ref_lstm.bidirectional
    ), "只支持同步单层单向 LSTM"

    H = my_lstm.hidden_size
    I = my_lstm.linear.in_features  # 输入维度

    with torch.no_grad():
        ref_lstm.bias_ih_l0.zero_()
        ref_lstm.bias_hh_l0.zero_()
        # ========== 1. 同步 Linear 权重 ==========
        # ref: weight_ih_l0: [4H, I]
        my_lstm.linear.weight.copy_(ref_lstm.weight_ih_l0)  # [4H, I]
        my_lstm.linear.bias.copy_(ref_lstm.bias_ih_l0 + ref_lstm.bias_hh_l0)  # [4H]

        # ========== 2. 同步 Recurrent 权重 R ==========
        weight_hh = ref_lstm.weight_hh_l0  # shape [4H, H]
        gates = torch.split(weight_hh, H, dim=0)  # 4 tensors of shape [H, H]
        stacked = torch.stack(gates, dim=0)  # [4, H, H]
        R = stacked.unsqueeze(0).permute(0, 2, 1, 3).contiguous()  # [1, H, 4, H]
        my_lstm.recurrents[0].copy_(R)

        # ========== 3. 同步 bias ==========
        if fused:
            total_bias = ref_lstm.bias_ih_l0 + ref_lstm.bias_hh_l0  # shape [4H]
            gates_b = torch.split(total_bias, H, dim=0)  # 4 tensors of shape [H]
            b_stacked = (
                torch.stack(gates_b, dim=0).unsqueeze(0).permute(0, 2, 1)
            )  # [1, H, 4]
            my_lstm.biases[0].copy_(b_stacked)
        else:
            total_bias = ref_lstm.bias_ih_l0 + ref_lstm.bias_hh_l0  # shape [4H]
            gates_b = torch.split(total_bias, H, dim=0)  # 4 tensors of shape [H]
            b_stacked = (
                torch.stack(gates_b, dim=0).unsqueeze(0).permute(0, 1, 2)
            )  # [1, H, 4]
            my_lstm.biases[0].copy_(b_stacked)

        # ========== 验证是否同步成功 ==========
        # [4H, I]
        diff_w = (my_lstm.linear.weight - ref_lstm.weight_ih_l0).abs().max()
        print(f"[Check] Linear weight max abs diff: {diff_w:.2e}")

        # [4H]
        expected_bias = ref_lstm.bias_ih_l0 + ref_lstm.bias_hh_l0
        diff_b = (my_lstm.linear.bias - expected_bias).abs().max()
        print(f"[Check] Linear bias   max abs diff: {diff_b:.2e}")

        # [1, H, 4, H] -> [4, H, H]
        R = my_lstm.recurrents[0].permute(2, 1, 3, 0).squeeze(3)  # [4, H, H]
        R_flat = torch.cat([R[i] for i in range(4)], dim=0)  # [4H, H]
        diff_R = (R_flat - ref_lstm.weight_hh_l0).abs().max()
        print(f"[Check] Recurrent weight max abs diff: {diff_R:.2e}")

        # [1, H, 4] -> [4H]
        b_my = my_lstm.biases[0].permute(2, 1, 0).reshape(-1)  # [4H]
        b_ref = ref_lstm.bias_ih_l0 + ref_lstm.bias_hh_l0
        diff_bias = (b_my - b_ref).abs().max()
        print(f"[Check] Bias max abs diff: {diff_bias:.2e}")
    print("[✓] LSTMFused 参数成功同步自 nn.LSTM。")


def max_abs_diff(a, b):
    return (a - b).abs().max().item()


def mean_abs_diff(a, b):
    return (a - b).abs().mean().item()


def max_ref_diff(a, b, eps=1e-8):
    return ((a - b).abs() / (b.abs() + eps)).max().item()


def mean_ref_diff(a, b, eps=1e-8):
    return ((a - b).abs() / (b.abs() + eps)).mean().item()



In [10]:
seq_len = 2
DH = hidden_size // num_head
NH=num_head

gate_linear = torch.nn.Linear(input_size, num_gates * hidden_size).to(
    device=device, dtype=dtype
)
with torch.no_grad():
    gate_linear.weight.fill_(1.0)
    if gate_linear.bias is not None:
        gate_linear.bias.fill_(1.0)

x = torch.randn(
    batch, seq_len, input_size, device="cuda", requires_grad=True, dtype=dtype
)
total_elems = 4*NH*DH*DH
# R=torch.cat([torch.ones((total_elems//2,),device=device,dtype=dtype),torch.full((total_elems//2,),2.0,device=device,dtype=dtype)],dim=0)
# R.requires_grad_(requires_grad)
R = torch.randn(total_elems, device=device,
    dtype=dtype,
    requires_grad=requires_grad,)  # 随机初始化一次
R=R.view(num_gates, NH, DH, DH)
print(R.shape)
print(R)
# R = torch.randn(
#     [num_gates, NH, DH, DH],
#     # [NH,DH,num_gates,DH],

#     device=device,
#     dtype=dtype,
#     requires_grad=requires_grad,
# ) 

b = torch.ones(
    [num_gates, NH, DH],
    device=device,
    dtype=dtype,
    requires_grad=requires_grad,
)
R_mtr = R.clone().to(dtype=dtype).detach().requires_grad_(requires_grad)
b_mtr = b.clone().to(dtype=dtype).detach().requires_grad_(requires_grad)

Wx = gate_linear(x)
Wx = Wx.reshape(
    Wx.shape[0], Wx.shape[1], R.shape[0], R.shape[1], R.shape[2]
)
# Wx = Wx.reshape(
#     seq_len, batch, num_head, DH, num_gates
# )
Wx_mtr = Wx.clone().to(dtype=dtype).detach().requires_grad_(requires_grad)


print(Wx.shape)



torch.Size([4, 1, 32, 32])
tensor([[[[ 2.5469, -0.7148, -0.4941,  ..., -0.6289, -0.6602,  2.0781],
          [ 1.4141, -0.3086, -0.2051,  ...,  1.6016,  0.1328,  1.0703],
          [-1.1172, -0.8398, -3.6719,  ..., -1.0234, -0.3301, -0.8633],
          ...,
          [-0.6289, -0.2295, -0.5430,  ...,  1.0391,  0.9609, -0.7891],
          [-1.6016, -0.7617, -1.4531,  ...,  0.7227, -0.1670, -0.4785],
          [-0.2812,  0.7109,  0.3965,  ..., -0.3906,  1.4766, -1.3281]]],


        [[[-1.0000, -0.9141,  0.9219,  ...,  0.9531, -0.2910, -1.3203],
          [ 0.5625,  0.9375, -0.6953,  ...,  0.4863, -1.6406, -2.7500],
          [ 0.4004,  1.4844,  0.9141,  ..., -0.6406, -0.7695, -1.3594],
          ...,
          [-1.3047,  0.7656,  0.7266,  ..., -1.2578, -2.4062,  0.2969],
          [-1.3047, -1.9141,  0.3203,  ..., -1.1406, -1.1328, -0.0272],
          [-1.1875, -0.7266, -0.8789,  ..., -0.5039, -1.1953,  0.2295]]],


        [[[ 1.6016,  0.2891,  1.9844,  ...,  0.0254, -1.5000,  1.1641],

In [11]:

# Models
ref_lstm = nn.LSTM(
    input_size,
    hidden_size,
    num_layers,
    bias=True,
    batch_first=True,
    bidirectional=False,
).to(device=device, dtype=dtype)

initialize_ref_lstm_constant(ref_lstm)

# ========== 2. 同步 Recurrent 权重 R ==========
# 转换成 [4H, H] 形式，用于赋值给 ref_lstm.weight_hh_l0
# 步骤：
R_perm = R         # [4, NH, D, D]
R_reshaped = R_perm.reshape(4, hidden_size, hidden_size)  # [4, H, D]
weight_hh = R_reshaped.reshape(4 * hidden_size, hidden_size)  # [4H, D]
# 赋值到 ref_lstm 的 recurrent 权重
ref_lstm.weight_hh_l0.data.copy_(weight_hh)
print("ref_lstm.weight_hh_l0.data:",ref_lstm.weight_hh_l0.data)


# Inputs
# x = torch.randn(batch, seq_len, input_size, device="cuda", requires_grad=False)
h0 = torch.zeros(
    num_layers, batch, hidden_size, device="cuda", requires_grad=True, dtype=dtype
)
c0 = torch.zeros(
    num_layers, batch, hidden_size, device="cuda", requires_grad=True, dtype=dtype
)

# Clone inputs for reference
x_ref = x.detach().clone().requires_grad_()
h0_ref = h0.detach().clone().requires_grad_()
c0_ref = c0.detach().clone().requires_grad_()

# Forward
out_ref, (hn_ref, cn_ref) = ref_lstm(x_ref, (h0_ref, c0_ref))      #
# out_my, last_h = flashrnn(Wx,R,b, function="lstm",backend="cuda_fused",dtype=dtype_str)
out_my, (hn_my, cn_my) = flashrnn(Wx,R,b, function="lstm",backend="cuda",dtype=dtype_str)
out_my=out_my.reshape(batch, seq_len, hidden_size)
hn_my=hn_my.reshape(num_layers, batch, hidden_size)
cn_my=cn_my.reshape(num_layers, batch, hidden_size)

# out_ref, (hn_ref, cn_ref) = ref_lstm(x_ref)
# out_my, (hn_my, cn_my) = my_lstm(x)
# print("out my: ", out_my)
# print("out ref:", out_ref)
print("out_my shape: ", out_my.shape)    #  [NS, B,T,NH,D]
print("out_ref shape: ", out_ref.shape)  # [B,T,H]
# # Backward
# loss_my = out_my.sum()
# loss_ref = out_ref.sum()
# loss_my.backward()
# loss_ref.backward()

ref_lstm.weight_hh_l0.data: tensor([[ 2.5469, -0.7148, -0.4941,  ..., -0.6289, -0.6602,  2.0781],
        [ 1.4141, -0.3086, -0.2051,  ...,  1.6016,  0.1328,  1.0703],
        [-1.1172, -0.8398, -3.6719,  ..., -1.0234, -0.3301, -0.8633],
        ...,
        [-0.4863,  0.7070, -0.4023,  ..., -0.1865,  0.7188,  0.0923],
        [-0.5508,  2.4062,  0.4395,  ...,  0.2490, -1.7109, -2.6406],
        [-0.4102,  1.1406, -0.0374,  ..., -1.6250, -0.4121, -0.0630]],
       device='cuda:0', dtype=torch.bfloat16)
function:  lstm
backend:  cuda
Wx:  torch.Size([2, 8, 1, 4, 32])
R:  torch.Size([1, 32, 4, 32])
b:  torch.Size([1, 4, 32])
input states:  torch.Size([2, 1, 8, 1, 32])
recurrent shape:  torch.Size([1, 32, 4, 32])
bias shape:  torch.Size([1, 4, 32])
Wx shape:  torch.Size([2, 8, 1, 4, 32])
R :  tensor([[[[ 2.5469,  1.4141, -1.1172,  ..., -0.6289, -1.6016, -0.2812],
          [-1.0000,  0.5625,  0.4004,  ..., -1.3047, -1.3047, -1.1875],
          [ 1.6016,  0.8359,  0.3516,  ...,  1.4453,  1

In [12]:
cn_my=cn_my.reshape(num_layers, batch, hidden_size)
print(cn_my.shape)

print(cn_my)


torch.Size([1, 8, 32])
tensor([[[ 1.5234e+00,  2.0996e-01,  8.3984e-01,  1.8594e+00,  1.8125e+00,
           8.6426e-02,  1.0469e+00,  1.7891e+00,  1.1484e+00,  1.7656e+00,
           1.3281e+00,  8.7891e-01,  1.0000e+00,  1.8047e+00,  3.6133e-01,
           1.1484e+00,  8.6328e-01,  6.0156e-01,  6.2109e-01,  9.8047e-01,
           1.7344e+00,  1.1875e+00,  1.0469e+00,  1.6641e+00,  1.5234e+00,
          -7.2266e-01,  1.8672e+00, -1.2598e-01,  1.4844e+00,  1.7031e+00,
           1.1406e+00,  7.9297e-01],
         [ 1.5000e+00,  7.9102e-02,  8.9062e-01,  1.9297e+00,  1.8672e+00,
           1.9531e-01,  1.0078e+00,  1.8359e+00,  1.1328e+00,  1.8281e+00,
           1.3125e+00,  9.4141e-01,  9.9219e-01,  1.8594e+00,  2.5781e-01,
           1.0391e+00,  9.3359e-01,  5.9766e-01,  4.9023e-01,  9.7266e-01,
           1.7734e+00,  1.1328e+00,  1.0469e+00,  1.7031e+00,  1.5156e+00,
          -7.1094e-01,  1.9375e+00, -5.5176e-02,  1.4609e+00,  1.7188e+00,
           1.0938e+00,  7.0312e-01],
   

In [13]:
print(cn_ref.shape)

print(cn_ref)


torch.Size([1, 8, 32])
tensor([[[ 1.5234e+00,  2.0996e-01,  8.3984e-01,  1.8594e+00,  1.8125e+00,
           8.6426e-02,  1.0469e+00,  1.7891e+00,  1.1484e+00,  1.7656e+00,
           1.3281e+00,  8.7891e-01,  1.0000e+00,  1.8047e+00,  3.6133e-01,
           1.1484e+00,  8.6328e-01,  6.0156e-01,  6.2109e-01,  9.8047e-01,
           1.7344e+00,  1.1875e+00,  1.0469e+00,  1.6641e+00,  1.5234e+00,
          -7.2266e-01,  1.8672e+00, -1.2598e-01,  1.4844e+00,  1.7031e+00,
           1.1406e+00,  7.9297e-01],
         [ 1.5000e+00,  7.9102e-02,  8.9062e-01,  1.9297e+00,  1.8672e+00,
           1.9531e-01,  1.0078e+00,  1.8359e+00,  1.1328e+00,  1.8281e+00,
           1.3125e+00,  9.4141e-01,  9.9219e-01,  1.8594e+00,  2.5781e-01,
           1.0391e+00,  9.3359e-01,  5.9766e-01,  4.9023e-01,  9.7266e-01,
           1.7734e+00,  1.1328e+00,  1.0469e+00,  1.7031e+00,  1.5156e+00,
          -7.1094e-01,  1.9375e+00, -5.5176e-02,  1.4609e+00,  1.7188e+00,
           1.0938e+00,  7.0312e-01],
   

In [7]:
result1 = cn_my[0, :, 0].tolist()  # list of 16 elements
result2 = cn_ref[0, :, 0].tolist()  # list of 16 elements

print(result1)
print(result2)

[1.5859375, -0.58984375, 1.828125, 1.703125, 1.515625, 1.1875, 1.890625, 1.9609375]
[1.5859375, -0.58984375, 1.828125, 1.703125, 1.515625, 1.1875, 1.890625, 1.9609375]


In [8]:

# Output comparison
print(f"[Forward] Output max abs diff: {max_abs_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     max abs diff: {max_abs_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     max abs diff: {max_abs_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output mean abs diff: {mean_abs_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     mean abs diff: {mean_abs_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     mean abs diff: {mean_abs_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output max ref diff: {max_ref_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     max ref diff: {max_ref_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     max ref diff: {max_ref_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output mean ref diff: {mean_ref_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     mean ref diff: {mean_ref_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] cn     mean ref diff: {mean_ref_diff(cn_my, cn_ref):.2e}")


[Forward] Output max abs diff: 0.00e+00
[Forward] hn     max abs diff: 0.00e+00
[Forward] cn     max abs diff: 0.00e+00
[Forward] Output mean abs diff: 0.00e+00
[Forward] hn     mean abs diff: 0.00e+00
[Forward] cn     mean abs diff: 0.00e+00
[Forward] Output max ref diff: 0.00e+00
[Forward] hn     max ref diff: 0.00e+00
[Forward] cn     max ref diff: 0.00e+00
[Forward] Output mean ref diff: 0.00e+00
[Forward] hn     mean ref diff: 0.00e+00
[Forward] cn     mean ref diff: 0.00e+00


In [9]:

# Gradients
print(f"[Grad] Input x     grad diff: {max_abs_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {mean_abs_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {max_ref_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {mean_ref_diff(x.grad, x_ref.grad):.2e}")

print(f"[Grad] h0          grad diff: {max_abs_diff(h0.grad, h0_ref.grad):.2e}")
print(f"[Grad] c0          grad diff: {max_abs_diff(c0.grad, c0_ref.grad):.2e}")

for (n1, p1), (n2, p2) in zip(
    my_lstm.named_parameters(), ref_lstm.named_parameters()
):
    if p1.grad is not None and p2.grad is not None:
        diff = max_abs_diff(p1.grad, p2.grad)
        print(f"[Grad] Param {n1:20s} grad diff: {diff:.2e}")

TypeError: unsupported operand type(s) for -: 'NoneType' and 'NoneType'