In [115]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import struct

In [116]:

def generate_vectors():
    # 均匀分布
    a_uniform_distribution = np.random.uniform(low=-1000000, high=1000000, size=(1, 32)) #9.2E-41~3.38E38
    b_uniform_distribution = np.random.uniform(low=-1000000, high=1000000, size=(32, 1))

    return a_uniform_distribution, b_uniform_distribution

In [117]:
#bf16标准乘法
def standard_bf16_multiply(a, b):
    a_tensor = torch.from_numpy(a)
    b_tensor = torch.from_numpy(b)
    a_bf16 = a_tensor.to(torch.bfloat16)
    b_bf16 = b_tensor.to(torch.bfloat16)
    result = torch.mm(a_bf16, b_bf16)
    result_bf16 = result.to(torch.bfloat16)
    return result_bf16

In [118]:
def twos_complement(binary_str):
    if binary_str[0] == '1':  # if the binary number is negative
        return -1 * (int(''.join('1' if b == '0' else '0' for b in binary_str), 2) + 1)
    else:
        return int(binary_str, 2)
    
#向量中元素转化为二进制操作
def Modify_tensor(tensor):
    bf_list = tensor.detach().numpy().tolist() # 转成numpy变量再转换成list    
    binary_array = []

    for bf_nums in bf_list:
        for bf_num in bf_nums:
            float_bytes = struct.pack('>f', bf_num)
            byte_array = ''.join(f'{byte:08b}' for byte in float_bytes)
            binary_array.append(byte_array)
    return binary_array

#将有符号二进制数转换为整数
def binary_to_signed_int(binary, bit_length):
    if binary >= 2**(bit_length - 1):
        return binary - 2**bit_length
    else:
        return binary


In [119]:
#ReDCIM
#运算
def ReDCIM_bf16_multiply(a,b):
    a_tensor = torch.from_numpy(a)
    b_tensor = torch.from_numpy(b)
    a_fp32 = a_tensor.float()
    b_fp32 = b_tensor.float()
    a_modified = Modify_tensor(a_fp32)
    b_modified = Modify_tensor(b_fp32)

    ##预对齐
    a_exp = []
    for number in a_modified:
        extracted_part = number[1:9]
        a_exp.append(extracted_part)
    a_exp_int = [int(binary, 2) for binary in a_exp]
    a_exp_max = max(a_exp_int)
    a_difference = [a_exp_max - value for value in a_exp_int]

    a_mantissa = []
    for number in a_modified:
        if number[0] == '1':
            inverted_part ='0' + ''.join('1' if bit == '0' else '0' for bit in number[9:15]) #取反码
            complement = format(int(inverted_part, 2) + 1, '07b') #取补码
            extracted_part = number[0] + complement #加符号位
        else:
            extracted_part = number[0] + '1' + number[9:15]
        a_mantissa.append(extracted_part)
    
    b_exp = []
    for number in b_modified:
        extracted_part = number[1:9]
        b_exp.append(extracted_part)
    b_exp_int = [int(binary, 2) for binary in b_exp]
    b_exp_max = max(b_exp_int)
    b_difference = [b_exp_max - value for value in b_exp_int]

    b_mantissa = []
    for number in b_modified:
        if number[0] == '1':
            inverted_part ='0' + ''.join('1' if bit == '0' else '0' for bit in number[9:15])
            complement = format(int(inverted_part, 2) + 1, '07b') 
            extracted_part = number[0] + complement
        else:
            extracted_part = number[0] + '1' + number[9:15]
        b_mantissa.append(extracted_part)
    
    a_shifted_mantissa_values = []
    for mant, diff in zip(a_mantissa, a_difference):
        sign = int(mant[0], 2)
        if diff > 0:
            shifted_value_binary = (sign * '1' if sign == 1 else '0') * diff + mant[:-diff]
        else:
            shifted_value_binary = mant
        a_shifted_mantissa_values.append(shifted_value_binary)

    b_shifted_mantissa_values = []
    for mant, diff in zip(b_mantissa, b_difference):
        sign = int(mant[0], 2)
        if diff > 0:
            shifted_value_binary = (sign * '1' if sign == 1 else '0') * diff + mant[:-diff]
        else:
            shifted_value_binary = mant
        b_shifted_mantissa_values.append(shifted_value_binary)

    ## 尾数相乘
    product_mantissa = [twos_complement(a) * twos_complement(b) for a, b in zip(a_shifted_mantissa_values, b_shifted_mantissa_values)]
    sum_product_mantissa = sum(product_mantissa)
    sum_product_mantissa_binary = format(sum_product_mantissa,'017b')#高位的0被省略，但是浮点数从高位向下取，因此必须补齐
    sign = 1 if sum_product_mantissa < 0 else 0
    if sign == 1:
        inverted_value = ~int(sum_product_mantissa_binary, 2) + 1
        inverted_value_bits = format(inverted_value, '017b')
        final_result = inverted_value_bits[1:]
    else:
        final_result = sum_product_mantissa_binary[1:]

    mantissa_val = 0
    for i, bit in enumerate(final_result):
        mantissa_val += int(bit) * (2 ** (-i + 3)) ##正常乘，小数点在MSB-2，带符号乘，小数点在MSB-4

    combined_result = ((-1) ** int(sign)) * (2 ** (a_exp_max + b_exp_max - 254)) * mantissa_val
    return combined_result

