In [264]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import struct

In [265]:
#向量大小
n = 2

# 均匀分布
a_uniform_distribution = np.random.uniform(low=-100, high=100, size=(1,n)) #9.2E-41~3.38E38
b_uniform_distribution = np.random.uniform(low=-100, high=100, size=(n,1))
print(a_uniform_distribution)
print(b_uniform_distribution)

[[ 77.00578965 -53.42176389]]
[[46.03220474]
 [84.72803782]]


In [266]:
#bf16标准乘法
def standard_bf16_multiply(a, b):
    a_tensor = torch.from_numpy(a)
    b_tensor = torch.from_numpy(b)
    a_bf16 = a_tensor.to(torch.bfloat16)
    b_bf16 = b_tensor.to(torch.bfloat16)
    result = torch.mm(a_bf16, b_bf16)
    result_bf16 = result.to(torch.bfloat16)
    return result_bf16

In [267]:
#ReDCIM

#向量中元素转化为二进制操作
def Modify_tensor(tensor):
    bf_list = tensor.detach().numpy().tolist() # 转成numpy变量再转换成list    
    binary_array = []

    for bf_nums in bf_list:
        for bf_num in bf_nums:
            float_bytes = struct.pack('>f', bf_num)
            byte_array = ''.join(f'{byte:08b}' for byte in float_bytes)
            binary_array.append(byte_array)
    return binary_array

#将有符号二进制数转换为整数
def binary_to_signed_int(binary, bit_length):
    if binary >= 2**(bit_length - 1):
        return binary - 2**bit_length
    else:
        return binary

#运算
def ReDCIM_bf16_multiply(a,b):
    a_tensor = torch.from_numpy(a)
    b_tensor = torch.from_numpy(b)
    a_fp32 = a_tensor.float()
    b_fp32 = b_tensor.float()
    a_modified = Modify_tensor(a_fp32)
    b_modified = Modify_tensor(b_fp32)
    print("a二进制表示: " + str(a_modified))
    print("b二进制表示: " + str(b_modified))

    ##预对齐
    a_exp = []
    for number in a_modified:
        extracted_part = number[1:9]
        a_exp.append(extracted_part)
    a_exp_int = [int(binary, 2) for binary in a_exp]
    a_exp_max = max(a_exp_int)
    a_difference = [a_exp_max - value for value in a_exp_int]
    print("a差值：" + str(a_difference))
    print("a的指数: " + str(a_exp))

    a_mantissa = []
    for number in a_modified:
        extracted_part = number[0] + '1' + number [9:15] #输入根据booth编码，舍去尾数最后一位
        a_mantissa.append(extracted_part)
    print("a的尾数 Mantissa+: " + str(a_mantissa))
    
    b_exp = []
    for number in b_modified:
        extracted_part = number[1:9]
        b_exp.append(extracted_part)
    b_exp_int = [int(binary, 2) for binary in b_exp]
    b_exp_max = max(b_exp_int)
    b_difference = [b_exp_max - value for value in b_exp_int]
    print("b差值：" + str(b_difference))
    print("b的指数: " + str(b_exp))

    b_mantissa = []
    for number in b_modified:
        extracted_part = number[0] + '1' + number [9:15] # 权重舍去尾数最后一位，和INT8 匹配
        b_mantissa.append(extracted_part)
    print("b的尾数 Mantissa+: " + str(b_mantissa))
    
    a_shifted_mantissa_values = []
    a_sign_bits = []
    for mant, diff in zip(a_mantissa, a_difference):
        sign = int(mant[0], 2)  # 提取符号位
        a_sign_bits.append(sign)
        value = int(mant[1:], 2) >> diff  #只处理数值部分
        a_shifted_mantissa_values.append(value)
    print("a的尾数 移位后: " + str(a_shifted_mantissa_values))

    b_shifted_mantissa_values = []
    b_sign_bits = []
    for mant, diff in zip(b_mantissa, b_difference):
        sign = int(mant[0], 2)
        b_sign_bits.append(sign)
        value = int(mant[1:], 2) >> diff
        b_shifted_mantissa_values.append(value)
    print("b的尾数 移位后: " + str(b_shifted_mantissa_values))

    ## 尾数相乘
    product_mantissa_values = [a * b for a, b in zip(a_shifted_mantissa_values, b_shifted_mantissa_values)]
    formatted_product_mantissa_values = [format(value, '014b') for value in product_mantissa_values]  #高位的0被省略，但是浮点数从高位向下取，因此必须补齐

    print("formatted_product_mantissa: " + ', '.join(formatted_product_mantissa_values))

    product_mantissa = []
    for i in range(len(formatted_product_mantissa_values)):
        sign = a_sign_bits[i] ^ b_sign_bits[i]
        # 取补码
        if sign == 1:
            inverted_product_mantissa = ''.join('1' if bit == '0' else '0' for bit in formatted_product_mantissa_values[i])
            product = format(int(inverted_product_mantissa, 2) + 1, '014b')
            product = '1' + product  # 补符号位
        else:
            product = '0' + formatted_product_mantissa_values[i] # 补符号位
        product_mantissa.append(product)

    sum_product_mantissa = sum(int(product, 2) for product in product_mantissa)
    sign_bit = (sum_product_mantissa >> (len(product_mantissa[0]) - 1)) & 1
    value_bits = sum_product_mantissa & ((1 << (len(product_mantissa[0]) - 1)) - 1)

    #转换回原码
    if sign_bit == 1:
        # 取补码（取反加一）
        inverted_value_bits = (~value_bits + 1) & ((1 << (len(product_mantissa[0]) - 1)) - 1)
        final_result = inverted_value_bits
    else:
        final_result = value_bits

    final_result_binary = format(final_result, '0{}b'.format(len(product_mantissa[0]) - 1))
    print("新符号位：" + str(sign_bit))
    print("新尾数：" + str(final_result))

    mantissa_val = 0
    for i, bit in enumerate(final_result_binary[:9]):
        mantissa_val += int(bit) * (2 ** (-i + 1))

    print(mantissa_val)
    combined_result = ((-1) ** sign_bit) * (2 ** (a_exp_max + b_exp_max - 254)) * mantissa_val
    return combined_result

