In [16]:
import torch
from tool_function import *
from dequant_function import *
import os
from pytorch_memlab import LineProfiler, profile
dev = 'cuda'

In [17]:
def unittest_deqaunt_cuda(param_list, tensor_buffer, tensor_buffer_listview, groupsize=128, quant_bits=4, quant_module=None, hadamard=False):
    tensor_type = tensor_buffer.dtype
    dequant_tensor_cuda = torch.empty_like(tensor_buffer)

    # for i in range(len(param_list)):
    #     tensor_buffer_listview[i].sub_(param_list[i])
    torch._foreach_sub_(tensor_buffer_listview, param_list)
    print('after sub', tensor_buffer)

    if hadamard is True:
        h_tensor = tensor_buffer.clone()
        h_tensor = fast_hadamard_transform(h_tensor, k=5, normalize=True)

    # stochastic quantize kernel
    N = tensor_buffer.nelement()
    groups = N // groupsize
    quant_tensor_cuda, quant_scales_cuda = quant_module.stochastic_quantize(tensor_buffer, groups, quant_bits, quant_module.Symmetric)
    
    if tensor_type is torch.bfloat16:
        quant_module.dequantize_bf16(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)
    elif tensor_type is torch.float32:
        quant_module.dequantize_fp32(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)
    elif tensor_type is torch.float16:
        quant_module.dequantize_half(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)
    else:
        assert(False), "dequant_type is not supported"

    if hadamard is True:
        dequant_tensor_cuda = fast_hadamard_transform(dequant_tensor_cuda, k=5, normalize=True)

    abs_error_norm, rela_error_norm = analysis_diff(tensor_buffer, dequant_tensor_cuda)
    print(f"cuda version quantization, absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}")
    return dequant_tensor_cuda

In [18]:
def unittest_deqaunt_cuda_fused(param_list, tensor_buffer, tensor_buffer_listview, groupsize=128, quant_bits=4, quant_module=None, hadamard=False):
    dp_size = 4
    dp_rank = 2
    assert tensor_buffer.numel() % dp_size == 0
    tensor_buffer_dp_view = []
    for i in range(dp_size):
        param_buffer_size = tensor_buffer.numel() // dp_size
        tensor_buffer_dp_view.append(tensor_buffer[i*param_buffer_size: (i+1)*param_buffer_size])

    tensor_type = tensor_buffer.dtype
    dequant_tensor_cuda = torch.empty_like(tensor_buffer_dp_view[dp_rank])

    if hadamard is True:
        h_tensor = tensor_buffer.clone()
        h_tensor = fast_hadamard_transform(h_tensor, k=5, normalize=True)

    # stochastic quantize kernel
    N = tensor_buffer.nelement()
    groups = N // groupsize
    dp_param_offset = dp_rank * tensor_buffer.numel() // dp_size
    quant_tensor_cuda, quant_scales_cuda = quant_module.sub_quantize(tensor_buffer_dp_view[dp_rank], param_list, dp_param_offset, groups, quant_bits, quant_module.Symmetric)
    
    if tensor_type is torch.bfloat16:
        quant_module.dequantize_bf16(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)
    elif tensor_type is torch.float32:
        quant_module.dequantize_fp32(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)
    elif tensor_type is torch.float16:
        quant_module.dequantize_half(quant_tensor_cuda, quant_scales_cuda, dequant_tensor_cuda, groups, quant_bits, quant_module.Symmetric)
    else:
        assert(False), "dequant_type is not supported"

    if hadamard is True:
        dequant_tensor_cuda = fast_hadamard_transform(dequant_tensor_cuda, k=5, normalize=True)

    torch._foreach_sub_(tensor_buffer_listview, param_list)
    abs_error_norm, rela_error_norm = analysis_diff(tensor_buffer_dp_view[dp_rank], dequant_tensor_cuda)
    print(f"cuda version quantization, absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}")
    return dequant_tensor_cuda

