In [1]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import numpy as np

import matplotlib.pyplot as plt

from model import BinaryLinear

# Taking binarized input data for MNIST

In [2]:
test_dataset = torchvision.datasets.MNIST(root='torch_dataset', 
                                          train=True, 
                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                          ]),
                                          download=True)

In [3]:
dataset = list(test_dataset)

In [4]:
X = np.array([_x[0].numpy() for _x in dataset])
y = np.array([_x[1] for _x in dataset])

In [5]:
X[X < 0.5] = 0
X[X > 0.5] = 1
X = X.astype(np.uint8)

In [6]:
np.save('bin_mnist_3d_tensor.npz', X)

In [7]:
X = X.reshape(60000, -1)

In [8]:
print(X.shape)

(60000, 784)


In [9]:
np.save('bin_mnist_flat.npz', X)

In [10]:
np.savetxt('bin_mnist_flat.csv', X, fmt='%i', delimiter=',')

In [11]:
train_data = torch.utils.data.TensorDataset(torch.from_numpy(X.astype(np.float32)), torch.from_numpy(y))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)

# Try to train a torch model on it

In [12]:
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 2 layer neural network
class Net(nn.Module):
    def __init__(self, num_classes=10):
        super(Net, self).__init__()
        self.fc = nn.Sequential(
            BinaryLinear(28*28, 128),
            nn.ReLU(),
            BinaryLinear(128, 64),
            nn.ReLU(),
            BinaryLinear(64, num_classes))
        
    def forward(self, x):
        out = self.fc(x)
        return out

In [13]:
model = Net().to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Train the model
total_step = len(train_loader)
losses = []

max_loss = float('inf')

for epoch in range(10):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 50 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, 1, i+1, total_step, loss.item()))
            if loss.item() < max_loss:
                max_loss = loss.item()
                torch.save(model.state_dict(), 'model.ckpt')