In [268]:
#Hybrid_domain
def twos_complement(binary_str):
    """Convert binary string in two's complement to integer."""
    if binary_str[0] == '1':  # if the binary number is negative
        return -1 * (int(''.join('1' if b == '0' else '0' for b in binary_str), 2) + 1)
    else:
        return int(binary_str, 2)

k = 2
#运算
def Hybrid_bf16_multiply(a,b):
    a_tensor = torch.from_numpy(a)
    b_tensor = torch.from_numpy(b)
    a_fp32 = a_tensor.float()
    b_fp32 = b_tensor.float()
    a_modified = Modify_tensor(a_fp32)
    b_modified = Modify_tensor(b_fp32)

    ##预对齐
    a_exp = []
    for number in a_modified:
        extracted_part = number[1:9]
        a_exp.append(extracted_part)
    a_exp_int = [int(binary, 2) for binary in a_exp]
    a_exp_max = max(a_exp_int)
    a_difference = [a_exp_max - value for value in a_exp_int]

    a_mantissa = []
    for number in a_modified:
        if number[0] == '1':
            inverted_part ='0' + ''.join('1' if bit == '0' else '0' for bit in number[9:16]) #取反码
            complement = format(int(inverted_part, 2) + 1, '08b') #取补码
            extracted_part = number[0] + complement #加符号位
        else:
            extracted_part = number[0] + '1' + number[9:16]
        a_mantissa.append(extracted_part)
    print("a的尾数9b补码: " + str(a_mantissa))
    
    b_exp = []
    for number in b_modified:
        extracted_part = number[1:9]
        b_exp.append(extracted_part)
    b_exp_int = [int(binary, 2) for binary in b_exp]
    b_exp_max = max(b_exp_int)
    b_difference = [b_exp_max - value for value in b_exp_int]

    b_mantissa = []
    for number in b_modified:
        if number[0] == '1':
            inverted_part ='0' + ''.join('1' if bit == '0' else '0' for bit in number[9:16])
            complement = format(int(inverted_part, 2) + 1, '08b') 
            extracted_part = number[0] + complement
        else:
            extracted_part = number[0] + '1' + number[9:16]
        b_mantissa.append(extracted_part)
    print("b的尾数9b补码: " + str(b_mantissa))
    
    a_shifted_mantissa_values = []
    for mant, diff in zip(a_mantissa, a_difference):
        sign = int(mant[0], 2)
        mant_padded = mant + '0' * k
        if diff > 0:
            shifted_value_binary = (sign * '1' if sign == 1 else '0') * diff + mant_padded[:-diff]
        else:
            shifted_value_binary = mant_padded
        a_shifted_mantissa_values.append(shifted_value_binary)
    print("a的尾数 移位扩展后: " + str(a_shifted_mantissa_values))

    b_shifted_mantissa_values = []
    for mant, diff in zip(b_mantissa, b_difference):
        sign = int(mant[0], 2)
        mant_padded = mant + '0' * k
        if diff > 0:
            shifted_value_binary = (sign * '1' if sign == 1 else '0') * diff + mant_padded[:-diff]
        else:
            shifted_value_binary = mant_padded
        b_shifted_mantissa_values.append(shifted_value_binary)
    print("b的尾数 移位扩展后: " + str(b_shifted_mantissa_values))

    ## 尾数相乘
    product_mantissa = [twos_complement(a) * twos_complement(b) for a, b in zip(a_shifted_mantissa_values, b_shifted_mantissa_values)]
    sum_product_mantissa = sum(product_mantissa)
    sum_product_mantissa_binary = format(sum_product_mantissa,'023b')#高位的0被省略，但是浮点数从高位向下取，因此必须补齐
    sign = 1 if sum_product_mantissa < 0 else 0
    if sign == 1:
        inverted_value = ~int(sum_product_mantissa_binary, 2) + 1
        inverted_value_bits = format(inverted_value, '023b')
        final_result = inverted_value_bits[1:]
    else:
        final_result = sum_product_mantissa_binary[1:]

    print("新符号位：" + str(sign))
    print("新尾数：" + str(final_result))

    mantissa_val = 0
    for i, bit in enumerate(final_result):
        mantissa_val += int(bit) * (2 ** (-i + 3)) ##正常乘，小数点在MSB-2，带符号乘，小数点在MSB-4

    print(mantissa_val)
    combined_result = ((-1) ** int(sign)) * (2 ** (a_exp_max + b_exp_max - 254)) * mantissa_val
    return combined_result