In [19]:
pkg_path = "/home/jindjia/scripts/Megatron-LM-jinda-final_speed_test/tools/jet_quant_cuda"
print('pkg path:', pkg_path)
quantization_module = build_and_import_module(pkg_path, 'quantization_cuda')

pkg path: /home/jindjia/scripts/Megatron-LM-jinda-final_speed_test/tools/jet_quant_cuda
running build
running build_ext
building 'quantization_cuda' extension
ninja: no work to do.
g++ -pthread -B /home/jindjia/miniforge3/envs/megatron/compiler_compat -shared -Wl,--allow-shlib-undefined -Wl,-rpath,/home/jindjia/miniforge3/envs/megatron/lib -Wl,-rpath-link,/home/jindjia/miniforge3/envs/megatron/lib -L/home/jindjia/miniforge3/envs/megatron/lib -Wl,--allow-shlib-undefined -Wl,-rpath,/home/jindjia/miniforge3/envs/megatron/lib -Wl,-rpath-link,/home/jindjia/miniforge3/envs/megatron/lib -L/home/jindjia/miniforge3/envs/megatron/lib /home/jindjia/scripts/Megatron-LM-jinda-final_speed_test/tools/jet_quant_cuda/build/temp.linux-x86_64-cpython-311/home/jindjia/scripts/Megatron-LM-jinda-final_speed_test/tools/jet_quant_cuda/hadamard/fast_hadamard_transform_cuda.o /home/jindjia/scripts/Megatron-LM-jinda-final_speed_test/tools/jet_quant_cuda/build/temp.linux-x86_64-cpython-311/home/jindjia/scripts/Me

Emitting ninja build file /home/jindjia/scripts/Megatron-LM-jinda-final_speed_test/tools/jet_quant_cuda/build/temp.linux-x86_64-cpython-311/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


Successfully complied quantization module
build_lib_path /home/jindjia/scripts/Megatron-LM-jinda-final_speed_test/tools/jet_quant_cuda/build/lib.linux-x86_64-cpython-311


In [20]:
tensor1 = torch.randn((1024 * 8,), dtype=torch.bfloat16, device=dev)
tensor2 = torch.randn((1024 * 16,), dtype=torch.bfloat16, device=dev)
tensor3 = torch.randn((1024 * 32,), dtype=torch.bfloat16, device=dev)

param_list = [tensor1, tensor2, tensor3]
total_len = sum([tensor.numel() for tensor in param_list])
print(f"total len: {total_len}")

param_buffer = torch.empty(size=(total_len,), dtype=torch.bfloat16, device=dev)
param_buffer_list_view = []

offset = 0
for i in range(len(param_list)):
    start_idx = offset
    offset += param_list[i].numel()
    end_idx = offset
    param_buffer_list_view.append(param_buffer[start_idx:end_idx])

    param_buffer_list_view[-1].copy_(param_list[i]*1.1)

print(param_list[0])
print(param_buffer_list_view[0])
print(param_buffer)


total len: 57344
tensor([-0.0688, -0.9922,  0.5352,  ...,  0.5898,  0.1621,  0.6016],
       device='cuda:0', dtype=torch.bfloat16)
tensor([-0.0757, -1.0938,  0.5898,  ...,  0.6484,  0.1787,  0.6602],
       device='cuda:0', dtype=torch.bfloat16)
tensor([-0.0757, -1.0938,  0.5898,  ...,  0.7188,  1.6328,  1.3828],
       device='cuda:0', dtype=torch.bfloat16)


In [21]:
output_buffer = param_buffer.clone()
output_buffer_list_view = []

offset = 0
for i in range(len(param_list)):
    start_idx = offset
    offset += param_list[i].numel()
    end_idx = offset
    output_buffer_list_view.append(output_buffer[start_idx:end_idx])

dequant_tensor = unittest_deqaunt_cuda(param_list=param_list, tensor_buffer=output_buffer, tensor_buffer_listview=output_buffer_list_view, groupsize = 2048, quant_bits = 4, quant_module=quantization_module, hadamard=False )
print(dequant_tensor, torch.norm(dequant_tensor))

