In [1]:
import torch
model_weights = torch.load("mnist_bnn_mlp.pt")
print(type(model_weights))

  from .autonotebook import tqdm as notebook_tqdm


<class 'collections.OrderedDict'>


In [2]:
model_weights.keys()

odict_keys(['fc1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'fc2.weight', 'bn2.weight', 'bn2.bias', 'bn2.running_mean', 'bn2.running_var', 'bn2.num_batches_tracked', 'fc3.weight', 'bn3.weight', 'bn3.bias', 'bn3.running_mean', 'bn3.running_var', 'bn3.num_batches_tracked', 'fc4.weight', 'bn4.weight', 'bn4.bias', 'bn4.running_mean', 'bn4.running_var', 'bn4.num_batches_tracked'])

In [3]:

for layer in model_weights.items():
    model_weights[layer[0]] = layer[1].cpu().numpy()

for layer in model_weights.items():
    print(layer)

('fc1.weight', array([[ 0.01796702, -0.01619875, -0.00735881, ...,  0.00634418,
        -0.00452901,  0.00206897],
       [-0.02773761, -0.02280657, -0.0312873 , ..., -0.00328715,
         0.00089431, -0.01704787],
       [ 0.02805158, -0.00959798, -0.03416752, ..., -0.01420229,
         0.00550288, -0.03365139],
       ...,
       [ 0.00570961,  0.0349777 , -0.03417083, ...,  0.03363043,
         0.02686312, -0.02106128],
       [ 0.01226201,  0.02846391, -0.00309211, ...,  0.03344198,
         0.033974  ,  0.01770352],
       [ 0.03216547, -0.02721824,  0.02646583, ...,  0.00504588,
         0.01041278,  0.02281732]], dtype=float32))
('bn1.weight', array([1.4778259, 2.5675004, 1.9303343, ..., 2.3447855, 1.886667 ,
       2.1504722], dtype=float32))
('bn1.bias', array([ 1.8485988 , -0.62126714, -0.0982881 , ..., -0.568802  ,
       -0.9123156 , -0.19557329], dtype=float32))
('bn1.running_mean', array([ 51.274117,  62.234627, -26.783463, ..., -45.81591 ,  34.40998 ,
        23.768795],

In [4]:
list(model_weights.items())[0][1].shape

(1024, 784)

In [5]:
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class BinaryLinear:
    def __init__(self, in_features, out_features):
        self.in_features = in_features
        self.out_features = out_features
        self.weight = np.zeros([out_features, in_features])
        self.binary_weight = np.zeros([out_features, in_features])
    
    def binarize(self):
        self.binary_weight = np.ones(self.weight.shape)
        self.binary_weight[self.weight<0] = -1
        self.binary_weight = (self.binary_weight + np.ones(self.weight.shape))/2
    
    def __call__(self, x):
        output = np.zeros([1, self.out_features])
        for i in range(self.out_features):
            row = self.binary_weight[i,:]
            xnor = ~np.logical_xor(row, x)
            output[0, i] = 2*np.sum(xnor) - self.in_features
        return output

class BinaryHardTanH:
    def __init__(self, in_features):
        self.in_features = in_features

    def __call__(self, x):
        x[x>1] = 1
        x[x<-1] = -1
        x[x<0] = -1
        x[x>=0] = 1
        return (x+np.ones(x.shape))/2
        
        

class BinaryNet:
    def __init__(self):
        
        self.features = [
            BinaryLinear(784,1024),
            nn.BatchNorm1d(1024),
            BinaryHardTanH(1024),
            
            BinaryLinear(1024,1024),
            nn.BatchNorm1d(1024),
            BinaryHardTanH(1024),
            
            BinaryLinear(1024,1024),
            nn.BatchNorm1d(1024),
            BinaryHardTanH(1024),
            
            BinaryLinear(1024,10),
            nn.BatchNorm1d(10)
        ]
    
    def load_weight(self, weight):
        cnt = 0
        
        weight = list(weight.items())
        for layer in self.features:
            (key, value) = weight[cnt]

            if isinstance(layer, BinaryHardTanH):
                continue
            
            if layer.weight.shape != value.shape:
                print("Non competible shape, expected: {}, loading: {}".format(layer.weight.shape, value.shape))
                raise
                
            layer.weight = nn.Parameter(torch.tensor(value))
            cnt+=1
            
            if 'bn' in key:
                (key, value) = weight[cnt]
                layer.bias = nn.Parameter(torch.tensor(value))
                
                (key, value) = weight[cnt+1]
                layer.running_mean = torch.tensor(value)
                
                (key, value) = weight[cnt+2]
                layer.running_var = torch.tensor(value)
                
                cnt+=4
    
    def binarize(self):
        for layer in self.features:
            if isinstance(layer, BinaryLinear):
                layer.binarize()
    def __call__(self, x):
        
        # turn to zero one
        x[x<0] = -1; x[x>=0] = 1
        x = (x + np.ones(x.shape))/2
        
        for layer in self.features:
            if isinstance(layer, nn.BatchNorm1d):
                x = torch.tensor(x, dtype=torch.float32)
                layer.training = False
                x = layer(x)
                x = x.detach().numpy()
            else:
                x = layer(x)
        
        x = torch.tensor(x, dtype=torch.float32)
        output = F.log_softmax(x, dim=1)
        return output
        


In [6]:
bnn = BinaryNet()
bnn.load_weight(model_weights)
bnn.binarize()

In [7]:
s = np.random.uniform(-1,1,[1,784]).astype(np.float32)

print(s.shape)
bnn(s)

(1, 784)


tensor([[ -6.7260,  -3.0064,  -0.0905,  -4.8599, -11.4987,  -4.2112,  -6.5674,
          -4.7578, -10.1762,  -5.7181]])

In [8]:
from torchvision import datasets, transforms

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
test_kwargs = {'batch_size': 1}
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [9]:
correct = 0
from tqdm import tqdm
for data, target in tqdm(test_loader):
    data = data.flatten(start_dim=1).numpy()
    output = bnn(data)
    pred = output.argmax(dim=1, keepdim=True)
    correct += pred.eq(target.view_as(pred)).sum().item()

print("Accuracy = {}".format(100. * correct / len(test_loader.dataset)) )

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [03:38<00:00, 45.68it/s]

Accuracy = 97.78





In [111]:
!pip install tqdm



In [57]:
s = np.random.uniform(1,-1, 784)

In [58]:
torch.tensor(np.expand_dims(s, axis=0), dtype=torch.float32)

tensor([[ 0.5282, -0.4630, -0.3560,  0.3242,  0.2986,  0.9084, -0.6866, -0.0322,
         -0.6632,  0.3347, -0.8971,  0.9783,  0.1746,  0.2339,  0.9170, -0.8043,
          0.8858,  0.5832,  0.8288,  0.6002,  0.0873,  0.1066, -0.6711,  0.7424,
          0.0589, -0.9383,  0.1378, -0.1173,  0.5028, -0.7691,  0.2411, -0.0311,
          0.5974, -0.7782, -0.3050, -0.2316, -0.8062, -0.2975, -0.1959, -0.1657,
          0.8577, -0.4954, -0.1729, -0.3457,  0.9202, -0.5183,  0.5251,  0.7135,
         -0.5186,  0.0313,  0.3776, -0.5130,  0.9162,  0.3030,  0.0122,  0.0909,
          0.1205, -0.0144,  0.9688, -0.7357, -0.8532, -0.5544,  0.2906, -0.7739,
         -0.5668, -0.2397, -0.9000, -0.0629,  0.0100, -0.7840, -0.6264,  0.2224,
         -0.5520, -0.6755, -0.1389, -0.7568, -0.7120, -0.9943,  0.7059, -0.2137,
          0.3697, -0.9765, -0.2583,  0.2644,  0.7865,  0.3486,  0.2491,  0.5264,
          0.1822, -0.9926,  0.7163,  0.2418,  0.9793, -0.0218,  0.1499, -0.9348,
          0.3937,  0.1368,  

In [59]:
bn = nn.BatchNorm1d(784)
bn.training = False
bn(torch.tensor(np.expand_dims(s, axis=0), dtype=torch.float32))

tensor([[ 0.5282, -0.4630, -0.3560,  0.3242,  0.2986,  0.9084, -0.6866, -0.0322,
         -0.6632,  0.3347, -0.8971,  0.9783,  0.1746,  0.2339,  0.9170, -0.8043,
          0.8858,  0.5832,  0.8288,  0.6002,  0.0873,  0.1066, -0.6711,  0.7424,
          0.0589, -0.9383,  0.1378, -0.1173,  0.5028, -0.7691,  0.2411, -0.0311,
          0.5974, -0.7782, -0.3050, -0.2316, -0.8062, -0.2975, -0.1959, -0.1657,
          0.8577, -0.4954, -0.1729, -0.3457,  0.9202, -0.5183,  0.5251,  0.7135,
         -0.5186,  0.0313,  0.3776, -0.5130,  0.9162,  0.3030,  0.0122,  0.0909,
          0.1205, -0.0144,  0.9688, -0.7357, -0.8532, -0.5544,  0.2906, -0.7739,
         -0.5668, -0.2397, -0.9000, -0.0629,  0.0100, -0.7840, -0.6264,  0.2224,
         -0.5520, -0.6755, -0.1389, -0.7568, -0.7120, -0.9943,  0.7059, -0.2137,
          0.3697, -0.9765, -0.2583,  0.2644,  0.7865,  0.3486,  0.2491,  0.5264,
          0.1822, -0.9926,  0.7162,  0.2418,  0.9793, -0.0218,  0.1499, -0.9348,
          0.3937,  0.1368,  

In [61]:
bn.running_var.shape

torch.Size([784])