In [1]:
from torch import Tensor

import torch
import numpy as np

In [2]:
test_x = torch.randn(5)
test_y = torch.randn(5)
test_x, test_y

(tensor([-0.5682,  1.4677, -0.3205,  0.2717,  0.4254]),
 tensor([-2.6464, -0.7340, -2.1760, -1.7312, -0.9295]))

In [3]:
def get_tensor_size(tensor: Tensor, debug=False):
    total_params = tensor.numel()
    total_size = total_params * tensor.element_size()
    if debug:
        print(f"tensor has {round(total_params)} elements")
        print(f"tensor is {round(total_size)} bytes")
    return total_params, total_size

In [4]:
test_c = torch.randint(0, 2, (5,)).to(dtype=torch.bool)
get_tensor_size(test_c, True), test_c.numpy().dtype

tensor has 5 elements
tensor is 5 bytes


((5, 5), dtype('bool'))

In [5]:
test_tensor = torch.arange(5 * 5).view(5, 5)
test_index = torch.arange(2)
test_tensor, torch.index_select(test_tensor, -1, test_index)

(tensor([[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19],
         [20, 21, 22, 23, 24]]),
 tensor([[ 0,  1],
         [ 5,  6],
         [10, 11],
         [15, 16],
         [20, 21]]))

In [6]:
DEVICE = 'cpu'

In [50]:
FRACT_PV = torch.arange(-1, -52-1, -1, device=DEVICE, dtype=torch.float64)
EXP_PV   = torch.arange(10, 0-1, -1, device=DEVICE, dtype=torch.float64)
print(FRACT_PV)
print(EXP_PV)
FRACT_SEL = (64 - 52) + torch.arange(52, device=DEVICE, dtype=torch.int)
EXP_SEL   = torch.arange(1, 12, device=DEVICE, dtype=torch.int)
print(FRACT_SEL)
print(EXP_SEL)

def float64_frac(binary_tensor: Tensor):
    print(torch.sum(torch.index_select(binary_tensor, -1, FRACT_SEL) * (2 ** FRACT_PV), -1))
    return 1 + torch.sum(torch.index_select(binary_tensor, -1, FRACT_SEL) * (2 ** FRACT_PV), -1)

def float64_exp(binary_tensor: Tensor):
    return 2 ** (torch.sum(torch.index_select(binary_tensor, -1, EXP_SEL) * (2 ** EXP_PV), -1) - 1023)

def to_float64(binary_tensor: Tensor):
    torch.cuda.empty_cache()
    with torch.no_grad():
        return ((-1.0) ** torch.select(binary_tensor, dim=-1, index=0)) * float64_exp(binary_tensor) * float64_frac(binary_tensor)

tensor([ -1.,  -2.,  -3.,  -4.,  -5.,  -6.,  -7.,  -8.,  -9., -10., -11., -12.,
        -13., -14., -15., -16., -17., -18., -19., -20., -21., -22., -23., -24.,
        -25., -26., -27., -28., -29., -30., -31., -32., -33., -34., -35., -36.,
        -37., -38., -39., -40., -41., -42., -43., -44., -45., -46., -47., -48.,
        -49., -50., -51., -52.], dtype=torch.float64)
tensor([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.,  2.,  1.,  0.],
       dtype=torch.float64)
tensor([12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
        30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
        48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
       dtype=torch.int32)
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=torch.int32)


In [51]:
test_sign = torch.randint(0, 2, (64, 64, 1), device=DEVICE, dtype=torch.bool)
test_exp  = torch.randint(0, 2, (64, 64, 11), device=DEVICE, dtype=torch.bool)
test_frac = torch.randint(0, 2, (64, 64, 52), device=DEVICE, dtype=torch.bool)
test_num  = torch.cat([test_sign, test_exp, test_frac], -1)
print(test_sign)
print(test_exp)
print(test_frac)
test_f = float64_frac(test_num).cpu()
test_e = float64_exp(test_num).cpu()
test_n = to_float64(test_num).cpu()
torch.cuda.empty_cache()
test_f, test_e, test_n, get_tensor_size(test_num), get_tensor_size(test_n)

tensor([[[ True],
         [ True],
         [False],
         ...,
         [ True],
         [ True],
         [False]],

        [[False],
         [False],
         [False],
         ...,
         [ True],
         [False],
         [False]],

        [[ True],
         [ True],
         [ True],
         ...,
         [False],
         [ True],
         [False]],

        ...,

        [[ True],
         [False],
         [False],
         ...,
         [False],
         [ True],
         [False]],

        [[ True],
         [ True],
         [ True],
         ...,
         [False],
         [False],
         [False]],

        [[ True],
         [ True],
         [False],
         ...,
         [False],
         [False],
         [ True]]])
