# Positive matrix product

### Imports

In [29]:
import time
import torch
import torch.nn.functional as F
import torch.nn as nn

### Matrix-vector product vs Matrix element-wise multiplication by vector (cuda float32)

In [30]:
device = torch.device('cuda')
n_in = 9216
n_out = 128
batch_size = 256
W = torch.empty(n_out, n_in).normal_().to(device=device)
x = torch.linspace(-10, 10, n_in).to(device=device)

In [32]:
start = time.clock()
result_1 = W @ x
elapsed_time = time.clock() - start
print("Elapsed time: " + "{0:0.5f}".format(elapsed_time * 1000) + "ms")

Elapsed time: 0.70267ms


In [33]:
start = time.clock()
result_2 = W * x
elapsed_time = time.clock() - start
print("Elapsed time: " + "{0:0.5}".format(elapsed_time * 1000) + "ms")

Elapsed time: 0.62806ms


### Matrix-vector product vs Matrix element-wise multiplication by vector (cpu int8)

In [54]:
device = torch.device('cpu')
n_in = 9216
n_out = 128
batch_size = 256
W = torch.empty(n_out, n_in, dtype=torch.int8).random_(-128, 127).to(device=device)
x = torch.empty(n_in, dtype=torch.int8).random_(-128, 127).to(device=device)

In [55]:
start = time.clock()
result_1 = W @ x
elapsed_time = time.clock() - start
print("Elapsed time: " + "{0:0.5f}".format(elapsed_time * 1000) + "ms")

Elapsed time: 1.07611ms


In [56]:
start = time.clock()
result_2 = (W * x)
elapsed_time = time.clock() - start
print("Elapsed time: " + "{0:0.5}".format(elapsed_time * 1000) + "ms")

Elapsed time: 0.86689ms


### FC Layer VS manual matrix vector product VS element-wise multiplication + sum

In [57]:
device = torch.device("cuda")
fc = nn.Linear(n_in, n_out, bias=False).to(device=device)
for parameter in fc.parameters():
    parameter.requires_grad = False  # This increases the speed of result 4 like CRAZY

In [58]:
inputs = torch.Tensor(batch_size, n_in).normal_().to(device=device)

In [59]:
n_iter = 100
start = time.clock()
for i in range(n_iter):
    results = fc(inputs)
elapsed_time = time.clock() - start
print("Elapsed time: " + "{0:0.3f}".format(elapsed_time) + "s - Time per iteration: " + "{0:0.3f}".format(elapsed_time/n_iter*1000) + "ms")

Elapsed time: 0.014s - Time per iteration: 0.141ms


In [60]:
n_iter = 100
W = fc.weight.data
start = time.time()
for i in range(n_iter):
    result = W @ inputs.view(n_in, -1)
    torch.cuda.synchronize()
elapsed_time = time.time() - start
print("Elapsed time: " + "{0:0.3f}".format(elapsed_time) + "s - Time per iteration: " + "{0:0.3f}".format(elapsed_time/n_iter*1000) + "ms")

Elapsed time: 0.067s - Time per iteration: 0.670ms


In [61]:
n_iter = 100
W = fc.weight.data
start = time.time()
for i in range(n_iter):
    inputs_expanded = inputs.unsqueeze(1).expand(batch_size, n_out, n_in)
    result = (W * inputs_expanded).sum(axis=2)
    torch.cuda.synchronize()
elapsed_time = time.time() - start
print("Elapsed time: " + "{0:0.3f}".format(elapsed_time) + "s - Time per iteration: " + "{0:0.3f}".format(elapsed_time/n_iter*1000) + "ms")

Elapsed time: 3.675s - Time per iteration: 36.748ms


### Computation of the signs matrix

In [13]:
device = torch.device("cuda")
W = torch.empty(n_out, n_in, dtype=torch.int8).random_(-128, 127).to(device)
input_ = torch.empty(n_in, dtype=torch.int8).random_(-128, 127).to(device)

In [14]:
W

tensor([[ -51,  112, -110,  ..., -103,   12,   -1],
        [ -80,  119,   43,  ...,  119,  109,  -25],
        [-124,  112,  -52,  ...,  -14,  -46,   10],
        ...,
        [ -12,  -29,   66,  ...,   90,  -92,   49],
        [  -8,  -35,   58,  ...,  107,   83,  118],
        [ 117, -100,   66,  ...,   42,   55,  111]], device='cuda:0',
       dtype=torch.int8)

In [15]:
input_

tensor([-118,  -36,   37,  ...,   97, -111,   71], device='cuda:0',
       dtype=torch.int8)

In [16]:
w_signs = torch.sign(W)  # The weight signs are precomputed (out of time benchmark)
input_signs = torch.sign(input_)
w_signs * input_signs

tensor([[ 1, -1, -1,  ..., -1, -1, -1],
        [ 1, -1,  1,  ...,  1, -1, -1],
        [ 1, -1, -1,  ..., -1,  1,  1],
        ...,
        [ 1,  1,  1,  ...,  1,  1,  1],
        [ 1,  1,  1,  ...,  1, -1,  1],
        [-1,  1,  1,  ...,  1, -1,  1]], device='cuda:0', dtype=torch.int8)

