In [73]:
import os
import torch
from torch import nn
from tqdm import tqdm

import sys

sys.path.append("..")
# sys.path.append(os.path.abspath(os.path.join(__file__, "../../..")))


from flashrnn.frameworks.cuda_alternating.gru import GRUCuda
from flashrnn.frameworks.cuda_fused.gru import GRUFused


os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.manual_seed(0)


device = "cuda"
dtype = torch.bfloat16
dtype_str = "bfloat16"
###
# Config
input_size = 1
hidden_size = 64
batch = 16
num_layers = 1
num_head = 1
num_gate = 3
requires_grad = True
total_elems = 3*hidden_size*hidden_size

R_g = torch.randn(total_elems, device=device,
            dtype=dtype,
            requires_grad=requires_grad,)  # 随机初始化一次
# R_g = torch.cat([torch.ones((total_elems//2,),device=device,dtype=dtype),torch.full((total_elems//2,),2.0,device=device,dtype=dtype)],dim=0)
# R_g = torch.ones(total_elems, device=device,
#             dtype=dtype,
#             requires_grad=requires_grad,)  # 随机初始化一次

In [13]:

def initialize_ref_lstm_constant(ref_lstm, value=1.0):
    with torch.no_grad():
        for name, param in ref_lstm.named_parameters():
            param.fill_(value)


def sync_from_pytorch_lstm(my_lstm, ref_lstm: nn.GRU, fused: bool):

    """
    同步 nn.LSTM 的第一层权重到自定义的 LSTMFused。
    要求：
    - my_lstm.num_heads == 1
    - my_lstm.num_layers == 1
    - ref_lstm.num_layers == 1，单向
    """
    assert my_lstm.num_heads == 1, "只能同步 num_heads == 1 的模型"
    assert my_lstm.num_layers == 1, "只能同步单层模型"
    assert (
        ref_lstm.num_layers == 1 and not ref_lstm.bidirectional
    ), "只支持同步单层单向 LSTM"

    H = my_lstm.hidden_size
    I = my_lstm.linear.in_features  # 输入维度
    # 初始化 ref_lstm 的 recurrent 权重（保持 nn.Parameter 类型）
    with torch.no_grad():


        # R = R=torch.cat([torch.ones((total_elems//2,),device=device,dtype=dtype),torch.full((total_elems//2,),2.0,device=device,dtype=dtype)],dim=0)

        ref_lstm.weight_hh_l0.data = R_g.view(3*H, H)
        # ref_lstm.weight_hh_l0.data = torch.randn(
        #     4 * H, H, device=device, dtype=dtype, requires_grad=requires_grad
        # )
        print("R:",R_g)


    with torch.no_grad():
        # ref_lstm.bias_ih_l0.zero_()
        # ref_lstm.bias_hh_l0.zero_()
        # ========== 1. 同步 Linear 权重 ==========
        # ref: weight_ih_l0: [4H, I]
        my_lstm.linear.weight.copy_(ref_lstm.weight_ih_l0)  # [4H, I]
        my_lstm.linear.bias.copy_(ref_lstm.bias_ih_l0)  # [4H]

        # ========== 2. 同步 Recurrent 权重 R ==========
        weight_hh = ref_lstm.weight_hh_l0  # shape [4H, H]
        gates = torch.split(weight_hh, H, dim=0)  # 4 tensors of shape [H, H]
        stacked = torch.stack(gates, dim=0)  # [4, H, H]
        R = stacked.unsqueeze(0).permute(0, 2, 1, 3).contiguous()  # [1, H, 4, H]
        my_lstm.recurrents[0].copy_(R)
        print(R)

        # ========== 3. 同步 bias ==========
        if fused:
            total_bias = ref_lstm.bias_hh_l0  # shape [4H]
            gates_b = torch.split(total_bias, H, dim=0)  # 4 tensors of shape [H]
            b_stacked = (
                torch.stack(gates_b, dim=0).unsqueeze(0).permute(0, 2, 1)
            )  # [1, H, 4]
            my_lstm.biases[0].copy_(b_stacked)
        else:
            total_bias =  ref_lstm.bias_hh_l0  # shape [4H]
            gates_b = torch.split(total_bias, H, dim=0)  # 4 tensors of shape [H]
            b_stacked = (
                torch.stack(gates_b, dim=0).unsqueeze(0).permute(0, 1, 2)
            )  # [1, H, 4]
            my_lstm.biases[0].copy_(b_stacked)

        # ========== 验证是否同步成功 ==========
        # [4H, I]
        diff_w = (my_lstm.linear.weight - ref_lstm.weight_ih_l0).abs().max()
        print(f"[Check] Linear weight max abs diff: {diff_w:.2e}")

        # [4H]
        expected_bias = ref_lstm.bias_ih_l0
        diff_b = (my_lstm.linear.bias - expected_bias).abs().max()
        print(f"[Check] Linear bias   max abs diff: {diff_b:.2e}")

        # [1, H, 4, H] -> [4, H, H]
        R = my_lstm.recurrents[0].permute(2, 1, 3, 0).squeeze(3)  # [4, H, H]
        R_flat = torch.cat([R[i] for i in range(3)], dim=0)  # [4H, H]
        diff_R = (R_flat - ref_lstm.weight_hh_l0).abs().max()
        print(f"[Check] Recurrent weight max abs diff: {diff_R:.2e}")

        # [1, H, 4] -> [4H]
        b_my = my_lstm.biases[0].permute(2, 1, 0).reshape(-1)  # [4H]
        b_ref = ref_lstm.bias_hh_l0
        diff_bias = (b_my - b_ref).abs().max()
        print(f"[Check] Bias max abs diff: {diff_bias:.2e}")
    print("[✓] LSTMFused 参数成功同步自 nn.LSTM。")


