In [None]:
import ctypes
import numpy as np
import torch
import os

# 假设 hip_kernel.so 是编译后的共享库
hip_kernel = ctypes.cdll.LoadLibrary("./hip_kernel.so")

# 假设 kernel 函数名为 run_fp8_gemm，并且是用 extern "C" 暴露的
hip_kernel.run_fp8_gemm.argtypes = [
    ctypes.POINTER(ctypes.c_uint8),   # a: float8_e4m3fnuz
    ctypes.POINTER(ctypes.c_uint8),   # b: float8_e4m3fnuz
    ctypes.POINTER(ctypes.c_float),   # a_scale
    ctypes.POINTER(ctypes.c_float),   # b_scale
    ctypes.POINTER(ctypes.c_uint16),  # c: bf16
    ctypes.c_int, ctypes.c_int, ctypes.c_int  # m, n, k
]

def custom_kernel(data: input_t) -> output_t:
    a, b, a_scale, b_scale, c = data
    m, k = a.shape
    n = b.shape[0]

    # 转换 float8 → uint8（将 view 为 uint8）
    a_np = a.contiguous().view(torch.uint8).cpu().numpy()
    b_np = b.contiguous().view(torch.uint8).cpu().numpy()

    # scaling 因子为 float32
    a_scale_np = a_scale.contiguous().cpu().numpy().astype(np.float32)
    b_scale_np = b_scale.contiguous().cpu().numpy().astype(np.float32)

    # bf16 → uint16（bf16 是 PyTorch 中 16bit）
    c_np = c.contiguous().view(torch.uint16).cpu().numpy()

    # 调用 HIP kernel
    hip_kernel.run_fp8_gemm(
        a_np.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)),
        b_np.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)),
        a_scale_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
        b_scale_np.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
        c_np.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)),
        ctypes.c_int(m), ctypes.c_int(n), ctypes.c_int(k)
    )

    # 将结果写回 PyTorch Tensor（注意数据类型）
    c.copy_(torch.from_numpy(c_np).view(torch.bfloat16).to("cuda"))

    return c


In [1]:
# This script provides a template for using load_inline to run a HIP kernel for
from torch.utils.cpp_extension import load_inline
from task import input_t, output_t
CPP_WRAPPER = """
void fp8_mm(torch::Tensor a, torch::Tensor b, torch::Tensor as, torch::Tensor bs, torch::Tensor c);
"""

CUDA_SRC = """
#include <hip/amd_detail/amd_hip_fp8.h>
#include <hip/amd_detail/amd_hip_bf16.h>

constexpr const int BLOCK = 128;

__global__ void custom_kernel(const __hip_fp8_e4m3_fnuz* a, const __hip_fp8_e4m3_fnuz* b, const float* as, const float* bs, 
                   __hip_bfloat16* c, int m, int n, int k) {
                   
    // Your implementation here
    int cx = threadIdx.x + blockDim.x * blockIdx.x;
    int cy = threadIdx.y + blockDim.y * blockIdx.y;
    if(cx >= m || cy >= n) return;
    
    int sn = (n + BLOCK - 1) / BLOCK;
    
    float result = 0;
    // split loop into an outer loop over different blocks, and an inner loop within one block.
    // we can assume k % BLOCK == 0.
    for(int i = 0; i < k; i += BLOCK) {
        // block results accumulates the inner product across a single block.
        // within each block, scales are constant, so we can lift the scaling 
        // outside of the inner loop.
        float block_result = 0;
        for(int ii = 0; ii < BLOCK; ++ii) {
            // load input matrix elements and convert to float for computations
            float av = (float)a[cx + (i + ii) * m];
            float bv = (float)b[cy + (i + ii) * n];
            block_result += av * bv; 
        }
        
        // before we can go to the next block, scale the result of the current block
        // and accumulate to final result
        // note the different indexing into as and bs
        result += block_result * as[cx + i/BLOCK * m] * bs[cy/BLOCK + i/BLOCK * sn];
    }
    
    // finally, write the result as bf16
    c[cx * n + cy] = (__hip_bfloat16)result;
}

void fp8_mm(torch::Tensor a, torch::Tensor b, torch::Tensor as, torch::Tensor bs, torch::Tensor c) {
    int m = a.size(0);
    int n = b.size(0);
    int k = a.size(1);
    custom_kernel<<<dim3((m+15)/16, (n+15)/16), dim3(16, 16), 0, 0>>> ((__hip_fp8_e4m3_fnuz*)a.data_ptr(), (__hip_fp8_e4m3_fnuz*)b.data_ptr(), 
    as.data_ptr<float>(), bs.data_ptr<float>(), (__hip_bfloat16*)c.data_ptr(), m, n, k);
    //C10_CUDA_CHECK(cudaGetLastError());
}
"""

import os
os.environ["CXX"] = "/opt/rocm-6.3.3/bin/hipcc"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "11.0.0"

# os.environ["RANK"] = "0"          # 单GPU需设置 RANK=0
# os.environ["WORLD_SIZE"] = "1"    # 总进程数为1
# os.environ["MASTER_ADDR"] = "localhost"  # 设置主节点地址
# os.environ["MASTER_PORT"] = "12355"      # 设置通信端口

module = load_inline(
    name='fp8_mm',
    cpp_sources=[CPP_WRAPPER],
    cuda_sources=[CUDA_SRC],
    functions=['fp8_mm'],
    verbose=True,
    extra_cuda_cflags=["--offload-arch=gfx1100", "-std=c++20"],
)


def custom_kernel(data: input_t) -> output_t:
    a, b, a_scale, b_scale, c = data
    module.fp8_mm(a, b, a_scale, b_scale, c)
    return c



/home/qin/.cache/torch_extensions/py312_cpu/fp8_mm/main.cpp -> /home/qin/.cache/torch_extensions/py312_cpu/fp8_mm/main.cpp [skipped, no changes]
/home/qin/.cache/torch_extensions/py312_cpu/fp8_mm/cuda.cu -> /home/qin/.cache/torch_extensions/py312_cpu/fp8_mm/hip.hip [skipped, already hipified]


Using /home/qin/.cache/torch_extensions/py312_cpu as PyTorch extensions root...
[92mSuccessfully preprocessed all matching files.[0m


Total number of unsupported CUDA function calls: 0


Total number of replaced kernel launches: 1


CalledProcessError: Command '['/opt/rocm-6.3.3/bin/hipcc', '-v']' returned non-zero exit status 1.