In [39]:
import torch
import torch.nn as nn

def quantize(x, bits, dev = 'cpu'):
    maxq = torch.tensor(2**bits - 1)

    tmp = torch.zeros(x.shape[0], device=dev)
    xmin = torch.minimum(x.min(1)[0], tmp) #Rowwise minimums
    xmax = torch.maximum(x.max(1)[0], tmp) #Rowwise minimums

    tmp = (xmin == 0) & (xmax == 0)
    xmin[tmp] = -1
    xmax[tmp] = +1

    if maxq < 0:
        scale = xmax
        zero = xmin
    else:
        scale = (xmax - xmin) / maxq
        zero = torch.round(-xmin / scale)

    if maxq < 0:
        return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
    
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return scale * (q - zero)

Greedy column reordering and quantization

In [60]:
#torch.manual_seed(0)
size = 100

#Split quantization without reordering
A = torch.randn(size,size)
B = torch.randn(size,size)
B2 = torch.clone(B)
out = torch.matmul(A,B)

B4 = quantize(B,4)
B8 = quantize(B,8)
 
B[:,size//2:size] = B4[:,size//2:size]
B[:,:size//2] = B8[:,:size//2] 

test_out = torch.matmul(A,B)

#Find largest columns by means
largest_col = torch.argsort(-torch.mean(A,axis=0))

#Reorder columns according to their mean
A = A[:,largest_col] 
B = B2[largest_col,:]

B4 = quantize(B,4)
B8 = quantize(B,8)
 
B[:,size//2:size] = B4[:,size//2:size]
B[:,:size//2] = B8[:,:size//2] 

quantOut = torch.matmul(A,B)

error = torch.norm(out - test_out)
print(error)

error = torch.norm(out - quantOut)
print(error)

#2.8505
#2.8109

tensor(99.7150)
tensor(97.9941)