In [120]:
#Hybrid_domain
k = 2
#运算
def Hybrid_bf16_multiply(a,b):
    a_tensor = torch.from_numpy(a)
    b_tensor = torch.from_numpy(b)
    a_fp32 = a_tensor.float()
    b_fp32 = b_tensor.float()
    a_modified = Modify_tensor(a_fp32)
    b_modified = Modify_tensor(b_fp32)

    ##预对齐
    a_exp = []
    for number in a_modified:
        extracted_part = number[1:9]
        a_exp.append(extracted_part)
    a_exp_int = [int(binary, 2) for binary in a_exp]
    a_exp_max = max(a_exp_int)
    a_difference = [a_exp_max - value for value in a_exp_int]

    a_mantissa = []
    for number in a_modified:
        if number[0] == '1':
            inverted_part ='0' + ''.join('1' if bit == '0' else '0' for bit in number[9:16])
            complement = format(int(inverted_part, 2) + 1, '08b') 
            extracted_part = number[0] + complement
        else:
            extracted_part = number[0] + '1' + number[9:16]
        a_mantissa.append(extracted_part)
    
    b_exp = []
    for number in b_modified:
        extracted_part = number[1:9]
        b_exp.append(extracted_part)
    b_exp_int = [int(binary, 2) for binary in b_exp]
    b_exp_max = max(b_exp_int)
    b_difference = [b_exp_max - value for value in b_exp_int]

    b_mantissa = []
    for number in b_modified:
        if number[0] == '1':
            inverted_part ='0' + ''.join('1' if bit == '0' else '0' for bit in number[9:16])
            complement = format(int(inverted_part, 2) + 1, '08b') 
            extracted_part = number[0] + complement
        else:
            extracted_part = number[0] + '1' + number[9:16]
        b_mantissa.append(extracted_part)
    
    a_shifted_mantissa_values = []
    for mant, diff in zip(a_mantissa, a_difference):
        sign = int(mant[0], 2)
        mant_padded = mant + '0' * k
        if diff > 0:
            shifted_value_binary = (sign * '1' if sign == 1 else '0') * diff + mant_padded[:-diff]
        else:
            shifted_value_binary = mant_padded
        a_shifted_mantissa_values.append(shifted_value_binary)

    b_shifted_mantissa_values = []
    for mant, diff in zip(b_mantissa, b_difference):
        sign = int(mant[0], 2)
        mant_padded = mant + '0' * k
        if diff > 0:
            shifted_value_binary = (sign * '1' if sign == 1 else '0') * diff + mant_padded[:-diff]
        else:
            shifted_value_binary = mant_padded
        b_shifted_mantissa_values.append(shifted_value_binary)


    ## 尾数相乘
    product_mantissa = [twos_complement(a) * twos_complement(b) for a, b in zip(a_shifted_mantissa_values, b_shifted_mantissa_values)]
    formatted_product_mantissa_values = [format(value, '022b') for value in product_mantissa]  #高位的0被省略，但是浮点数从高位向下取，因此必须补齐

    sum_product_mantissa = sum(product_mantissa)
    sum_product_mantissa_binary = format(sum_product_mantissa,'023b')
    sign = 1 if sum_product_mantissa < 0 else 0
    if sign == 1:
        inverted_value = ~int(sum_product_mantissa_binary, 2) + 1
        inverted_value_bits = format(inverted_value, '023b')
        final_result = inverted_value_bits[1:]
    else:
        final_result = sum_product_mantissa_binary[1:]

    mantissa_val = 0
    for i, bit in enumerate(final_result):
        mantissa_val += int(bit) * (2 ** (-i + 3)) ##正常乘，小数点在MSB-2，带符号乘，小数点在MSB-4

    combined_result = ((-1) ** int(sign)) * (2 ** (a_exp_max + b_exp_max - 254)) * mantissa_val
    return combined_result

In [121]:
def compare_and_calculate_y(results_standard, results_ReDCIM):
    differences = []
    for result_standard, ReDCIM in zip(results_standard, results_ReDCIM):
        difference = abs(torch.abs(result_standard - ReDCIM)/result_standard)
        differences.append(difference.item())

    mean_difference = np.mean(differences)
    return mean_difference

def run_calculations_with_actual_implementation(n):
    results_standard = []
    results_ReDCIM = []
    results_Hybrid = []
    
    for _ in range(500):
        a, b = generate_vectors()
        result_standard = standard_bf16_multiply(a, b)
        result_ReDCIM = ReDCIM_bf16_multiply(a, b)
        result_Hybrid = Hybrid_bf16_multiply(a,b)

        results_standard.append(result_standard)
        results_ReDCIM.append(result_ReDCIM)
        results_Hybrid.append(result_Hybrid)

    ReDCIM_main_difference = compare_and_calculate_y(results_standard, results_ReDCIM)
    Hybrid_main_difference = compare_and_calculate_y(results_standard, results_Hybrid)
    return [ReDCIM_main_difference,Hybrid_main_difference]

ReDCIM_main_difference = run_calculations_with_actual_implementation(5000)[0]
Hybrid_main_difference = run_calculations_with_actual_implementation(5000)[1]

print("ReDCIM 平均误差 (normalized) " + str(ReDCIM_main_difference))
print("Hybrid 平均误差 (normalized) " + str(Hybrid_main_difference))

ReDCIM 平均误差 (normalized) 0.05588189697265625
Hybrid 平均误差 (normalized) 0.0396053466796875
