In [22]:
%%writefile test_conv2d_fp16.cu
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <iostream>
#include <random>

// float32 kernel
__global__ void conv2d_float(const float* input, const float* kernel, float* output,
                            int input_rows, int input_cols,
                            int kernel_rows, int kernel_cols,
                            int output_rows, int output_cols) {
    unsigned int out_col = blockIdx.x * blockDim.x + threadIdx.x;
    unsigned int out_row = blockIdx.y * blockDim.y + threadIdx.y;

    float sum = 0.0f;
    if (out_row < output_rows && out_col < output_cols) {
        for (int m = 0; m < kernel_rows; ++m) {
            for (int n = 0; n < kernel_cols; ++n) {
                int input_row = out_row + m;
                int input_col = out_col + n;
                float input_value = input[input_row * input_cols + input_col];
                float kernel_value = kernel[m * kernel_cols + n];
                sum += input_value * kernel_value;
            }
        }
        output[out_row * output_cols + out_col] = sum;
    }
}

// FP16 kernel =
__global__ void conv2d_fp16(const __half* input, const __half* kernel,
                          __half* output,
                          int input_rows, int input_cols,
                          int kernel_rows, int kernel_cols,
                          int output_rows, int output_cols) {
    unsigned int out_col = blockIdx.x * blockDim.x + threadIdx.x;
    unsigned int out_row = blockIdx.y * blockDim.y + threadIdx.y;

    __half sum = __float2half(0.0f);
    if (out_row < output_rows && out_col < output_cols) {
        for (int m = 0; m < kernel_rows; ++m) {
            for (int n = 0; n < kernel_cols; ++n) {
                int input_row = out_row + m;
                int input_col = out_col + n;
                __half input_value = input[input_row * input_cols + input_col];
                __half kernel_value = kernel[m * kernel_cols + n];
                sum = __hadd(sum, __hmul(input_value, kernel_value)); // Native FP16 ops
            }
        }
        output[out_row * output_cols + out_col] = sum;
    }
}

extern "C" float run_float_kernel(float* input, float* kernel, float* output,
                                 int input_rows, int input_cols,
                                 int kernel_rows, int kernel_cols) {
    float *d_input, *d_kernel, *d_output;
    int output_rows = input_rows - (kernel_rows - 1);
    int output_cols = input_cols - (kernel_cols - 1);

    cudaMalloc((void**)&d_input, input_rows * input_cols * sizeof(float));
    cudaMalloc((void**)&d_kernel, kernel_rows * kernel_cols * sizeof(float));
    cudaMalloc((void**)&d_output, output_rows * output_cols * sizeof(float));

    cudaMemcpy(d_input, input, input_rows * input_cols * sizeof(float), cudaMemcpyHostToDevice);
    cudaMemcpy(d_kernel, kernel, kernel_rows * kernel_cols * sizeof(float), cudaMemcpyHostToDevice);

    dim3 threadsPerBlock(16, 16);
    dim3 numBlocks((output_cols + threadsPerBlock.x - 1) / threadsPerBlock.x,
                  (output_rows + threadsPerBlock.y - 1) / threadsPerBlock.y);

    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);

    cudaEventRecord(start);
    conv2d_float<<<numBlocks, threadsPerBlock>>>(d_input, d_kernel, d_output,
                                                input_rows, input_cols,
                                                kernel_rows, kernel_cols,
                                                output_rows, output_cols);
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);

    float milliseconds = 0;
    cudaEventElapsedTime(&milliseconds, start, stop);

    cudaMemcpy(output, d_output, output_rows * output_cols * sizeof(float), cudaMemcpyDeviceToHost);

    cudaFree(d_input); cudaFree(d_kernel); cudaFree(d_output);
    cudaEventDestroy(start); cudaEventDestroy(stop);

    return milliseconds / 1000.0f; // Convert to seconds
}

extern "C" float run_fp16_kernel(float* input, float* kernel, float* output,
                                int input_rows, int input_cols,
                                int kernel_rows, int kernel_cols) {
    __half *d_input, *d_kernel, *d_output;
    int output_rows = input_rows - (kernel_rows - 1);
    int output_cols = input_cols - (kernel_cols - 1);

    // Convert to FP16 once on host
    __half *h_input_fp16 = new __half[input_rows * input_cols];
    __half *h_kernel_fp16 = new __half[kernel_rows * kernel_cols];
    for (int i = 0; i < input_rows * input_cols; i++)
        h_input_fp16[i] = __float2half(input[i]);
    for (int i = 0; i < kernel_rows * kernel_cols; i++)
        h_kernel_fp16[i] = __float2half(kernel[i]);

    cudaMalloc((void**)&d_input, input_rows * input_cols * sizeof(__half));
    cudaMalloc((void**)&d_kernel, kernel_rows * kernel_cols * sizeof(__half));
    cudaMalloc((void**)&d_output, output_rows * output_cols * sizeof(__half));

    cudaMemcpy(d_input, h_input_fp16, input_rows * input_cols * sizeof(__half), cudaMemcpyHostToDevice);
    cudaMemcpy(d_kernel, h_kernel_fp16, kernel_rows * kernel_cols * sizeof(__half), cudaMemcpyHostToDevice);

    dim3 threadsPerBlock(16, 16);
    dim3 numBlocks((output_cols + threadsPerBlock.x - 1) / threadsPerBlock.x,
                  (output_rows + threadsPerBlock.y - 1) / threadsPerBlock.y);

    cudaEvent_t start, stop;
    cudaEventCreate(&start);
    cudaEventCreate(&stop);

    cudaEventRecord(start);
    conv2d_fp16<<<numBlocks, threadsPerBlock>>>(d_input, d_kernel, d_output,
                                               input_rows, input_cols,
                                               kernel_rows, kernel_cols,
                                               output_rows, output_cols);
    cudaEventRecord(stop);
    cudaEventSynchronize(stop);

    float milliseconds = 0;
    cudaEventElapsedTime(&milliseconds, start, stop);

    __half *h_output_fp16 = new __half[output_rows * output_cols];
    cudaMemcpy(h_output_fp16, d_output, output_rows * output_cols * sizeof(__half), cudaMemcpyDeviceToHost);

    for (int i = 0; i < output_rows * output_cols; i++)
        output[i] = __half2float(h_output_fp16[i]);

    cudaFree(d_input); cudaFree(d_kernel); cudaFree(d_output);
    delete[] h_input_fp16; delete[] h_kernel_fp16; delete[] h_output_fp16;
    cudaEventDestroy(start); cudaEventDestroy(stop);

    return milliseconds / 1000.0f; // Convert to seconds
}