after sub tensor([-0.0068, -0.1016,  0.0547,  ...,  0.0664,  0.1484,  0.1250],
       device='cuda:0', dtype=torch.bfloat16)
cuda version quantization, absolute error norm: 5.09375, relative error norm: 0.212890625
tensor([ 0.0000, -0.0938,  0.0938,  ...,  0.0513,  0.1543,  0.1025],
       device='cuda:0', dtype=torch.bfloat16) tensor(24.5000, device='cuda:0', dtype=torch.bfloat16)


In [22]:
output_buffer = param_buffer.clone()
output_buffer_list_view = []

offset = 0
for i in range(len(param_list)):
    start_idx = offset
    offset += param_list[i].numel()
    end_idx = offset
    output_buffer_list_view.append(output_buffer[start_idx:end_idx])

dequant_tensor = unittest_deqaunt_cuda_fused(param_list=param_list, tensor_buffer=output_buffer, tensor_buffer_listview=output_buffer_list_view, groupsize = 2048, quant_bits = 4, quant_module=quantization_module, hadamard=False )
print(dequant_tensor, torch.norm(dequant_tensor))

cuda version quantization, absolute error norm: 2.3125, relative error norm: 0.1923828125
tensor([ 0.1270, -0.0425, -0.0425,  ...,  0.0469,  0.1406, -0.0938],
       device='cuda:0', dtype=torch.bfloat16) tensor(12.1875, device='cuda:0', dtype=torch.bfloat16)


In [23]:
import time

def functionA(func, *args, **kwargs):
    """
    Measures the GPU running time of the given function using CUDA events.
    
    Parameters:
    - func: The function to be measured.
    - *args: Positional arguments to pass to the function.
    - **kwargs: Keyword arguments to pass to the function.
    
    Returns:
    - result: The result of the function execution.
    - elapsed_time_ms: The time taken to execute the function on the GPU in milliseconds.
    """
    # Ensure CUDA is available
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available.")
    
    active = 20

    # Create CUDA events for timing
    # start_event = torch.cuda.Event(enable_timing=True)
    # end_event = torch.cuda.Event(enable_timing=True)
    
    # warmup
    for _ in range(10):
        _ = func(*args, **kwargs)

    # Synchronize and empty the cache before starting
    torch.cuda.synchronize(device=dev)
    
    # Record the start event
    # start_event.record()
    begin = time.time()
    # Call the function with provided arguments
    for _ in range(active):
        result = func(*args, **kwargs)
    torch.cuda.synchronize()
    elapsed_time_s = (time.time()-begin)
    elapsed_time_s = elapsed_time_s / active
    elapsed_time_ms = elapsed_time_s * 1000
    # # Record the end event
    # end_event.record()
    
    # # Wait for the events to be recorded
    # torch.cuda.synchronize(device=dev)
    
    # Calculate the elapsed time
    # elapsed_time_ms = start_event.elapsed_time(end_event)
    # elapsed_time_ms = elapsed_time_ms / active

    return result, elapsed_time_ms