In [17]:
n_iter = 1000
start = time.time()
if device == torch.device("cuda"):
    for i in range(n_iter):
        input_signs = torch.sign(input_)  # A bit slower with cpu, a bit faster with cuda
        signs_matrix = w_signs * input_signs
else:
    for i in range(n_iter):
        input_signs = ((input_ >> 7) * -2 + 1)
        signs_matrix = w_signs * input_signs

elapsed_time = time.time() - start
print("Elapsed time: " + "{0:0.3f}".format(elapsed_time) + "s - Time per iteration: " + "{0:0.4f}".format(elapsed_time/n_iter * 1000) + "ms")

Elapsed time: 0.326s - Time per iteration: 0.3262ms


### Computation of the absolute value of the tensors

As we can see the absolute value of the smallest number is bugged because it doesnt exist in that dtype

In [18]:
value1 = torch.abs(torch.tensor(-127, dtype=torch.int8)).item()
value2 = torch.abs(torch.tensor(-128, dtype=torch.int8)).item()
print("Absolute value of int8 -127 is " + repr(value1))
print("Absolute value of int8 -128 is " + repr(value2))

Absolute value of int8 -127 is 127
Absolute value of int8 -128 is -128


In [19]:
device = torch.device("cpu")
W = torch.empty(n_out, n_in, dtype=torch.int8).random_(-127, 127).to(device)  # We only range from -127 to avoid bug of abs
input_ = torch.empty(n_in, dtype=torch.int8).random_(-127, 127).to(device)  # Same

In [20]:
abs_W = torch.abs(W)  # This can be precomputed

In [21]:
n_iter = 10000
start = time.time()
for i in range(n_iter):
    torch.abs(input_)
elapsed_time = time.time() - start
print("Elapsed time: " + "{0:0.3f}".format(elapsed_time) + "s - Time per iteration: " + "{0:0.4f}".format(elapsed_time/n_iter*1000) + "ms")

Elapsed time: 0.188s - Time per iteration: 0.0188ms


### Full pipeline

In [62]:
batch_size = 256
n_in = 9216
n_out = 128
device = torch.device("cpu")

# Generate W and precompute it's element-wise signs and absolute values
W = torch.empty(n_out, n_in, dtype=torch.int8, device=device).random_(-127, 127)
abs_W = torch.abs(W)
signs_W = torch.sign(W)

In [63]:
# Generate batch_size inputs
inputs = torch.empty(batch_size, n_in, dtype=torch.int8, device=device).random_(-127, 127)

In [64]:
abs_inputs_expanded = torch.abs(inputs).unsqueeze(1).expand(batch_size, n_out, n_in)
signs_inputs_expanded = torch.sign(inputs).unsqueeze(1).expand(batch_size, n_out, n_in)
signs_matrices = signs_W * signs_inputs_expanded

# The first multiplication (abs_W * abs_inputs_expanded) can use only fib encoded numbers
# The second multiplication (... * signs_matrices) is always an unsigned int * a sign
result = torch.zeros_like(abs_W, dtype=torch.int32)
torch.mul(abs_W, abs_inputs_expanded, out=result)
torch.mul(result, signs_matrices, out=result)
result = result.sum(dim=2)

In [65]:
result

tensor([[   55689,  -225734,  -195977,  ...,   680322,  -100576,   571251],
        [  546280,   455686,  -596431,  ...,  -187095,  -641463,   722258],
        [ -177177,  -105483,  -774910,  ...,  -303105,   721467,  -356380],
        ...,
        [ -648762, -1156661,   570426,  ...,  -660435,   658413,  1071726],
        [  399636,  -708527, -1157746,  ...,  -721120,  -101464,   495962],
        [     303,   474557,   201979,  ...,  -708200,   173130,  -564593]])

In [66]:
device = torch.device("cuda")
inputs_cuda = inputs.to(dtype=torch.float32, device=device)
fc = nn.Linear(n_in, n_out, bias=False).to(device=device)
with torch.no_grad():
    fc.weight.data = W.to(dtype=torch.float32, device=device)
    result_baseline = fc(inputs_cuda)

In [67]:
result_baseline

tensor([[ 5.5689e+04, -2.2573e+05, -1.9598e+05,  ...,  6.8032e+05,
         -1.0058e+05,  5.7125e+05],
        [ 5.4628e+05,  4.5569e+05, -5.9643e+05,  ..., -1.8710e+05,
         -6.4146e+05,  7.2226e+05],
        [-1.7718e+05, -1.0548e+05, -7.7491e+05,  ..., -3.0310e+05,
          7.2147e+05, -3.5638e+05],
        ...,
        [-6.4876e+05, -1.1567e+06,  5.7043e+05,  ..., -6.6044e+05,
          6.5841e+05,  1.0717e+06],
        [ 3.9964e+05, -7.0853e+05, -1.1577e+06,  ..., -7.2112e+05,
         -1.0146e+05,  4.9596e+05],
        [ 3.0300e+02,  4.7456e+05,  2.0198e+05,  ..., -7.0820e+05,
          1.7313e+05, -5.6459e+05]], device='cuda:0')

In [68]:
# Check equality between baseline and positive matrix product trick
torch.all(torch.eq(result.to(dtype=torch.float32, device=device), result_baseline)).item()

True