tensor([[[ True, False, False,  ...,  True, False, False],
         [False,  True, False,  ..., False, False,  True],
         [False,  True, False,  ...,  True, False, False],
         ...,
         [ True,  True,  True,  ...,  True,  True, 

(tensor([[1.3267, 1.1376, 1.1734,  ..., 1.4573, 1.0125, 1.6060],
         [1.2028, 1.2446, 1.9373,  ..., 1.6800, 1.7533, 1.5854],
         [1.1531, 1.2806, 1.2760,  ..., 1.7213, 1.3730, 1.3127],
         ...,
         [1.6656, 1.9748, 1.0721,  ..., 1.9420, 1.4660, 1.5710],
         [1.1014, 1.2717, 1.2805,  ..., 1.6473, 1.3313, 1.2003],
         [1.2242, 1.1346, 1.0548,  ..., 1.8735, 1.4238, 1.6130]],
        dtype=torch.float64),
 tensor([[ 2.3058e+18, 1.2813e-144, 1.7198e-136,  ..., 1.0175e+236,
           3.4028e+38,  5.9863e+51],
         [ 2.3408e-97, 1.6418e-288, 4.1675e+239,  ...,  1.1259e+15,
          1.5391e+113,  1.5625e-02],
         [ 5.0000e-01, 2.5436e+235, 3.9869e-205,  ...,  2.3970e-94,
          7.2911e-304, 5.8582e-244],
         ...,
         [1.2994e-113, 2.3134e+223, 1.1845e+226,  ..., 2.8639e+250,
          4.4555e+189, 1.4742e+166],
         [6.2101e+231, 4.1095e+208, 5.8147e+135,  ...,  5.7646e+17,
          5.2766e-228, 2.0658e+121],
         [6.2978e+262,  2.

In [52]:
x = torch.arange(2036, 2046, dtype=torch.float64) - 1023
a = 2 ** x
y = torch.log2(a)
y_ = torch.log2(a * (1 + torch.abs(torch.rand_like(a))))
z = (2 ** (y_ - y) - 1)
x, a, y + 1023, y_, z, y / x

(tensor([1013., 1014., 1015., 1016., 1017., 1018., 1019., 1020., 1021., 1022.],
        dtype=torch.float64),
 tensor([8.7778e+304, 1.7556e+305, 3.5111e+305, 7.0222e+305, 1.4044e+306,
         2.8089e+306, 5.6178e+306, 1.1236e+307, 2.2471e+307, 4.4942e+307],
        dtype=torch.float64),
 tensor([2036., 2037., 2038., 2039., 2040., 2041., 2042., 2043., 2044., 2045.],
        dtype=torch.float64),
 tensor([1013.7807, 1014.2719, 1015.3618, 1016.1188, 1017.8038, 1018.6142,
         1019.3919, 1020.7339, 1021.9458, 1022.7578], dtype=torch.float64),
 tensor([0.7180, 0.2074, 0.2850, 0.0859, 0.7456, 0.5307, 0.3122, 0.6631, 0.9263,
         0.6909], dtype=torch.float64),
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=torch.float64))

In [64]:
def to_binary(tensor: Tensor):
    sign = (tensor < 0.0).unsqueeze(-1)
    log = torch.log2(torch.abs(tensor)) + 1023
    exponent = torch.floor(log)
    fraction = (2 ** (log - exponent)) - 1
    print(log, exponent, fraction)

    # Get exponent in binary
    exponent_bin = []
    for _ in range(11):
        exponent_bin.append((exponent % 2 != 0).unsqueeze(-1))
        exponent = torch.floor(exponent / 2)
    exponent_bin = torch.cat(list(reversed(exponent_bin)), -1)

    # Get fraction in binary
    fraction_bin = []
    for _ in range(52):
        value = fraction * 2
        integer = np.floor(value)
        fraction_bin.append(integer.unsqueeze(-1))
        fraction = value - integer
    fraction_bin = torch.cat(fraction_bin, -1)

    print(sign.shape, exponent_bin.shape, fraction_bin.shape)
    return torch.cat([sign, exponent_bin, fraction_bin], -1)

test_x = torch.randn(1) * 10000
test_x_bin = to_binary(test_x)
print(test_x_bin)
test_x_rev = to_float64(test_x_bin).clone()

test_x, torch.select(test_x_bin, -1, 0), test_x_bin, test_x_rev

tensor([1034.8939]) tensor([1034.]) tensor([0.8582])
torch.Size([1, 1]) torch.Size([1, 11]) torch.Size([1, 52])
tensor([[0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 1., 0., 1., 1., 0.,
         1., 1., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([0.8582], dtype=torch.float64)


(tensor([3805.7893]),
 tensor([0.]),
 tensor([[0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 1., 0., 1., 1., 0.,
          1., 1., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor([3805.6335], dtype=torch.float64))