In [115]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import struct

In [116]:

# generate a,b
def generate_vectors(n):
    a_uniform_distribution = np.random.uniform(low=-10000, high=10000, size=(1, n)) 
    b_uniform_distribution = np.random.choice([0, 1], size=(n, 1))

    return a_uniform_distribution, b_uniform_distribution

In [117]:
#base
def multiply_vectors(a,b):
    a_bfloat16 = a.astype(np.float16)
    result = np.dot(a_bfloat16, b)
    return result

In [118]:
#向量中元素转化为二进制操作
def Modify_tensor(tensor):
    bf_list = tensor.detach().numpy().tolist() # 转成numpy变量再转换成list    
    binary_array = []

    for bf_nums in bf_list:
        for bf_num in bf_nums:
            float_bytes = struct.pack('>f', bf_num)
            byte_array = ''.join(f'{byte:08b}' for byte in float_bytes)
            binary_array.append(byte_array)
    return binary_array

In [119]:
#ReDCIM
def ReDCIM_multiply(a,b):
    a_tensor = torch.from_numpy(a)
    b_tensor = torch.from_numpy(b)
    a_fp32 = a_tensor.float()
    a_modified = Modify_tensor(a_fp32)

    ##预对齐
    a_exp = []
    for number in a_modified:
        extracted_part = number[1:9]
        a_exp.append(extracted_part)
    a_exp_int = [int(binary, 2) for binary in a_exp]
    a_exp_max = max(a_exp_int)
    a_difference = [a_exp_max - value for value in a_exp_int]
    #print(a_exp)

    a_mantissa = []
    for number in a_modified:
        a_mantissa.append(int('1' + number[9:15], 2) if number[0] == '0' else -int('1' + number[9:15], 2))
    #print(a_mantissa)

    a_shifted_mantissa_values = []
    for mant, diff in zip(a_mantissa, a_difference):
        a_shifted_mantissa_values.append(mant >> diff)
    #print(a_shifted_mantissa_values)
    product_mantissa_int = [a * b for a, b in zip(a_shifted_mantissa_values, b_tensor)]
    sum_product_mantissa = sum(product_mantissa_int)

    left_shift = a_exp_max - 127 - 6

    combined_result = sum_product_mantissa << left_shift if left_shift > 0 else sum_product_mantissa >> -left_shift
    return combined_result

In [120]:
#vectors
n = 4
a,b = generate_vectors(n)
print(a)
print(b)

result_standard = multiply_vectors(a, b)
print("Result standard:\n", result_standard)

result_ReDCIM = ReDCIM_multiply(a, b)
print("Result ReDCIM:\n", result_ReDCIM)

[[-4689.95492684 -9889.05418537 -4083.93124047  6658.90451672
  -3929.67320532  7902.67644098 -5266.86168393  8209.07133682
   6599.08127189 -4733.07688859 -1057.48868158 -2748.69407037
   5489.41922024 -7015.36087803  1353.87839026 -9619.29677291
  -8836.26431443  6415.91409874 -7127.89336813  4607.15427072
  -3162.19542479  4678.33989535 -5851.45511059 -7302.16096847
   5723.14336249 -8323.53058082  2363.52764916  8192.31991503
   4983.82610263 -8100.01095352 -6979.23000485  -515.10716037
    310.35246923 -2347.28730532 -5300.51881954  2777.36901513
  -9457.81373317   918.87514332 -1416.7715998    725.61554023
     92.9792336   1608.93978755  -722.54469102  3886.48023585
  -7790.31338191  8501.93074663  3033.30532979  1706.0013306
   8371.61035483   785.5651568   2740.53137738  4020.73664715
  -6107.58262404 -5310.49609163   -37.06539123 -6190.21046404
   7306.92933121  4946.39503953  2040.3868678  -7398.59832246
   3375.7878728    274.78348867 -5869.87557091  4552.75697409
   5934.2