In [309]:
#ReDCIM_new
#运算
def ReDCIM_bf16_multiply_new(a,b):
    a_tensor = torch.from_numpy(a)
    b_tensor = torch.from_numpy(b)
    a_fp32 = a_tensor.float()
    b_fp32 = b_tensor.float()
    a_modified = Modify_tensor(a_fp32)
    b_modified = Modify_tensor(b_fp32)

    ##预对齐
    a_exp = []
    for number in a_modified:
        extracted_part = number[1:9]
        a_exp.append(extracted_part)
    a_exp_int = [int(binary, 2) for binary in a_exp]
    a_exp_max = max(a_exp_int)
    a_difference = [a_exp_max - value for value in a_exp_int]

    a_mantissa = []
    for number in a_modified:
        if number[0] == '1':
            inverted_part ='0' + ''.join('1' if bit == '0' else '0' for bit in number[9:15]) #取反码
            complement = format(int(inverted_part, 2) + 1, '07b') #取补码
            extracted_part = number[0] + complement #加符号位
        else:
            extracted_part = number[0] + '1' + number[9:15]
        a_mantissa.append(extracted_part)
    print("a的尾数8b补码: " + str(a_mantissa))
    
    b_exp = []
    for number in b_modified:
        extracted_part = number[1:9]
        b_exp.append(extracted_part)
    b_exp_int = [int(binary, 2) for binary in b_exp]
    b_exp_max = max(b_exp_int)
    b_difference = [b_exp_max - value for value in b_exp_int]

    b_mantissa = []
    for number in b_modified:
        if number[0] == '1':
            inverted_part ='0' + ''.join('1' if bit == '0' else '0' for bit in number[9:15])
            complement = format(int(inverted_part, 2) + 1, '07b') 
            extracted_part = number[0] + complement
        else:
            extracted_part = number[0] + '1' + number[9:15]
        b_mantissa.append(extracted_part)
    print("b的尾数8b补码: " + str(b_mantissa))
    
    a_shifted_mantissa_values = []
    for mant, diff in zip(a_mantissa, a_difference):
        sign = int(mant[0], 2)
        if diff > 0:
            shifted_value_binary = (sign * '1' if sign == 1 else '0') * diff + mant[:-diff]
        else:
            shifted_value_binary = mant
        a_shifted_mantissa_values.append(shifted_value_binary)
    print("a的尾数 移位扩展后: " + str(a_shifted_mantissa_values))

    b_shifted_mantissa_values = []
    for mant, diff in zip(b_mantissa, b_difference):
        sign = int(mant[0], 2)
        if diff > 0:
            shifted_value_binary = (sign * '1' if sign == 1 else '0') * diff + mant[:-diff]
        else:
            shifted_value_binary = mant
        b_shifted_mantissa_values.append(shifted_value_binary)
    print("b的尾数 移位扩展后: " + str(b_shifted_mantissa_values))

    ## 尾数相乘
    # 由于是补码相乘，先进行带有符号位的扩展，即扩展到16位
    sum_product_mantissa = 0
    for a,b in zip(a_shifted_mantissa_values, b_shifted_mantissa_values):
        a_extended = a if a[0] == '0' else '1'*8 + a
        b_extended = b if b[0] == '0' else '1'*8 + b
        product = bin(int(a_extended, 2) * int(b_extended, 2))
        product_17_bits = product[-17:]
        print("部分积: " + str(product_17_bits))
        sum_product_mantissa += int(product_17_bits, 2)
    sum_product_mantissa_binary = format(sum_product_mantissa,'017b')
    print("尾数：" + str(sum_product_mantissa_binary))
    if sum_product_mantissa_binary[0] == 1 :
        inverted_value = ~int(sum_product_mantissa_binary, 2) + 1
        inverted_value_bits = format(inverted_value, '017b')
        final_result = inverted_value_bits[1:]
    else:
        final_result = sum_product_mantissa_binary[1:]

    print("新符号位：" + str(sum_product_mantissa_binary[0]))
    print("新尾数：" + str(final_result))

    mantissa_val = 0
    for i, bit in enumerate(final_result):
        mantissa_val += int(bit) * (2 ** (-i + 3)) ##正常乘，小数点在MSB-2，带符号乘，小数点在MSB-4

    print(mantissa_val)
    combined_result = ((-1) ** int(sum_product_mantissa_binary[0])) * (2 ** (a_exp_max + b_exp_max - 254)) * mantissa_val
    return combined_result