def max_abs_diff(a, b):
    return (a - b).abs().max().item()


def mean_abs_diff(a, b):
    return (a - b).abs().mean().item()


def max_ref_diff(a, b, eps=1e-8):
    return ((a - b).abs() / (b.abs() + eps)).max().item()


def mean_ref_diff(a, b, eps=1e-8):
    return ((a - b).abs() / (b.abs() + eps)).mean().item()

def print_lstm_all_params(lstm: nn.LSTM):
    with torch.no_grad():
        for name, param in lstm.named_parameters():
            print("name :",name)
            print("param:",param)


In [3]:
seq_len = 2


# Models
ref_lstm = nn.GRU(
    input_size,
    hidden_size,
    num_layers,
    bias=True,
    batch_first=True,
    bidirectional=False,
).to(device=device, dtype=dtype)
my_lstm = GRUCuda(input_size, hidden_size, num_layers).to(
    device=device, dtype=dtype
)
fused = False
initialize_ref_lstm_constant(ref_lstm)
# print_lstm_all_params(ref_lstm)
sync_from_pytorch_lstm(my_lstm, ref_lstm, fused)  # 同步权重

# Inputs
# x = torch.randn(batch, seq_len, input_size, device="cuda", requires_grad=False)
x = torch.randn(
    batch, seq_len, input_size, device="cuda", requires_grad=True, dtype=dtype
)
h0 = torch.zeros(
    num_layers, batch, hidden_size, device="cuda", requires_grad=True, dtype=dtype
)


# Clone inputs for reference
x_ref = x.detach().clone().requires_grad_()
h0_ref = h0.detach().clone().requires_grad_()

# Forward
# out_ref: [B,T,H]
# hn_ref \ cn_ref: [num_layers,B,H]
out_ref, hn_ref = ref_lstm(x_ref, (h0_ref))
out_my, hn_my = my_lstm(x, (h0))
# out_ref, (hn_ref, cn_ref) = ref_lstm(x_ref)
# out_my, (hn_my, cn_my) = my_lstm(x)
# print("out my: ", out_my)
# print("out ref:", out_ref)
print("out_my shape: ", out_my.shape)
print("out_ref shape: ", out_ref.shape)
# Backward
loss_my = out_my.sum()
loss_ref = out_ref.sum()
loss_my.backward()
loss_ref.backward()



