In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def to_mx_format(tensor, emax_elem=127):
    max_val = torch.max(torch.abs(tensor))
    shared_exp = torch.floor(torch.log2(max_val)) - emax_elem
    X = torch.pow(2, shared_exp)
    P = (tensor / X).round()
    return X, P

def from_mx_format(X, P):
    return X * P

In [3]:
class MXNeuralNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MXNeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.fc2 = nn.Linear(hidden_size, output_size, bias=False)

    def forward(self, x):
        # feed forward pass with using MX format
        X1, P1 = to_mx_format(self.fc1.weight) # convert weight tensor into MX format for layer 1
        x = F.linear(x, from_mx_format(X1, P1))  # apply weight tensor with MX format
        x = F.relu(x) 
        X2, P2 = to_mx_format(self.fc2.weight)  # convert weight tensor into MX format for layer 2
        x = F.linear(x, from_mx_format(X2, P2))  # apply weight tensor with MX format
        return x

input_size = 3
hidden_size = 5
output_size = 2
model = MXNeuralNetwork(input_size, hidden_size, output_size)

x = torch.randn(4, input_size)

# feed forward pass
output = model(x)
print("Result:", output)

Result: tensor([[-0.1015,  0.0346],
        [ 0.0810, -0.1206],
        [ 0.6072, -0.1646],
        [ 0.0963, -0.0379]], grad_fn=<MmBackward0>)


In [4]:
criterion = nn.MSELoss()

# Target values
targets = torch.randn(4, output_size)

# backward pass
def backward_pass(output, targets):
    loss = criterion(output.to(torch.float32), targets.to(torch.float32))
    loss.backward()
    return loss.item()

loss_value = backward_pass(output, targets)
print(f"loss value : {loss_value}")

loss value : 0.7194342613220215


# cuda realization

In [5]:
FLOAT32_EXP_BIAS = 127
FLOAT32_IMPLIED1 = 1 << 23
FLOAT32_FULL_MBITS = 23

def get_biased_exponent(float_val):
    bits = torch.FloatTensor([float_val]).to(torch.int32)
    return ((bits >> 23) & 0xFF).item()

def get_sign(float_val):
    bits = torch.FloatTensor([float_val]).to(torch.int32)
    return (bits >> 31).item()

def get_trailing_mantissa(float_val):
    bits = torch.FloatTensor([float_val]).to(torch.int32)
    return (bits & (FLOAT32_IMPLIED1 - 1)).item()

def construct_float(sign, biased_exp, mantissa):
    bits = (sign << 31) | (biased_exp << 23) | (mantissa & (FLOAT32_IMPLIED1 - 1))
    return torch.tensor([bits], dtype=torch.int32).view(torch.float32).item()

def shift_right_round_mantissa(
    mantissa,    
    is_subnorm,  
    mbits,       
    exp_diff,    
    rounding_mode='away',
    allow_overflow=False
):
    print(f"Initial mantissa: {mantissa}, exp_diff: {exp_diff}")

    mantissa = mantissa if is_subnorm else mantissa + FLOAT32_IMPLIED1
    fp32_sig_bits = 23 if is_subnorm else 24

    tie = False
    even = False
    if rounding_mode == 'even':
        tbits = exp_diff + (fp32_sig_bits - mbits)
        mask = (1 << (tbits - 1)) - 1
        tie = not (mantissa & mask)
        mask = (1 << tbits)
        even = not (mantissa & mask)

    mantissa >>= exp_diff
    mantissa >>= (fp32_sig_bits - mbits - 1)

    if mantissa == 0:  # Если после сдвига мантисса становится 0, не сдвигать дальше
        print("Mantissa became zero after shifting. Exiting early.")
        return 0

    if (rounding_mode == 'away' or rounding_mode == 'even') and \
       (allow_overflow or mantissa != ((1 << (mbits + 1)) - 1)):
        if not (tie and even):
            mantissa += 1

    mantissa >>= 1
    print(f"Mantissa after shift and rounding: {mantissa}")
    return mantissa

def shift_left_mantissa(
    mantissa,    
    is_subnorm,  
    mbits,       
    exp_diff     
):
    print(f"Shift left mantissa: {mantissa}")

    fp32_sig_bits = 23 if is_subnorm else 24
    mantissa <<= (fp32_sig_bits - mbits + exp_diff)

    overflow = mantissa >= (1 << fp32_sig_bits)
    mantissa = mantissa >> 1 if overflow and not is_subnorm else mantissa

    mantissa &= (FLOAT32_IMPLIED1 - 1)
    print(f"Mantissa after shift left: {mantissa}, overflow: {overflow}")
    return overflow

In [6]:
def quantize_elemwise(
    input,
    bits,    
    exp_bits,
    max_norm,
    rounding_mode='away',
    saturate_normals=False,
    allow_denorm=True
):
    if input == 0.0:
        return 0.0

    biased_exp = get_biased_exponent(input)
    sign = get_sign(input)
    mantissa = get_trailing_mantissa(input)

    print(f"Input: {input}, Biased exp: {biased_exp}, Sign: {sign}, Mantissa: {mantissa}")

    mbits = bits - 1
    is_int = exp_bits == 0

    new_bias = 1 if is_int else (1 << (exp_bits - 1)) - 1
    new_biased_exp = biased_exp - FLOAT32_EXP_BIAS + new_bias

    print(f"New biased exp: {new_biased_exp}, Max norm: {max_norm}")

    if not is_int and not allow_denorm and new_biased_exp < 1:
        return 0.0

    exp_diff = max(0, 1 - new_biased_exp)
    exp_diff = min(FLOAT32_FULL_MBITS, exp_diff)

    print(f"Exp diff: {exp_diff}")

    mantissa = shift_right_round_mantissa(mantissa, biased_exp == 0, mbits, exp_diff, rounding_mode)

    if mantissa == 0:
        return 0.0

    overflow = shift_left_mantissa(mantissa, biased_exp == 0, mbits, exp_diff)
    biased_exp = biased_exp + 1 if overflow else biased_exp

    output = construct_float(sign, biased_exp, mantissa)

    print(f"Output after constructing float: {output}")

    if abs(output) > max_norm:
        if is_int or saturate_normals:
            return -max_norm if sign else max_norm
        else:
            return float('inf')

    return output

def quantize_mx_elem(input, scale, flush_tile, elem_ebits, elem_mbits, elem_max_norm, rounding_mode='away'):
    scaled_in = 0 if flush_tile else input / scale

    print(f"Scaled input: {scaled_in}")

    scaled_out = quantize_elemwise(
        scaled_in, elem_mbits, elem_ebits, elem_max_norm,
        rounding_mode, True, True
    )

    return scaled_out * scale

In [15]:
input_value = 1.2345
scale = 0.5
elem_ebits = 8
elem_mbits = 7
elem_max_norm = 2.0

# quantized_value = quantize_mx_elem(
#     input_value, scale, False, elem_ebits, elem_mbits, elem_max_norm)


input = torch.rand(1,100)
out = torch.zeros(1,100)
for i in range(1):
    out[0,i] = quantize_mx_elem(input[0,i].item(), scale, False, elem_ebits, elem_mbits, elem_max_norm)
# print(f"Quantized values: {out}")
# print(f"Квантованное значение: {quantized_value}")


Scaled input: 1.6672261953353882
Input: 1.6672261953353882, Biased exp: 0, Sign: 0, Mantissa: 1
New biased exp: 0, Max norm: 2.0
Exp diff: 1
Initial mantissa: 1, exp_diff: 1
Mantissa became zero after shifting. Exiting early.
