In [8]:
import ctypes
import os
import numpy as np

#os.environ['LD_LIBRARY_PATH'] = '/opt/rocm/lib:'
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "11.0.0"

# Load the shared library
hip_lib = ctypes.CDLL('/home/qin/rocm_test/MoE/lin.so')

# Set argument types for process_tensor
hip_lib.linear.argtypes = [
    ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p,
    ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
]

# Example parameters
batch_size = 2
seq_len = 3
d_hidden = 4
d_expert = 5

# Allocate numpy arrays (float16 for __half)
a = np.random.rand(batch_size * seq_len * d_hidden).astype(np.float16)
b = np.random.rand(d_hidden * d_expert).astype(np.float16)
d = np.zeros(batch_size * seq_len * d_expert, dtype=np.float16)

# Call the function
hip_lib.linear(
    a.ctypes.data, b.ctypes.data, d.ctypes.data,
    batch_size, seq_len, d_hidden, d_expert
)

# CPU 计算参考结果
# a: (batch_size, seq_len, d_hidden), row-major
# b: (d_hidden, d_expert), column-major
# d = a * b

ref_d = np.zeros_like(d, dtype=np.float32)
for bidx in range(batch_size):
    for i in range(seq_len):
        for j in range(d_expert):
            sum_ = 0.0
            for k in range(d_hidden):
                # a: row-major, index = ((bidx * seq_len + i) * d_hidden + k)
                a_val = float(a[(bidx * seq_len + i) * d_hidden + k])
                # b: column-major, index = (k + j * d_hidden)
                b_val = float(b[k + j * d_hidden])
                sum_ += a_val * b_val
            ref_d[(bidx * seq_len + i) * d_expert + j] = sum_

# 校验
ok = True
for i in range(d.size):
    v1 = float(d[i])
    v2 = ref_d[i]
    if abs(v1 - v2) > 1e-2:
        print(f"Mismatch at {i}: GPU={v1} CPU={v2}")
        ok = False
if ok:
    print("Linear result matches CPU reference!")
else:
    print("Linear result does NOT match CPU reference!")

# 打印 GPU 结果
print("GPU output (all values):", d.astype(np.float32))
# 打印 CPU 参考结果
print("CPU reference output (all values):", ref_d)
# 打印部分结果
print("First 5 output values (GPU):", d[:5].astype(np.float32))



# Now d contains the result
#print("Output d:", d)

Linear result matches CPU reference!
GPU output (all values): [0.5439453  0.95410156 0.52490234 0.5830078  0.578125   0.5214844
 1.1601562  1.1591797  0.9111328  0.5878906  0.7944336  1.5126953
 1.2333984  1.1181641  0.828125   0.953125   1.6650391  1.3037109
 1.3642578  0.7475586  0.6616211  1.4482422  1.3203125  1.0341797
 0.8203125  1.015625   1.8759766  1.3730469  1.3076172  1.0751953 ]
CPU reference output (all values): [0.54416555 0.95363384 0.52483964 0.58296454 0.5779909  0.52145296
 1.1604624  1.1591532  0.91094786 0.58769035 0.7943734  1.5124687
 1.2331318  1.1180426  0.82813233 0.9528602  1.6646446  1.304158
 1.364641   0.7474949  0.66152084 1.4476941  1.3202235  1.0342244
 0.8203589  1.0151855  1.8768799  1.3726771  1.307194   1.0748775 ]
First 5 output values (GPU): [0.5439453  0.95410156 0.52490234 0.5830078  0.578125  ]