Overwriting test_conv2d_fp16.cu


In [23]:
# compile the CUDA code into a shared library
!nvcc -shared -o test_conv2d_fp16.so test_conv2d_fp16.cu -arch=sm_75 -Xcompiler -fPIC

In [26]:
import numpy as np
from ctypes import *
import time

lib = cdll.LoadLibrary('./test_conv2d_fp16.so')

# update function signatures to return float
lib.run_float_kernel.argtypes = [POINTER(c_float), POINTER(c_float), POINTER(c_float),
                                c_int, c_int, c_int, c_int]
lib.run_float_kernel.restype = c_float
lib.run_fp16_kernel.argtypes = [POINTER(c_float), POINTER(c_float), POINTER(c_float),
                               c_int, c_int, c_int, c_int]
lib.run_fp16_kernel.restype = c_float

# test data
input_rows, input_cols = 1024 * 3, 1024 * 3
kernel_rows, kernel_cols = 128, 128
output_rows = input_rows - (kernel_rows - 1)
output_cols = input_cols - (kernel_cols - 1)

input_data = np.random.randn(input_rows, input_cols).astype(np.float32)
kernel_data = np.random.randn(kernel_rows, kernel_cols).astype(np.float32)
output_float = np.zeros((output_rows, output_cols), dtype=np.float32)
output_fp16 = np.zeros((output_rows, output_cols), dtype=np.float32)

input_ptr = input_data.ctypes.data_as(POINTER(c_float))
kernel_ptr = kernel_data.ctypes.data_as(POINTER(c_float))
output_float_ptr = output_float.ctypes.data_as(POINTER(c_float))
output_fp16_ptr = output_fp16.ctypes.data_as(POINTER(c_float))

num_runs = 100
warmup_runs = 10

float_times = []
fp16_times = []

# warm-up
for _ in range(warmup_runs):
    lib.run_float_kernel(input_ptr, kernel_ptr, output_float_ptr,
                        input_rows, input_cols, kernel_rows, kernel_cols)
    lib.run_fp16_kernel(input_ptr, kernel_ptr, output_fp16_ptr,
                       input_rows, input_cols, kernel_rows, kernel_cols)

# benchmark
for _ in range(num_runs):
    float_time = lib.run_float_kernel(input_ptr, kernel_ptr, output_float_ptr,
                                     input_rows, input_cols, kernel_rows, kernel_cols)
    float_times.append(float_time)

    fp16_time = lib.run_fp16_kernel(input_ptr, kernel_ptr, output_fp16_ptr,
                                   input_rows, input_cols, kernel_rows, kernel_cols)
    fp16_times.append(fp16_time)

# stats
float_mean = np.mean(float_times)
float_std = np.std(float_times)
fp16_mean = np.mean(fp16_times)
fp16_std = np.std(fp16_times)

diff = np.abs(output_float - output_fp16)
max_error = np.max(diff)
mean_error = np.mean(diff)
mse = np.mean(diff**2)

print(f"Float32 kernel time: {float_mean:.6f} ± {float_std:.6f} seconds ({num_runs} runs)")
print(f"FP16 kernel time: {fp16_mean:.6f} ± {fp16_std:.6f} seconds ({num_runs} runs)")
print(f"Speedup (Float32/FP16): {float_mean/fp16_mean:.2f}x")
print(f"Maximum absolute error: {max_error:.6f}")
print(f"Mean absolute error: {mean_error:.6f}")
print(f"Mean squared error: {mse:.6f}")

print("\nFirst 5x5 of Float32 output:")
print(output_float[:5, :5])
print("\nFirst 5x5 of FP16 output:")
print(output_fp16[:5, :5])

Float32 kernel time: 0.000000 ± 0.000000 seconds (100 runs)
FP16 kernel time: 0.000000 ± 0.000000 seconds (100 runs)
Speedup (Float32/FP16): 1.46x
Maximum absolute error: 0.289307
Mean absolute error: 0.036147
Mean squared error: 0.002109

First 5x5 of Float32 output:
[[ -55.84884     25.658815  -413.50046    259.5033    -109.114334 ]
 [ -58.657173    59.219414   218.30032     57.322186    -8.62942  ]
 [-166.65118    214.80194    314.93027   -192.7418    -111.4903   ]
 [ 186.23802   -160.4699    -162.89262    -34.438816    17.689949 ]
 [ 302.5674      -6.6662626  217.40765     16.765682   249.54005  ]]

First 5x5 of FP16 output:
[[ -55.84375     25.671875  -413.5        259.5       -109.0625   ]
 [ -58.6875      59.21875    218.25        57.25        -8.6171875]
 [-166.625      214.75       315.        -192.75      -111.5      ]
 [ 186.25      -160.5       -162.875      -34.4375      17.6875   ]
 [ 302.5         -6.703125   217.375       16.8125     249.5      ]]