R: tensor([-0.9258, -0.4258, -2.6406,  ..., -0.6602,  2.3125,  0.4824],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
tensor([[[[-0.9258, -0.4258, -2.6406,  ..., -0.8359,  1.3516, -0.2871],
          [-1.0312, -0.8906, -0.1914,  ..., -1.3125,  0.5312, -0.7266],
          [ 1.4922, -1.5703,  0.3320,  ...,  0.0315,  0.1030, -1.0156]],

         [[-0.5977, -0.3281, -0.9102,  ...,  0.1211,  0.4727, -1.0859],
          [ 0.0554, -0.6758,  1.2969,  ...,  0.8711, -0.4844,  0.1631],
          [-1.1562,  1.7109,  0.4316,  ...,  1.0703,  1.0625,  0.9023]],

         [[-0.0334, -0.9727,  0.9570,  ...,  0.5234,  1.1719, -0.9570],
          [ 0.1157, -0.8398, -1.0078,  ...,  0.0693, -0.3086, -0.0967],
          [ 0.7930,  0.2129,  0.8047,  ...,  0.0074, -0.9102, -0.0879]],

         ...,

         [[ 1.5859,  0.7383, -0.2578,  ..., -2.1719, -1.0391, -0.9648],
          [ 0.5156, -0.2520,  0.5117,  ..., -0.6406,  0.5039,  0.0457],
          [ 1.8750,  0.2891, -0.7227,  ..., -2.82

Using /home/qinhaoping/.cache/torch_extensions/py311_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/qinhaoping/.cache/torch_extensions/py311_cu118/gru_HS32BS8NH1NS1DbDBbDRbDWbDGbDSbDAbNGR3NGW3NGI4NGT4SA0UDB1GRCV0GR/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module gru_HS32BS8NH1NS1DbDBbDRbDWbDGbDSbDAbNGR3NGW3NGI4NGT4SA0UDB1GRCV0GR...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.
out_my shape:  torch.Size([8, 2, 32])
out_ref shape:  torch.Size([8, 2, 32])


Loading extension module gru_HS32BS8NH1NS1DbDBbDRbDWbDGbDSbDAbNGR3NGW3NGI4NGT4SA0UDB1GRCV0GR...


In [5]:
print(hn_ref)
print(hn_ref.shape)

tensor([[[0.2188, 0.2832, 0.2852, 0.1494, 0.2988, 0.2031, 0.2070, 0.1582,
          0.1777, 0.3145, 0.3203, 0.3926, 0.2988, 0.1807, 0.2324, 0.1836,
          0.2949, 0.3379, 0.1631, 0.4766, 0.1592, 0.2275, 0.2559, 0.2773,
          0.1660, 0.2617, 0.1836, 0.2236, 0.2217, 0.3438, 0.2598, 0.2402],
         [0.1953, 0.2441, 0.2314, 0.1484, 0.2471, 0.1846, 0.1904, 0.1650,
          0.1738, 0.2461, 0.2598, 0.2910, 0.2402, 0.1768, 0.2051, 0.1719,
          0.2490, 0.2793, 0.1562, 0.3242, 0.1543, 0.2109, 0.2178, 0.2314,
          0.1729, 0.2236, 0.1768, 0.2002, 0.2051, 0.2637, 0.2285, 0.2188],
         [0.0840, 0.0869, 0.0938, 0.0752, 0.0898, 0.0840, 0.0801, 0.0854,
          0.0737, 0.0967, 0.0918, 0.1011, 0.1001, 0.0737, 0.0869, 0.0791,
          0.0894, 0.0928, 0.0786, 0.1177, 0.0742, 0.0806, 0.0874, 0.0884,
          0.0869, 0.0869, 0.0840, 0.0825, 0.0903, 0.0967, 0.0850, 0.0820],
         [0.1953, 0.3301, 0.2471, 0.1157, 0.3281, 0.1572, 0.2021, 0.0718,
          0.1914, 0.2793, 0.3672, 0

In [6]:
print(hn_my)
print(hn_my.shape)

tensor([[[0.2793, 0.1689, 0.2715, 0.3477, 0.1797, 0.2812, 0.1787, 0.3887,
          0.2168, 0.2305, 0.2041, 0.1475, 0.2891, 0.1963, 0.3320, 0.2559,
          0.2227, 0.1562, 0.2227, 0.2432, 0.3203, 0.1816, 0.2217, 0.2305,
          0.4141, 0.2080, 0.3730, 0.2217, 0.3887, 0.2090, 0.1934, 0.1934],
         [0.2305, 0.1650, 0.2246, 0.2656, 0.1699, 0.2295, 0.1729, 0.2812,
          0.1982, 0.2002, 0.1934, 0.1436, 0.2324, 0.1846, 0.2656, 0.2188,
          0.2129, 0.1572, 0.2100, 0.2100, 0.2539, 0.1738, 0.1963, 0.2051,
          0.2988, 0.1885, 0.2793, 0.1973, 0.2988, 0.1885, 0.1836, 0.1846],
         [0.0908, 0.0728, 0.0903, 0.0991, 0.0771, 0.0942, 0.0825, 0.1045,
          0.0806, 0.0874, 0.0776, 0.0757, 0.0947, 0.0781, 0.0933, 0.0918,
          0.0791, 0.0688, 0.0967, 0.0884, 0.0942, 0.0752, 0.0850, 0.0830,
          0.1040, 0.0820, 0.0991, 0.0845, 0.0986, 0.0820, 0.0767, 0.0762],
         [0.2715, 0.1738, 0.2500, 0.3418, 0.1660, 0.2314, 0.1260, 0.3574,
          0.2285, 0.1865, 0.2227, 0

In [7]:
result1 = hn_my[0, :, 0].tolist()  # list of 16 elements
result2 = hn_ref[0, :, 0].tolist()  # list of 16 elements

print(result1)
print(result2)

[0.279296875, 0.23046875, 0.0908203125, 0.271484375, 0.1162109375, 0.349609375, 0.16796875, 0.333984375]
[0.21875, 0.1953125, 0.083984375, 0.1953125, 0.10498046875, 0.2578125, 0.1455078125, 0.255859375]


In [9]:

# Output comparison
print(f"[Forward] Output max abs diff: {max_abs_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     max abs diff: {max_abs_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] Output mean abs diff: {mean_abs_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     mean abs diff: {mean_abs_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] Output max ref diff: {max_ref_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     max ref diff: {max_ref_diff(hn_my, hn_ref):.2e}")
print(f"[Forward] Output mean ref diff: {mean_ref_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     mean ref diff: {mean_ref_diff(hn_my, hn_ref):.2e}")


[Forward] Output max abs diff: 4.49e-01
[Forward] hn     max abs diff: 4.49e-01
[Forward] Output mean abs diff: 4.03e-02
[Forward] hn     mean abs diff: 8.06e-02
[Forward] Output max ref diff: 4.69e+00
[Forward] hn     max ref diff: 4.69e+00
[Forward] Output mean ref diff: 2.01e-01
[Forward] hn     mean ref diff: 4.00e-01


In [None]:

# Gradients
print(f"[Grad] Input x     grad diff: {max_abs_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {mean_abs_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {max_ref_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {mean_ref_diff(x.grad, x_ref.grad):.2e}")

print(f"[Grad] h0          grad diff: {max_abs_diff(h0.grad, h0_ref.grad):.2e}")

for (n1, p1), (n2, p2) in zip(
    my_lstm.named_parameters(), ref_lstm.named_parameters()
):
    if p1.grad is not None and p2.grad is not None:
        diff = max_abs_diff(p1.grad, p2.grad)
        print(f"[Grad] Param {n1:20s} grad diff: {diff:.2e}")

In [3]:
from flashrnn.flashrnn import flashrnn


In [82]:
seq_len = 1024
DH = hidden_size // num_head
NH=num_head

gate_linear = torch.nn.Linear(input_size, num_gate * hidden_size).to(
    device=device, dtype=dtype
)
with torch.no_grad():
    gate_linear.weight.fill_(1.0)
    if gate_linear.bias is not None:
        gate_linear.bias.fill_(1.0)
    # gate_linear.weight.fill_(0.0)
    # if gate_linear.bias is not None:
    #     gate_linear.bias.fill_(0.0)

x = torch.randn(
    batch, seq_len, input_size, device="cuda", requires_grad=True, dtype=dtype
)
# R=torch.cat([torch.ones((total_elems//2,),device=device,dtype=dtype),torch.full((total_elems//2,),2.0,device=device,dtype=dtype)],dim=0)
# R.requires_grad_(requires_grad)
print("R_g:",R_g)
R=R_g.view(num_gate, NH, DH, DH)
print(R.shape)
print(R)
# R = torch.randn(
#     [num_gate, NH, DH, DH],
#     # [NH,DH,num_gate,DH],

#     device=device,
#     dtype=dtype,
#     requires_grad=requires_grad,
# ) 

b = torch.ones(
    [num_gate, NH, DH],
    device=device,
    dtype=dtype,
    requires_grad=requires_grad,
)
R_mtr = R.clone().to(dtype=dtype).detach().requires_grad_(requires_grad)
b_mtr = b.clone().to(dtype=dtype).detach().requires_grad_(requires_grad)

Wx = gate_linear(x)
Wx = Wx.reshape(
    Wx.shape[0], Wx.shape[1], R.shape[0], R.shape[1], R.shape[2]
)
# Wx = Wx.reshape(
#     seq_len, batch, num_head, DH, num_gate
# )
Wx_mtr = Wx.clone().to(dtype=dtype).detach().requires_grad_(requires_grad)


print(Wx.shape)
r=R
r=r.permute(1,3,0,2)
# print(Wx)
print(r.shape)

R_g: tensor([-0.9258, -0.4258, -2.6406,  ...,  0.2334, -0.6875,  1.0859],
       device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
torch.Size([3, 1, 64, 64])
tensor([[[[-0.9258, -0.4258, -2.6406,  ...,  0.1211,  0.4727, -1.0859],
          [-0.0334, -0.9727,  0.9570,  ..., -0.2129, -0.3320, -0.2021],
          [-1.1484, -0.5703, -0.6523,  ...,  0.2148, -0.7383, -0.4512],
          ...,
          [ 0.1680, -0.6953, -0.8633,  ...,  0.6797,  0.6445, -0.0197],
          [-0.0153,  0.6875,  0.1709,  ..., -0.5312,  0.0552, -0.0664],
          [ 0.6289,  0.9922,  0.9062,  ...,  1.6016, -1.0703,  1.6016]]],


        [[[ 2.2969,  0.1514,  2.0625,  ..., -0.8906,  1.1094,  0.5977],
          [-0.4355,  0.4609,  0.3574,  ...,  0.1748, -1.0547, -1.2266],
          [ 1.0000,  0.7227, -1.6250,  ..., -0.8008, -0.1719,  0.0952],
          ...,
          [-1.4219,  1.5703, -0.2891,  ...,  0.1226, -1.1875,  1.8359],
          [-0.7891,  0.8711,  0.1660,  ...,  0.5742, -0.6367,  1.2969],
       

In [88]:
# Models
ref_lstm = nn.GRU(
    input_size,
    hidden_size,
    num_layers,
    bias=True,
    batch_first=True,
    bidirectional=False,
).to(device=device, dtype=dtype)

initialize_ref_lstm_constant(ref_lstm)

# ========== 2. 同步 Recurrent 权重 R ==========
# 转换成 [4H, H] 形式，用于赋值给 ref_lstm.weight_hh_l0
# 步骤：
R_perm = R         # [4, NH, D, D]
R_reshaped = R_perm.reshape(3, hidden_size, hidden_size)  # [3, H, D]
weight_hh = R_reshaped.reshape(3 * hidden_size, hidden_size)  # [3H, D]
# 赋值到 ref_lstm 的 recurrent 权重
ref_lstm.weight_hh_l0.data.copy_(weight_hh)
print("ref_lstm.weight_hh_l0.data:",ref_lstm.weight_hh_l0.data)


# Inputs
# x = torch.randn(batch, seq_len, input_size, device="cuda", requires_grad=False)
# h0 = torch.ones(
#     num_layers, batch, hidden_size, device="cuda", requires_grad=True, dtype=dtype
# )
h0=torch.full((num_layers, batch, hidden_size),0,device="cuda", requires_grad=True, dtype=dtype)


# Clone inputs for reference
x_ref = x.detach().clone().requires_grad_()
h0_ref = h0.detach().clone().requires_grad_()
h0=h0.reshape(1,batch,num_head,hidden_size)
# Forward
out_ref, hn_ref = ref_lstm(x_ref, h0_ref)      #
# out_my, last_h = flashrnn(Wx,R,b, function="lstm",backend="cuda_fused",dtype=dtype_str)
out_my, hn_my = flashrnn(Wx,R,b, states=h0[None, :],function="gru",backend="cuda",dtype=dtype_str)
out_my=out_my.reshape(batch, seq_len, hidden_size)
hn_my=hn_my.reshape(num_layers, batch, hidden_size)
# cn_my=cn_my.reshape(num_layers, batch, hidden_size)

# out_ref, (hn_ref, cn_ref) = ref_lstm(x_ref)
# out_my, (hn_my, cn_my) = my_lstm(x)
# print("out my: ", out_my)
# print("out ref:", out_ref)
print("out_my shape: ", out_my.shape)    #  [NS, B,T,NH,D]
print("out_ref shape: ", out_ref.shape)  # [B,T,H]
# # Backward
loss_my = out_my.sum()
loss_ref = out_ref.sum()
loss_my.backward()
loss_ref.backward()

ref_lstm.weight_hh_l0.data: tensor([[-0.9258, -0.4258, -2.6406,  ...,  0.1211,  0.4727, -1.0859],
        [-0.0334, -0.9727,  0.9570,  ..., -0.2129, -0.3320, -0.2021],
        [-1.1484, -0.5703, -0.6523,  ...,  0.2148, -0.7383, -0.4512],
        ...,
        [-0.4902, -1.7734, -0.1094,  ..., -0.7617,  1.0000,  1.2500],
        [-1.4844,  1.0312, -1.0234,  ...,  0.1895, -1.3516, -2.2969],
        [-1.2422, -1.7422,  0.4453,  ...,  0.2334, -0.6875,  1.0859]],
       device='cuda:0', dtype=torch.bfloat16)
function:  gru
backend:  cuda
Wx:  torch.Size([1024, 16, 1, 3, 64])
R:  torch.Size([1, 64, 3, 64])
b:  torch.Size([1, 3, 64])
input states:  torch.Size([1, 16, 1, 1, 64])
recurrent shape:  torch.Size([1, 64, 3, 64])
bias shape:  torch.Size([1, 3, 64])
Wx shape:  torch.Size([1024, 16, 1, 3, 64])
R :  tensor([[[[-0.9258, -0.0334, -1.1484,  ...,  0.1680, -0.0153,  0.6289],
          [ 2.2969, -0.4355,  1.0000,  ..., -1.4219, -0.7891,  0.8359],
          [-0.8789,  1.2500,  1.3906,  ..., -0.

In [84]:
print(out_my[:,0,:])
print(out_ref[:,0,:])

tensor([[0.0105, 0.0105, 0.0105,  ..., 0.0105, 0.0105, 0.0105],
        [0.1680, 0.1680, 0.1680,  ..., 0.1680, 0.2695, 0.2656],
        [0.0267, 0.0267, 0.0267,  ..., 0.0267, 0.0267, 0.0267],
        ...,
        [0.0596, 0.0596, 0.0596,  ..., 0.0596, 0.0601, 0.0601],
        [0.1035, 0.1035, 0.1035,  ..., 0.1035, 0.1074, 0.1074],
        [0.1719, 0.1719, 0.1719,  ..., 0.1719, 0.2539, 0.2520]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
tensor([[0.0105, 0.0105, 0.0105,  ..., 0.0105, 0.0105, 0.0105],
        [0.1680, 0.1680, 0.1680,  ..., 0.1680, 0.1680, 0.1680],
        [0.0265, 0.0265, 0.0265,  ..., 0.0265, 0.0265, 0.0265],
        ...,
        [0.0596, 0.0596, 0.0596,  ..., 0.0596, 0.0596, 0.0596],
        [0.1035, 0.1035, 0.1035,  ..., 0.1035, 0.1035, 0.1035],
        [0.1719, 0.1719, 0.1719,  ..., 0.1719, 0.1719, 0.1719]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SliceBackward0>)


In [85]:
hn_my=hn_my.reshape(num_layers, batch, hidden_size)
print(hn_my.shape)

print(hn_my)
print(hn_ref)

torch.Size([1, 16, 64])
tensor([[[-0.5078,  0.1138,  1.0000,  ...,  0.4961,  1.0000, -0.7188],
         [-0.4043,  0.7500,  0.9023,  ..., -0.4551,  1.0000,  0.3516],
         [-0.7070, -0.5039,  1.0000,  ...,  0.4531,  1.0000, -0.2656],
         ...,
         [-0.4844, -0.2656,  0.9844,  ...,  0.6992,  1.0000, -0.1245],
         [-0.8555, -0.1211,  0.7734,  ...,  0.7422,  1.0000, -0.7461],
         [-0.6953,  0.4336,  0.9805,  ...,  0.7227,  1.0000, -0.0457]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ViewBackward0>)
tensor([[[ 0.9727,  0.4199,  0.2227,  ..., -0.7422, -0.8086,  0.9570],
         [ 0.8281,  0.1367,  1.0000,  ..., -0.0771, -0.7930,  0.3398],
         [ 0.8555,  1.0000,  0.8516,  ...,  0.9727, -0.3027,  0.9922],
         ...,
         [-0.9180, -0.9961, -0.9961,  ..., -0.1216,  0.6641, -0.0640],
         [-0.9062, -1.0000,  0.9688,  ...,  0.1904, -0.5703,  0.9531],
         [ 0.1895, -1.0000,  0.5508,  ..., -0.9023, -0.8438,  0.3828]]],
       device='cuda:0

In [86]:
result1 = hn_my[0, :, 0].tolist()  # list of 16 elements
result2 = hn_ref[0, :, 0].tolist()  # list of 16 elements

print(result1)
print(result2)

[-0.5078125, -0.404296875, -0.70703125, 0.099609375, -0.451171875, 0.05029296875, -0.5546875, 0.11572265625, -0.15234375, -0.96484375, -0.6953125, -0.2421875, -0.46875, -0.484375, -0.85546875, -0.6953125]
[0.97265625, 0.828125, 0.85546875, 0.59375, 0.4921875, 0.7265625, 0.94140625, 0.0791015625, 0.7109375, 0.8203125, -0.7265625, 0.455078125, 0.81640625, -0.91796875, -0.90625, 0.189453125]


In [87]:
# Output comparison
print(f"[Forward] Output max abs diff: {max_abs_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     max abs diff: {max_abs_diff(hn_my, hn_ref):.2e}")
# print(f"[Forward] cn     max abs diff: {max_abs_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output mean abs diff: {mean_abs_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     mean abs diff: {mean_abs_diff(hn_my, hn_ref):.2e}")
# print(f"[Forward] cn     mean abs diff: {mean_abs_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output max ref diff: {max_ref_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     max ref diff: {max_ref_diff(hn_my, hn_ref):.2e}")
# print(f"[Forward] cn     max ref diff: {max_ref_diff(cn_my, cn_ref):.2e}")
print(f"[Forward] Output mean ref diff: {mean_ref_diff(out_my, out_ref):.2e}")
print(f"[Forward] hn     mean ref diff: {mean_ref_diff(hn_my, hn_ref):.2e}")
# print(f"[Forward] cn     mean ref diff: {mean_ref_diff(cn_my, cn_ref):.2e}")

[Forward] Output max abs diff: 2.00e+00
[Forward] hn     max abs diff: 2.00e+00
[Forward] Output mean abs diff: 8.40e-01
[Forward] hn     mean abs diff: 8.48e-01
[Forward] Output max ref diff: 5.90e+05
[Forward] hn     max ref diff: 3.72e+02
[Forward] Output mean ref diff: 6.91e+00
[Forward] hn     mean ref diff: 3.08e+00


In [90]:

# Gradients
print(f"[Grad] Input x     grad diff: {max_abs_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {mean_abs_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {max_ref_diff(x.grad, x_ref.grad):.2e}")
print(f"[Grad] Input x     grad diff: {mean_ref_diff(x.grad, x_ref.grad):.2e}")

print(f"[Grad] h0          grad diff: {max_abs_diff(h0[None, :].grad, h0_ref.grad):.2e}")

for (n1, p1), (n2, p2) in zip(
    my_lstm.named_parameters(), ref_lstm.named_parameters()
):
    if p1.grad is not None and p2.grad is not None:
        diff = max_abs_diff(p1.grad, p2.grad)
        print(f"[Grad] Param {n1:20s} grad diff: {diff:.2e}")

[Grad] Input x     grad diff: nan
[Grad] Input x     grad diff: nan
[Grad] Input x     grad diff: nan
[Grad] Input x     grad diff: nan


  print(f"[Grad] h0          grad diff: {max_abs_diff(h0[None, :].grad, h0_ref.grad):.2e}")


TypeError: unsupported operand type(s) for -: 'NoneType' and 'Tensor'