In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms
from torch.nn import functional
import math

In [6]:
class Binarize(torch.autograd.Function):
    THRESHOLD_STE = True
    
    @staticmethod
    def forward(ctx, input):
        """
        We approximate the input by the following:
        
        input ~= sign(input) * l1_norm(input) / input.size
        """
        ctx.save_for_backward(input)
        return input.sign() * torch.mean(torch.abs(input))

    @staticmethod
    def backward(ctx, grad_output):
        """
        According to [Do-Re-Fa Networks](https://arxiv.org/pdf/1606.06160.pdf),
        the STE for binary weight networks is completely pass through.
        
        However, according to [Binary Neural Networks](https://arxiv.org/pdf/1602.02830.pdf),
        and [XNOR-net networks](https://arxiv.org/pdf/1603.05279.pdf),
        the STE must be thresholded by the following:
        
        d = d * (-1 <= w <= 1)
        
        Set THRESHOLD_STE to True/False for either behavior. However, it is suggested
        to set it to True because we have seen performance degradations with it = False.
        """
        if Binarize.THRESHOLD_STE:
            input, = ctx.saved_tensors
            grad_output[input.ge(1)] = 0
            grad_output[input.le(-1)] = 0
        return grad_output
    
class BinaryLinear(nn.Module):
    def __init__(self, in_features, out_features):
        """
        Takes in some inputs x, and initializes some weights for matmul,
        and performs a bitcount(xor(x, weights)).
        
        input = (N, M)
        weights = (M, K)
        
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to False, the layer will not learn an additive bias.
            Default: ``True``
        """
        super(BinaryLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = torch.nn.Parameter(torch.Tensor(out_features))
        
        # Initializing parameters
        stdv = 1. / math.sqrt(in_features * out_features)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)        

    def forward(self, input):
        binarize = Binarize.apply
        return functional.linear(binarize(input), binarize(self.weight), binarize(self.bias))

In [9]:
bmul = BinaryLinear(3,1)

In [11]:
x = torch.randn(1, 3, requires_grad=True)
print(x, '\n', bmul.bias, '\n', bmul.weight)
bmul.forward(x)

tensor([[-0.7516, -1.9929,  0.2617]]) 
 Parameter containing:
tensor([ 0.5027]) 
 Parameter containing:
tensor([[ 0.4418,  0.0777, -0.5062]])


tensor([[-0.5252]])

# Proof of correctness : linear

In [13]:
# >>> import numpy as np
# >>> x = np.array([[-0.7516, -1.9929,  0.2617]])
# >>> w = np.array([[ 0.4418,  0.0777, -0.5062]])
# >>> ax = np.mean(np.absolute(x))
# >>> ax
# 1.0020666666666667
# >>> aw = np.mean(np.absolute(w))
# >>> aw
# 0.34190000000000004
# >>> Bx = np.sign(x)
# >>> Bw = np.sign(w)
# >>> print("w:",w,"bw:",Bw * aw)
# w: [[ 0.4418  0.0777 -0.5062]] bw: [[ 0.3419  0.3419 -0.3419]]
# >>> linear = np.sum(Bx * Bw)
# >>> linear *= ax * aw
# >>> linear
# -1.0278197800000002
# >>> linear += b
# >>> linear
# -0.5251197800000001

# Proof of correctness : convolution