In [24]:
param_model_list = []
hidden_size = 2048
for i in range(24):
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))

    param_model_list.append(torch.randn((hidden_size * 3 * hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size * hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))

    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))

    param_model_list.append(torch.randn((hidden_size * 4 * hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size * hidden_size * 4,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))

# for i in range(1):
#     param_model_list.append(torch.randn((hidden_size * 4 * hidden_size,), dtype=torch.bfloat16, device=dev))
#     param_model_list.append(torch.randn((hidden_size * hidden_size,), dtype=torch.bfloat16, device=dev))
#     param_model_list.append(torch.randn((hidden_size * hidden_size * 4,), dtype=torch.bfloat16, device=dev))


In [25]:
# Speed Test for Stoquantize
def unfused_stoquantize(param_buffer_tensor, param_buffer_list_view, param_model_list, groups, quant_bits, quant_mode):
    for i in range(len(param_model_list)):
        param_buffer_list_view[i].sub_(param_model_list[i])
    # torch._foreach_sub_(param_buffer_list_view, param_model_list)
    quantization_module.stochastic_quantize(param_buffer_tensor, groups, quant_bits, quant_mode)

dp_size = 4
dp_rank = 0

total_len = sum([tensor.numel() for tensor in param_model_list]) // dp_size
print(f"total len: {total_len}")
print(f"tensor size: {total_len * param_model_list[0].element_size() / 1024 / 1024} MB, dtype: {param_model_list[0].dtype}")


param_buffer_tensor = torch.empty(size=(total_len * dp_size,), dtype=torch.bfloat16, device=dev)
tensor_buffer_dp_view = []
for i in range(dp_size):
    param_buffer_size = param_buffer_tensor.numel() // dp_size
    tensor_buffer_dp_view.append(param_buffer_tensor[i*param_buffer_size: (i+1)*param_buffer_size])

param_buffer_list_view = []
offset = 0
for i in range(len(param_model_list)):
    start_idx = offset
    offset += param_model_list[i].numel()
    end_idx = offset
    param_buffer_list_view.append(param_buffer_tensor[start_idx:end_idx])

    param_buffer_list_view[-1].copy_(param_model_list[i]*1.1)


N = total_len
quant_bits = 4
groupsize = 1024
groups = N // groupsize
_, avg_time = functionA(unfused_stoquantize, tensor_buffer_dp_view[dp_rank], param_buffer_list_view, param_model_list, groups, quant_bits, quantization_module.Symmetric)

num_bytes = param_buffer_tensor.numel() * param_buffer_tensor.element_size()
print('Sto Quantize')
print(f'time: {avg_time}ms')
print(f'numbytes: {num_bytes}Bytes')
print(f'throughput: {num_bytes / avg_time / 10**6}GB/s')

total len: 453113856
tensor size: 864.24609375 MB, dtype: torch.bfloat16
Sto Quantize
time: 29.323363304138184ms
numbytes: 3624910848Bytes
throughput: 123.61852255496366GB/s


In [26]:
# Speed Test for Stoquantize

dp_size = 4
dp_rank = 0

total_len = sum([tensor.numel() for tensor in param_model_list]) // dp_size
print(f"total len: {total_len}")
print(f"tensor size: {total_len * param_model_list[0].element_size() / 1024 / 1024} MB, dtype: {param_model_list[0].dtype}")


param_buffer_tensor = torch.empty(size=(total_len * dp_size,), dtype=torch.bfloat16, device=dev)
tensor_buffer_dp_view = []
for i in range(dp_size):
    param_buffer_size = param_buffer_tensor.numel() // dp_size
    tensor_buffer_dp_view.append(param_buffer_tensor[i*param_buffer_size: (i+1)*param_buffer_size])

param_buffer_list_view = []
offset = 0
for i in range(len(param_model_list)):
    start_idx = offset
    offset += param_model_list[i].numel()
    end_idx = offset
    param_buffer_list_view.append(param_buffer_tensor[start_idx:end_idx])

    param_buffer_list_view[-1].copy_(param_model_list[i]*1.1)

N = param_buffer_tensor.nelement()
quant_bits = 4
groupsize = 1024
groups = N // groupsize
_, avg_time = functionA(quantization_module.sub_quantize, tensor_buffer_dp_view[dp_rank], param_model_list, 0, groups, quant_bits, quantization_module.Symmetric)

num_bytes = param_buffer_tensor.numel() * param_buffer_tensor.element_size()
print('Sto Quantize')
print(f'time: {avg_time}ms')
print(f'numbytes: {num_bytes}Bytes')
print(f'throughput: {num_bytes / avg_time / 10**6}GB/s')

total len: 453113856
tensor size: 864.24609375 MB, dtype: torch.bfloat16
Sto Quantize
time: 17.266249656677246ms
numbytes: 3624910848Bytes
throughput: 209.94199204099692GB/s
