In [2]:
import math
import os

# import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import torchvision
import torchvision.transforms as transforms
from torch import Tensor
from torch.autograd import Function
from torch.nn import init, Module
from torch.nn.parameter import Parameter, UninitializedParameter
from torch.utils.data import DataLoader, Dataset, TensorDataset
from typing import Optional, Tuple, List, Union

In [3]:
def hard_sigmoid(x: Tensor):
    return torch.clip((x + 1) / 2, 0, 1)

def binarize(weight: Tensor, H: float, deterministic: bool=True) -> Tensor:
    weight_binary = hard_sigmoid(weight / H)
    
    if deterministic:
        weight_binary = torch.round(weight_binary)
    else:
        weight_binary = torch.bernoulli(weight_binary)
#         print(weight_binary.is_cuda)
        weight_binary = weight_binary.float()
    
    weight_binary = ((2 * weight_binary - 1) * H).float()
#     weight_binary = weight
    return weight_binary


class BinaryLinearFunction(Function):
    @staticmethod
    def forward(ctx, input: Tensor, weight: Tensor, bias: Tensor=None, H: float=1., deterministic: bool=True):
        # Binarize the weights for forward prop
        weight_binary = binarize(weight, H, deterministic)
        
        # The full precision weights are required for backprop
        ctx.save_for_backward(input, weight_binary, bias)
#         print(input.shape, weight_binary.shape)
        
        output = torch.mm(input, weight_binary.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output: Tensor):
        input, weight_binary, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        
        if ctx.needs_input_grad[0]:
            grad_input = torch.mm(grad_output, weight_binary)
            
        if ctx.needs_input_grad[1]:
            grad_weight = torch.mm(grad_output.t(), input)
            
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
            
        return grad_input, grad_weight, grad_bias, None, None

binary_linear = BinaryLinearFunction.apply
    
class BinaryDense(Module):
    __constants__ = ["in_features", "out_features"]
    in_features: int
    out_features: int
    weight: Tensor
    binary_weight: Tensor

    def __init__(self, in_features: int, out_features: int, H: float=1, bias: bool=False, deterministic: bool=True) -> None:
        super(BinaryDense, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(Tensor(out_features, in_features))
        self.H = H
        if bias:
            self.bias = Parameter(Tensor(out_features))
        else:
            self.register_parameter("bias", None)

        self.deterministic = deterministic
        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.xavier_uniform_(self.weight)
        if self.bias is not None:
            init.zeros_(self.bias)
            
    def forward(self, input: Tensor) -> Tensor:
        return binary_linear(input, self.weight, self.bias, self.H, self.deterministic)

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

In [4]:
device = torch.device("cuda")

In [15]:
batch_size = 128
alpha = .15
epsilon = 1e-4
num_units = 100
n_hidden_layers = 1
num_epochs = 30
dropout_in = 0
dropout_hidden = 0

deterministic = False
H = 1.

learning_rate = .001
min_learning_rate = 3e-6
decay = (learning_rate - min_learning_rate) / num_epochs

t = transforms.Compose(
    [
       transforms.ToTensor(),
       transforms.Normalize(mean=(0), std=(1))
    ]
)

dl_train = DataLoader(
    torchvision.datasets.MNIST(
        "/data/mnist",
        download=True,
        train=True,
        transform=t,
        target_transform=torchvision.transforms.Compose([
            lambda x:torch.LongTensor([x]), # or just torch.tensor
            lambda x:torch.nn.functional.one_hot(x, 10)
        ])
    ),
    batch_size=batch_size,
    drop_last=True,
    shuffle=True
)
dl_valid = DataLoader(
    torchvision.datasets.MNIST(
        "/data/mnist",
        download=True,
        train=False,
        transform=t,
        target_transform=torchvision.transforms.Compose([
            lambda x:torch.LongTensor([x]), # or just torch.tensor
            lambda x:torch.nn.functional.one_hot(x, 10)
        ])
    ),
    batch_size=batch_size,
    drop_last=True,
    shuffle=True
)
num_in = 28 * 28

layers = []
layers.append(torch.nn.Dropout(dropout_in))
for i in range(n_hidden_layers):
    layers.append(BinaryDense(num_in, num_units, H=H, deterministic=deterministic))
#     layers.append(torch.nn.BatchNorm1d(num_units))
#     layers.append(torch.nn.Linear(num_in, num_units))
    layers.append(torch.nn.Sigmoid())
#     layers.append(torch.nn.Dropout(dropout_hidden))
    num_in = num_units
# layers.append(torch.nn.Linear(num_in, 10))
layers.append(BinaryDense(num_in, 10, H=H, deterministic=deterministic))
# layers.append(torch.nn.BatchNorm1d(10, eps=epsilon, momentum=alpha))
layers.append(torch.nn.Softmax(1))
model = torch.nn.Sequential(*layers).to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=30)
lossfunction = torch.nn.MSELoss()


losses = [0] * num_epochs
val_losses = [0] * num_epochs
total_steps = len(dl_train) * num_epochs

