In [None]:
# Copyright 2023-2024 Bytedance Ltd. and/or its affiliates 


# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at 

#     http://www.apache.org/licenses/LICENSE-2.0 

# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
# See the License for the specific language governing permissions and 
# limitations under the License. 

In [1]:
import torch
from tool_function import *
from dequant_function import *
import os
from pytorch_memlab import LineProfiler, profile
dev = 'cuda'

In [None]:
pkg_path = "../../tools/jet_quant_cuda"
print('pkg path:', pkg_path)
quantization_module = build_and_import_module(pkg_path, 'quantization_cuda')

In [3]:
import time

def functionA(func, *args, **kwargs):
    """
    Measures the GPU running time of the given function using CUDA events.
    
    Parameters:
    - func: The function to be measured.
    - *args: Positional arguments to pass to the function.
    - **kwargs: Keyword arguments to pass to the function.
    
    Returns:
    - result: The result of the function execution.
    - elapsed_time_ms: The time taken to execute the function on the GPU in milliseconds.
    """
    # Ensure CUDA is available
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available.")
    
    active = 20

    # Create CUDA events for timing
    # start_event = torch.cuda.Event(enable_timing=True)
    # end_event = torch.cuda.Event(enable_timing=True)
    
    # warmup
    for _ in range(10):
        _ = func(*args, **kwargs)

    # Synchronize and empty the cache before starting
    torch.cuda.synchronize(device=dev)
    
    # Record the start event
    # start_event.record()
    begin = time.time()
    # Call the function with provided arguments
    for _ in range(active):
        result = func(*args, **kwargs)
    torch.cuda.synchronize()
    elapsed_time_s = (time.time()-begin)
    elapsed_time_s = elapsed_time_s / active
    elapsed_time_ms = elapsed_time_s * 1000
    # # Record the end event
    # end_event.record()
    
    # # Wait for the events to be recorded
    # torch.cuda.synchronize(device=dev)
    
    # Calculate the elapsed time
    # elapsed_time_ms = start_event.elapsed_time(end_event)
    # elapsed_time_ms = elapsed_time_ms / active

    return result, elapsed_time_ms

In [4]:
def unittest_deqaunt_cuda_unfused(param_list, param_buffer, groupsize=128, quant_bits=4, quant_module=None):

    original_param_buffer = param_buffer.clone()
    tensor_type = param_buffer.dtype
    param_buffer_listview = []
    
    offset = 0
    for i in range(len(param_list)):
        param = param_list[i]
        param_buffer_listview.append(param_buffer[offset : offset+param.numel()])
        offset += param.numel()

    # for i in range(len(param_list)):
    #     param_buffer_listview[i].sub_(param_list[i])
    torch._foreach_sub_(param_buffer_listview, param_list)
    original_param_buffer_delta = param_buffer.clone()
    print('after sub', param_buffer)

    # stochastic quantize kernel
    N = param_buffer.nelement()
    groups = N // groupsize
    quant_tensor_cuda, quant_scales_cuda = quant_module.stochastic_quantize(param_buffer, groups, quant_bits, quant_module.Symmetric)
    
    if tensor_type is torch.bfloat16:
        quant_module.dequantize_bf16(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)
    elif tensor_type is torch.float32:
        quant_module.dequantize_fp32(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)
    elif tensor_type is torch.float16:
        quant_module.dequantize_half(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)
    else:
        assert(False), "dequant_type is not supported"

    abs_error_norm, rela_error_norm = analysis_diff(original_param_buffer_delta, param_buffer)
    print(f"unfused dequantization&add, weight_diff absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}")

    torch._foreach_add_(param_buffer_listview, param_list)

    abs_error_norm, rela_error_norm = analysis_diff(original_param_buffer, param_buffer)
    print(f"unfused dequantization&add, weight_diff/weight absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}")