In [270]:
result_standard = standard_bf16_multiply(a_uniform_distribution, b_uniform_distribution)
print("Result standard:\n", result_standard)

Result standard:
 tensor([[-980.]], dtype=torch.bfloat16)


In [310]:
result_ReDCIM = ReDCIM_bf16_multiply_new(a_uniform_distribution, b_uniform_distribution)
print("Result ReDCIM:\n", result_ReDCIM)



a的尾数8b补码: ['01001101', '10010110']
b的尾数8b补码: ['01011100', '01010100']
a的尾数 移位扩展后: ['01001101', '11001011']
b的尾数 移位扩展后: ['00101110', '01010100']
部分积: 0b110111010110
部分积: 11110111010011100
尾数：11111110001110010
新符号位：1
新尾数：1111110001110010
15.77783203125
Result ReDCIM:
 -64626.0


In [272]:
result_Hybrid = Hybrid_bf16_multiply(a_uniform_distribution, b_uniform_distribution)
print("result_Hybrid:\n", result_Hybrid)

a的尾数9b补码: ['010011010', '100101011']
b的尾数9b补码: ['010111000', '010101001']
a的尾数 移位扩展后: ['01001101000', '11001010110']
b的尾数 移位扩展后: ['00101110000', '01010100100']
新符号位：1
新尾数：0000001110111101101000
0.233795166015625
result_Hybrid:
 -957.625