Epoch [1/1], Step [50/1875], Loss: 2.3025
Epoch [1/1], Step [100/1875], Loss: 2.3050
Epoch [1/1], Step [150/1875], Loss: 2.2385
Epoch [1/1], Step [200/1875], Loss: 2.1703
Epoch [1/1], Step [250/1875], Loss: 1.7583
Epoch [1/1], Step [300/1875], Loss: 1.3345
Epoch [1/1], Step [350/1875], Loss: 1.2034
Epoch [1/1], Step [400/1875], Loss: 1.2178
Epoch [1/1], Step [450/1875], Loss: 1.0049
Epoch [1/1], Step [500/1875], Loss: 1.0994
Epoch [1/1], Step [550/1875], Loss: 0.7945
Epoch [1/1], Step [600/1875], Loss: 0.5994
Epoch [1/1], Step [650/1875], Loss: 0.3887
Epoch [1/1], Step [700/1875], Loss: 0.6215
Epoch [1/1], Step [750/1875], Loss: 0.4411
Epoch [1/1], Step [800/1875], Loss: 0.3505
Epoch [1/1], Step [850/1875], Loss: 0.3562
Epoch [1/1], Step [900/1875], Loss: 0.3850
Epoch [1/1], Step [950/1875], Loss: 0.4271
Epoch [1/1], Step [1000/1875], Loss: 0.4200
Epoch [1/1], Step [1050/1875], Loss: 0.6697
Epoch [1/1], Step [1100/1875], Loss: 0.5931
Epoch [1/1], Step [1150/1875], Loss: 0.4249
Epoch [1

Epoch [6/1], Step [250/1875], Loss: 0.3734
Epoch [6/1], Step [300/1875], Loss: 0.1757
Epoch [6/1], Step [350/1875], Loss: 0.1315
Epoch [6/1], Step [400/1875], Loss: 0.2029
Epoch [6/1], Step [450/1875], Loss: 0.3972
Epoch [6/1], Step [500/1875], Loss: 0.1979
Epoch [6/1], Step [550/1875], Loss: 0.0960
Epoch [6/1], Step [600/1875], Loss: 0.3042
Epoch [6/1], Step [650/1875], Loss: 0.5213
Epoch [6/1], Step [700/1875], Loss: 0.0786
Epoch [6/1], Step [750/1875], Loss: 0.4435
Epoch [6/1], Step [800/1875], Loss: 0.2376
Epoch [6/1], Step [850/1875], Loss: 0.2973
Epoch [6/1], Step [900/1875], Loss: 0.2291
Epoch [6/1], Step [950/1875], Loss: 0.2222
Epoch [6/1], Step [1000/1875], Loss: 0.3481
Epoch [6/1], Step [1050/1875], Loss: 0.1654
Epoch [6/1], Step [1100/1875], Loss: 0.1456
Epoch [6/1], Step [1150/1875], Loss: 0.2093
Epoch [6/1], Step [1200/1875], Loss: 0.3843
Epoch [6/1], Step [1250/1875], Loss: 0.2160
Epoch [6/1], Step [1300/1875], Loss: 0.4836
Epoch [6/1], Step [1350/1875], Loss: 0.1395
Epo

In [14]:
model.load_state_dict(torch.load('model.ckpt'))

In [46]:
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    print('Test Accuracy of the model on the 60000 test images: {} %'.format(acc))

Test Accuracy of the model on the 60000 test images: 92.76333333333334 %


In [42]:
params = dict(model.fc.named_parameters())

In [44]:
# # To scale up the parameters
# for pname in params:
#     params[pname].data.copy_(params[pname].data * 100)

# To quantize into ints
for pname in params:
    params[pname].data.copy_(params[pname].round())

# To change into quantized values
# for pname in params:
#     w = params[pname]
#     avg = torch.mean(torch.abs(w))
#     sign = w.sign()
#     params[pname].data.copy_(avg*sign)

In [45]:
params

{'0.weight': Parameter containing:
 tensor([[ 777., -777.,  777.,  ..., -777., -777.,  777.],
         [ 777., -777.,  777.,  ...,  777.,  777., -777.],
         [-777.,  777., -777.,  ..., -777., -777.,  777.],
         ...,
         [ 777.,  777.,  777.,  ..., -777.,  777.,  777.],
         [-777.,  777., -777.,  ..., -777.,  777.,  777.],
         [ 777., -777., -777.,  ..., -777., -777.,  777.]],
        device='cuda:0', requires_grad=True), '0.bias': Parameter containing:
 tensor([-188.,  188., -188.,  188., -188., -188., -188., -188., -188.,    0.,
          188., -188.,  188., -188.,  188., -188.,  188., -188., -188., -188.,
         -188., -188., -188., -188., -188.,  188., -188., -188.,  188.,  188.,
          188.,  188., -188., -188., -188.,  188., -188.,  188., -188., -188.,
         -188., -188.,  188., -188., -188., -188., -188., -188., -188., -188.,
         -188., -188.,  188.,  188., -188.,  188.,  188., -188., -188.,  188.,
         -188., -188.,  188., -188., -188., 

In [54]:
# To actually retrieve the values, we take the magnitude and the signs:
for pname in params:
    w = params[pname].data.cpu().numpy()
    mag = np.amax(w)
    sign = np.sign(w)
    sign[sign < 0] = 0
    sign = sign.astype(np.int8)
    print(mag, sign)
    np.save('{}__{}'.format(pname, int(mag)), sign)

777.0 [[1 0 1 ... 0 0 1]
 [1 0 1 ... 1 1 0]
 [0 1 0 ... 0 0 1]
 ...
 [1 1 1 ... 0 1 1]
 [0 1 0 ... 0 1 1]
 [1 0 0 ... 0 0 1]]
188.0 [0 1 0 1 0 0 0 0 0 0 1 0 1 0 1 0 1 0 0 0 0 0 0 0 0 1 0 0 1 1 1 1 0 0 0 1 0
 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 0 1 1 0 0 1 0 0 1 0 0 0 0 0 1 1 1 1 1 1
 0 0 0 1 1 0 0 0 1 0 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 0 0 0 1 0 1 0 1 0 0
 0 0 1 0 1 1 0 0 1 1 1 0 1 1 1 0 1]
506.0 [[0 0 0 ... 1 0 1]
 [1 0 0 ... 1 0 1]
 [0 0 1 ... 0 1 0]
 ...
 [1 1 1 ... 1 0 1]
 [0 1 0 ... 1 0 0]
 [1 1 0 ... 1 0 0]]
11.0 [0 0 1 0 1 1 0 1 0 0 0 1 0 1 0 1 0 1 1 0 1 0 1 0 1 0 1 0 0 0 1 1 0 1 1 1 1
 0 1 1 0 1 1 0 0 0 0 1 1 1 0 0 0 0 0 1 0 0 1 1 1 1 0 0]
970.0 [[1 1 0 0 0 1 1 0 1 0 0 0 0 1 0 0 1 0 0 0 1 1 1 1 0 1 1 1 1 1 0 1 1 1 1 0
  1 0 0 0 1 1 0 0 0 1 0 0 1 1 0 1 0 0 0 0 1 0 0 0 1 0 1 1]
 [1 1 0 0 1 0 0 1 1 1 0 0 0 1 1 0 1 1 0 1 1 1 0 1 0 0 0 1 0 0 0 1 1 0 1 1
  1 0 0 0 0 0 1 0 0 0 1 0 0 0 1 1 0 0 1 1 1 1 0 0 0 1 1 0]
 [0 0 1 0 1 1 1 0 1 0 1 1 1 0 1 0 0 1 1 1 1 1 0 0 0 0 0 0 1 1 0 1 0 1 0 0

In [70]:
# sanity check
b0 = np.load('0.bias__188.npy')
w0 = np.load('0.weight__777.npy').astype(int)
b2 = np.load('2.bias__11.npy')
w2 = np.load('2.weight__506.npy').astype(int)
b4 = np.load('4.bias__98.npy')
w4 = np.load('4.weight__970.npy').astype(int)

w0[w0 == 0] = -1
w2[w2 == 0] = -1
w4[w4 == 0] = -1

In [73]:
x = X.astype(int)
x = (x.dot(w0.T) * 777) + (b0 * 188)
x[x < 0] = 0
x = np.sign(x) * int(np.mean(np.abs(x)))
x = (x.dot(w2.T) * 506) + (b2 * 11)
x[x < 0] = 0
x = np.sign(x) * int(np.mean(np.abs(x)))
x = (x.dot(w4.T) * 970) + (b4 * 98)

In [74]:
x.dtype

dtype('int64')

In [75]:
y

array([5, 0, 4, ..., 5, 6, 8])

In [None]:
np.mean(np.argmax(x, axis=1) == y)