In [5]:
def unittest_deqaunt_cuda_fused(param_list, param_buffer, dp_param_offset, groupsize=128, quant_bits=4, quant_module=None):

    original_param_buffer = param_buffer.clone()
    tensor_type = param_buffer.dtype
    param_buffer_listview = []
    
    offset = 0
    for i in range(len(param_list)):
        param = param_list[i]
        param_buffer_listview.append(param_buffer[offset : offset+param.numel()])
        offset += param.numel()

    # for i in range(len(param_list)):
    #     param_buffer_listview[i].sub_(param_list[i])
    # torch._foreach_sub_(param_buffer_listview, param_list)
    original_param_buffer_delta = param_buffer.clone()
    print('after sub', param_buffer)

    # stochastic quantize kernel
    N = param_buffer.nelement()
    groups = N // groupsize
    # quant_tensor_cuda, quant_scales_cuda = quant_module.stochastic_quantize(param_buffer, groups, quant_bits, quant_module.Symmetric)
    quant_tensor_cuda, quant_scales_cuda = quant_module.sub_quantize(param_buffer, param_list, dp_param_offset, groups, quant_bits, quant_module.Symmetric)

    if tensor_type is torch.bfloat16:
        quant_module.dequantize_add_bf16(quant_tensor_cuda, quant_scales_cuda, param_buffer, param_list, dp_param_offset, groups, quant_bits, quant_module.Symmetric)
    elif tensor_type is torch.float32:
        # quant_module.dequantize_fp32(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)
        assert(False), "dequant_type is not supported"
    elif tensor_type is torch.float16:
        # quant_module.dequantize_half(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)
        assert(False), "dequant_type is not supported"
    else:
        assert(False), "dequant_type is not supported"

    # abs_error_norm, rela_error_norm = analysis_diff(original_param_buffer_delta, param_buffer)
    # print(f"unfused dequantization&add, weight_diff absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}")

    # torch._foreach_add_(param_buffer_listview, param_list)

    abs_error_norm, rela_error_norm = analysis_diff(original_param_buffer, param_buffer)
    print(f"unfused dequantization&add, weight_diff/weight absolute error norm: {abs_error_norm}, relative error norm: {rela_error_norm}")

In [None]:
tensor1 = torch.randn((1024 * 8,), dtype=torch.bfloat16, device=dev)
tensor2 = torch.randn((1024 * 16,), dtype=torch.bfloat16, device=dev)
tensor3 = torch.randn((1024 * 32,), dtype=torch.bfloat16, device=dev)

param_list = [tensor1, tensor2, tensor3]
total_len = sum([tensor.numel() for tensor in param_list])
print(f"total len: {total_len}")

param_buffer = torch.zeros(size=(total_len + 2048,), dtype=torch.bfloat16, device=dev)
param_buffer_list_view = []

offset = 0
for i in range(len(param_list)):
    start_idx = offset
    offset += param_list[i].numel()
    end_idx = offset
    param_buffer[start_idx:end_idx].copy_(param_list[i]*1.1)

print(param_list[0])
print(param_buffer)


In [7]:
dp_size = 4
dp_rank = 2
N = param_buffer.numel() // dp_size
# unittest_deqaunt_cuda_unfused(param_list=param_list, param_buffer=param_buffer[N*dp_rank: N*(dp_rank+1)], groupsize=2048, quant_bits=4,quant_module=quantization_module)

In [None]:
dp_size = 4
dp_rank = 2
N = param_buffer.numel() // dp_size

unittest_deqaunt_cuda_fused(param_list=param_list, param_buffer=param_buffer[N*dp_rank: N*(dp_rank+1)], dp_param_offset=N*dp_rank, groupsize=512, quant_bits=4,quant_module=quantization_module)

In [None]:
param_model_list = []
hidden_size = 2048
for i in range(24):
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))

    param_model_list.append(torch.randn((hidden_size * 3 * hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size * hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))

    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))

    param_model_list.append(torch.randn((hidden_size * 4 * hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size * hidden_size * 4,), dtype=torch.bfloat16, device=dev))
    param_model_list.append(torch.randn((hidden_size,), dtype=torch.bfloat16, device=dev))

# for i in range(1):
#     param_model_list.append(torch.randn((hidden_size * 4 * hidden_size,), dtype=torch.bfloat16, device=dev))
#     param_model_list.append(torch.randn((hidden_size * hidden_size,), dtype=torch.bfloat16, device=dev))
#     param_model_list.append(torch.randn((hidden_size * hidden_size * 4,), dtype=torch.bfloat16, device=dev))
total_len = sum([tensor.numel() for tensor in param_model_list])
print(f"total len: {total_len}")
param_buffer = torch.zeros(size=(total_len,), dtype=torch.bfloat16, device=dev)
param_buffer_list_view = []