# from sklearn.metrics import accuracy_score, f1_score

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}")
    model.train()
    ect = 0
    ecv = 0
    ett = 0
    etv = 0
    for i, (input, target) in enumerate(dl_train):
        if i % 10 == 0:
            print(i, end=" ")
        optimizer.zero_grad()
        input = torch.reshape(input, (-1, 28 * 28)).to(device)
        target = torch.reshape(target, (-1, 10)).to(device)
        
        output = model(input)
#         ect += accuracy_score(output.argmax(dim=-1).cpu(), target.argmax(dim=-1).cpu(), normalize=False)
        
#         print(output.shape, target.shape)
        
        loss = lossfunction(output, target.float())
        losses[epoch] += loss.item()
        loss.backward()
        
        optimizer.step()
#         scheduler.step()
#         ett += target.shape[0]

    model.eval()
    with torch.no_grad():
        for j, (input, target) in enumerate(dl_valid):
            input = torch.reshape(input, (-1, 28*28)).to(device)
            target = target.reshape((-1, 10)).to(device)
            output = model(input)
            loss = lossfunction(output, target.float())
            val_losses[epoch] += loss.item()
#             ecv += accuracy_score(output.argmax(dim=-1).cpu(), target.argmax(dim=-1).cpu(), normalize=False)
#             etv += target.shape[0]
    
    print("")
    print("Epoch training loss" , losses[epoch] / len(dl_train))
    print("Epoch valid loss" , val_losses[epoch] / len(dl_valid))


Epoch 1
0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 
Epoch training loss 0.1479529457915033
Epoch valid loss 0.1235754980872839
Epoch 2
0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 
Epoch training loss 0.1002253546880988
Epoch valid loss 0.07413571523741269
Epoch 3
0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 
Epoch training loss 0.059664154313823096
Epoch valid loss 0.046792120433961734
Epoch 4
0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 
Epoch training loss 0.0428

In [20]:
tot_acc = 0
with torch.no_grad():
    for j, (input, target) in enumerate(dl_valid):
        input = torch.reshape(input, (-1, 28*28)).to(device)
        target = target.to(device)
        output = model(input)
        target = target.argmax(-1).reshape(-1)
        output = output.argmax(-1).reshape(-1)
        
        tot_acc = (tot_acc * j + int(sum(target == output)) / len(target)) / (j + 1)

print("Validation Accuracy:", tot_acc)

Validation Accuracy: 0.8928285256410257


In [None]:
class _BinaryConvNd(Module):
    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
        ...

    _in_channels: int
    _reversed_padding_repeated_twice: List[int]
    out_channels: int
    kernel_size: Tuple[int, ...]
    stride: Tuple[int, ...]
    padding: Union[str, Tuple[int, ...]]
    dilation: Tuple[int, ...]
    transposed: bool
    output_padding: Tuple[int, ...]
    groups: int
    padding_mode: str
    weight: Tensor
    bias: Optional[Tensor]


    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Tuple[int, ...],
                 stride: Tuple[int, ...],
                 padding: Tuple[int, ...],
                 dilation: Tuple[int, ...],
                 transposed: bool,
                 output_padding: Tuple[int, ...],
                 groups: int,
                 bias: bool,
                 padding_mode: str,
                 H: float=1.,
                 deterministic: bool=True) -> None:
        super(BinaryConvNd, self).__init__()
        
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
        if padding_mode not in valid_padding_modes:
            raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
                valid_padding_modes, padding_mode))

        self.H = H
        self.deterministic = deterministic
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        self.padding_mode = padding_mode
        # `_reversed_padding_repeated_twice` is the padding to be passed to
        # `F.pad` if needed (e.g., for non-zero padding types that are
        # implemented as two ops: padding + conv). `F.pad` accepts paddings in
        # reverse order than the dimension.
        self._reversed_padding_repeated_twice = _reverse_repeat_tuple(self.padding, 2)
        if transposed:
            self.weight = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        init.xavier_uniform_(self.weight)
        if self.bias is not None:
            init.xavier_uniform_(self.bias)

    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        if self.padding_mode != 'zeros':
            s += ', padding_mode={padding_mode}'
        return s.format(**self.__dict__)

    def __setstate__(self, state):
        super(_ConvNd, self).__setstate__(state)
        if not hasattr(self, 'padding_mode'):
            self.padding_mode = 'zeros'

In [None]:

class BinarizeKernel(Function):
    @staticmethod
    def forward(ctx, weight: Tensor, H: float=1., deterministic: bool=True):
        ...
        
    @staticmethod
    def backward(ctx, grad_output: Tensor):
        ...

class BinaryConv2D(_BinaryConvNd):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t,
        stride: _size_2_t = 1,
        padding: _size_2_t = 0,
        dilation: _size_2_t = 1,
        groups: int = 1,
        bias: bool = True,
        padding_mode: str = 'zeros'  # TODO: refine this type
    ):

        kernel_size_ = torch.nn.modules.utils._pair(kernel_size)
        stride_ = torch.nn.modules.utils._pair(stride)
        padding_ = torch.nn.modules.utils._pair(padding)
        dilation_ = torch.nn.modules.utils._pair(dilation)
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
            False, _pair(0), groups, bias, padding_mode)

    def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
        if self.padding_mode != 'zeros':
            return torch.nn.functional.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                            weight, bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input: Tensor) -> Tensor:
        return self._conv_forward(input, self.weight, self.bias)
