In [37]:
import torch
import numpy as np
import struct

In [38]:
# generate a,b
def generate_vectors(n):
    a_uniform_distribution = np.random.uniform(low=-1000, high=1000, size=(1, n))
    b_uniform_distribution = np.random.choice([0, 1], size=(n, 1))

    return a_uniform_distribution, b_uniform_distribution

In [39]:
#base
def multiply_vectors(a,b):
    a_bfloat16 = a.astype(np.float16)
    result = np.dot(a_bfloat16, b)
    return result

In [40]:
#向量中元素转化为二进制操作
def Modify_tensor(tensor):
    bf_list = tensor.detach().numpy().tolist() # 转成numpy变量再转换成list    
    binary_array = []

    for bf_nums in bf_list:
        for bf_num in bf_nums:
            float_bytes = struct.pack('>f', bf_num)
            byte_array = ''.join(f'{byte:08b}' for byte in float_bytes)
            binary_array.append(byte_array)
    return binary_array

In [41]:
#Compensation
def Compensation_multiply(a,b,k):
    a_tensor = torch.from_numpy(a)
    b_tensor = torch.from_numpy(b)
    a_fp32 = a_tensor.float()
    a_modified = Modify_tensor(a_fp32)

    ##预对齐
    a_exp = []
    for number in a_modified:
        extracted_part = number[1:9]
        a_exp.append(extracted_part)
    a_exp_int = [int(binary, 2) for binary in a_exp]
    a_exp_max = max(a_exp_int)
    a_difference = [a_exp_max - value for value in a_exp_int]
    #print(a_exp)

    a_mantissa = []
    for number in a_modified:
        #a_mantissa.append(int('1' + number[9:15] , 2) if number[0] == '0' else -int('1' + number[9:15], 2))
        a_mantissa.append(int('1' + number[9:15] + k * '0', 2) if number[0] == '0' else -int('1' + number[9:15]+ k * '0', 2))
    

    a_shifted_mantissa_values = []
    for mant, diff in zip(a_mantissa, a_difference):
        a_shifted_mantissa_values.append(mant >> diff)
    #print(a_shifted_mantissa_values)
    product_mantissa_int = [a * b for a, b in zip(a_shifted_mantissa_values, b_tensor)]
    sum_product_mantissa = sum(product_mantissa_int)

    left_shift = a_exp_max - 127 - 6 - k

    combined_result = sum_product_mantissa << left_shift if left_shift > 0 else sum_product_mantissa >> -left_shift
    return combined_result

In [42]:
#compare
def compare(results_a, results_b):
    differences = []
    for result_standard, ReDCIM in zip(results_a, results_b):
        difference = abs(torch.abs((torch.from_numpy(result_standard)) - ReDCIM)/result_standard)
        differences.append(difference.item())

    mean_difference = np.mean(differences)
    return mean_difference

def run_calculations_with_actual_implementation(n):
    results_standard = []
    results_ReDCIM = []
    results_Compensation_1 = []
    results_Compensation_2 = []
    results_Compensation_4 = []
    
    for _ in range(n):
        a, b = generate_vectors(512)
        result_standard = multiply_vectors(a, b)
        result_ReDCIM = Compensation_multiply(a, b,0)
        result_Compensation_1 = Compensation_multiply(a, b,1)
        result_Compensation_2 = Compensation_multiply(a, b,2)
        result_Compensation_4 = Compensation_multiply(a, b,4)

        results_standard.append(result_standard)
        results_ReDCIM.append(result_ReDCIM)
        results_Compensation_1.append(result_Compensation_1)
        results_Compensation_2.append(result_Compensation_2)
        results_Compensation_4.append(result_Compensation_4)


    ReDCIM_main_difference = compare(results_standard, results_ReDCIM)
    Compensation_1_main_difference = compare(results_standard, results_Compensation_1)
    Compensation_2_main_difference = compare(results_standard, results_Compensation_2)
    Compensation_4_main_difference = compare(results_standard, results_Compensation_4)
    
    return (ReDCIM_main_difference, Compensation_1_main_difference, Compensation_2_main_difference, Compensation_4_main_difference)


Differences = run_calculations_with_actual_implementation(1000)

print("ReDCIM 平均误差" + str(Differences[0]))
print("补充1位 平均误差" + str(Differences[1]))
print("补充2位 平均误差" + str(Differences[2]))
print("补充4位 平均误差" + str(Differences[3]))


ReDCIM 平均误差0.247671618961944
补充1位 平均误差0.06222280636713665
补充2位 平均误差0.01768880693776994
补充4位 平均误差0.014205017499889437