offset = 0
for i in range(len(param_model_list)):
    start_idx = offset
    offset += param_model_list[i].numel()
    end_idx = offset
    param_buffer[start_idx:end_idx].copy_(param_model_list[i]*1.1)

    param_buffer_list_view.append(param_buffer[start_idx:end_idx])

dp_size = 1
dp_rank = 0

tensor_buffer_dp_view = []
for i in range(dp_size):
    param_buffer_size = param_buffer.numel() // dp_size
    tensor_buffer_dp_view.append(param_buffer[i*param_buffer_size: (i+1)*param_buffer_size])

groupsize = 512
N = tensor_buffer_dp_view[dp_rank].nelement()
groups = N // groupsize
quant_tensor_cuda, quant_scales_cuda = quantization_module.sub_quantize(tensor_buffer_dp_view[dp_rank], param_list, dp_rank * tensor_buffer_dp_view[0].numel(), groups, 4, quantization_module.Symmetric)


In [None]:
# Speed Test for Stoquantize
def unfused_dequantize(param_buffer_list_view, param_model_list, quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module):
    quant_module.dequantize_bf16(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)
    torch._foreach_add_(param_buffer_list_view, param_model_list)
    # for i in range(len(param_model_list)):
    #     param_buffer_listview[i].add_(param_model_list[i])



total_len = sum([tensor.numel() for tensor in param_model_list]) // dp_size
print(f"total len: {total_len}")
print(f"tensor size: {total_len * param_model_list[0].element_size() / 1024 / 1024} MB, dtype: {param_model_list[0].dtype}")


N = total_len
quant_bits = 4
groupsize = 1024
groups = N // groupsize
_, avg_time = functionA(unfused_dequantize, param_buffer_list_view, param_model_list, quant_tensor_cuda, quant_scales_cuda, tensor_buffer_dp_view[dp_rank], groups, quant_bits, quantization_module)

num_bytes = tensor_buffer_dp_view[dp_rank].numel() * tensor_buffer_dp_view[dp_rank].element_size()
print('unfused Dequantization')
print(f'time: {avg_time}ms')
print(f'numbytes: {num_bytes}Bytes')
print(f'throughput: {num_bytes / avg_time / 10**6}GB/s')

In [None]:
# Speed Test for Stoquantize
def fused_dequantize(dp_param_offset, param_model_list, quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module):
    quant_module.dequantize_add_bf16(quant_tensor_cuda, quant_scales_cuda, param_buffer, param_model_list, dp_param_offset, groups, quant_bits, quant_module.Symmetric)
    # quant_module.dequantize_bf16(quant_tensor_cuda, quant_scales_cuda, param_buffer, groups, quant_bits, quant_module.Symmetric)
    # torch._foreach_add_(param_buffer_list_view, param_model_list)
    # for i in range(len(param_model_list)):
    #     param_buffer_listview[i].add_(param_model_list[i])



total_len = sum([tensor.numel() for tensor in param_model_list]) // dp_size
print(f"total len: {total_len}")
print(f"tensor size: {total_len * param_model_list[0].element_size() / 1024 / 1024} MB, dtype: {param_model_list[0].dtype}")


N = total_len
quant_bits = 4
groupsize = 1024
groups = N // groupsize
_, avg_time = functionA(fused_dequantize, dp_rank * tensor_buffer_dp_view[0].numel(), param_model_list, quant_tensor_cuda, quant_scales_cuda, tensor_buffer_dp_view[dp_rank], groups, quant_bits, quantization_module)

num_bytes = tensor_buffer_dp_view[dp_rank].numel() * tensor_buffer_dp_view[dp_rank].element_size()
print('unfused Dequantization')
print(f'time: {avg_time}ms')
print(f'numbytes: {num_bytes}Bytes')
print(f'throughput: {num_bytes / avg_time / 10**6}GB/s')