In [1]:
import torch
import math
from torch import nn

In [2]:
inp = torch.randint(2,(100,10))
ans = torch.randint(2,(100,))

  inp = torch.randint(2,(100,10))


In [3]:
def twn_quantizer(clamp_val=2.5):
    class TwnQuantizer(torch.autograd.Function):
        @staticmethod
        def forward(ctx, w : torch.Tensor, dim=None):
            if dim is None:
                dim = tuple(range(len(w.shape)))
            if type(dim) is int:
                dim = (dim,)
            ctx.save_for_backward(w)
            n = math.prod((w.shape[d]) for d in dim)

            w = torch.clamp(w,-clamp_val,clamp_val)

            thres = 0.7 * torch.norm(w, p=1, dim=dim) / n
            for d in dim:
                thres = thres.unsqueeze(d)

            b = (w>thres).type(w.dtype) - (w<-thres).type(w.dtype)
            alpha = torch.norm(b*w,p=1,dim=dim)/torch.norm(b,p=1,dim=dim)
            for d in dim:
                alpha = alpha.unsqueeze(d)

            return alpha*b

        @staticmethod
        def backward(ctx, grad_output):
            """
            Approximate the gradient wrt to the full-precision inputs
            using the gradient wrt to the quantized inputs, 
            zeroing out gradient for clamped values.
            """
            w = ctx.saved_tensors
            grad_output *= (-clamp_val < w < clamp_val)
            return grad_output, None
    
    return TwnQuantizer

In [4]:
def min_max_quantizer(bits=8, clamp_val=2.5):
    class MinMaxQuantizer(torch.autograd.Function):
        @staticmethod
        def forward(ctx, w : torch.Tensor, dim=None):
            if dim is None:
                dim = tuple(range(len(w.shape)))
            if type(dim) is int:
                dim = (dim,)
            ctx.save_for_backward(w)

            w = torch.clamp(w,-clamp_val,clamp_val)
            
            mn = mx = w
            for d in dim:
                mn = torch.min(mn,dim=d).values
                mx = torch.max(mx,dim=d).values
                mn = mn.unsqueeze(d)
                mx = mx.unsqueeze(d)

            round_factor = (2**bits-1)/(mx-mn)
            quant_w = torch.round((w-mn)*round_factor)/round_factor+mn

            return quant_w


        @staticmethod
        def backward(ctx, grad_output):
            """
            Approximate the gradient wrt to the full-precision inputs
            using the gradient wrt to the quantized inputs, 
            zeroing out gradient for clamped values.
            """
            w, = ctx.saved_tensors
            grad_output *= ((-clamp_val < w) & (w < clamp_val))
            return grad_output, None
            
    return MinMaxQuantizer

In [16]:
class QuantizedLinear(nn.Module):
    def __init__(self, size_in, size_out, quantizer=twn_quantizer(), quantize_input=False):
        super().__init__()
        
        self.layer = nn.Linear(size_in, size_out)
        self.quantizer = quantizer
        self.quantize_input = quantize_input

    def forward(self, input):
        if self.quantize_input:
            input = self.quantizer.apply(input,(-2,-1))
        quant_weight = self.quantizer.apply(self.layer.weight,(0,1))
        output = nn.functional.linear(input, quant_weight, self.layer.bias)

        return output

In [17]:
print(torch.clamp(torch.tensor(-1.5),-1,1))
quantizer=min_max_quantizer(4)
w = torch.rand((5,5))*2-1
print(w)
quantizer.apply(w)

tensor(-1.)
tensor([[-0.2520, -0.9986, -0.6978, -0.8116, -0.2498],
        [-0.4534, -0.1445, -0.1343, -0.8844, -0.2444],
        [-0.9627,  0.0402, -0.0448, -0.8209,  0.0904],
        [ 0.7080,  0.5314,  0.0226,  0.9883, -0.2767],
        [-0.5458, -0.2103,  0.6451, -0.8749, -0.4863]])


tensor([[-0.2038, -0.9986, -0.7337, -0.8661, -0.2038],
        [-0.4688, -0.2038, -0.0714, -0.8661, -0.2038],
        [-0.9986,  0.0611, -0.0714, -0.8661,  0.0611],
        [ 0.7234,  0.5909,  0.0611,  0.9883, -0.3363],
        [-0.6012, -0.2038,  0.5909, -0.8661, -0.4688]])

In [18]:
model = torch.nn.Sequential(
    QuantizedLinear(10,10,quantizer=quantizer,quantize_input=True),
    torch.nn.ReLU(),
    QuantizedLinear(10,1,quantizer=quantizer,quantize_input=True),
    torch.nn.Sigmoid()
)

In [20]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

for i in range(100000):
    optimizer.zero_grad()
    loss = torch.sum((model(inp.float())[:,0] - (ans))**2)
    loss.backward()
    optimizer.step()

    print(loss)

tensor(25.5288, grad_fn=<SumBackward0>)
tensor(25.5322, grad_fn=<SumBackward0>)
tensor(25.5240, grad_fn=<SumBackward0>)
tensor(25.4190, grad_fn=<SumBackward0>)
tensor(25.4370, grad_fn=<SumBackward0>)
tensor(25.3544, grad_fn=<SumBackward0>)
tensor(25.3307, grad_fn=<SumBackward0>)
tensor(25.3120, grad_fn=<SumBackward0>)
tensor(25.2735, grad_fn=<SumBackward0>)
tensor(25.2706, grad_fn=<SumBackward0>)
tensor(25.2631, grad_fn=<SumBackward0>)
tensor(25.2561, grad_fn=<SumBackward0>)
tensor(25.2178, grad_fn=<SumBackward0>)
tensor(25.2189, grad_fn=<SumBackward0>)
tensor(25.1953, grad_fn=<SumBackward0>)
tensor(25.1175, grad_fn=<SumBackward0>)
tensor(25.1098, grad_fn=<SumBackward0>)
tensor(25.1188, grad_fn=<SumBackward0>)
tensor(25.0960, grad_fn=<SumBackward0>)
tensor(25.0810, grad_fn=<SumBackward0>)
tensor(25.0515, grad_fn=<SumBackward0>)
tensor(25.0411, grad_fn=<SumBackward0>)
tensor(25.0359, grad_fn=<SumBackward0>)
tensor(25.0225, grad_fn=<SumBackward0>)
tensor(24.9977, grad_fn=<SumBackward0>)


KeyboardInterrupt: 

In [None]:
model[0].layer.weight = torch.nn.Parameter(model[0].quantizer.apply(model[0].layer.weight))
model[2].layer.weight = torch.nn.Parameter(model[2].quantizer.apply(model[2].layer.weight))
model[2].layer.weight

Parameter containing:
tensor([[ 0.0000,  0.0000, -5.7003,  5.7003,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000, -5.7003]], requires_grad=True)

In [None]:
model(torch.tensor([[0,1]]).float())

tensor([[0.9991]], grad_fn=<SigmoidBackward0>)