In [1]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import numpy as np

import matplotlib.pyplot as plt

from model import BinaryLinear

# Taking binarized input data for MNIST

In [2]:
test_dataset = torchvision.datasets.MNIST(root='torch_dataset', 
                                          train=True, 
                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                          ]),
                                          download=True)

In [3]:
dataset = list(test_dataset)

In [4]:
X = np.array([_x[0].numpy() for _x in dataset])
y = np.array([_x[1] for _x in dataset])

In [5]:
X[X < 0.5] = 0
X[X > 0.5] = 1
X = X.astype(np.uint8)

In [6]:
np.save('bin_mnist_3d_tensor.npz', X)

In [7]:
X = X.reshape(60000, -1)

In [8]:
print(X.shape)

(60000, 784)


In [9]:
np.save('bin_mnist_flat.npz', X)

In [10]:
np.savetxt('bin_mnist_flat.csv', X, fmt='%i', delimiter=',')

In [11]:
train_data = torch.utils.data.TensorDataset(torch.from_numpy(X.astype(np.float32)), torch.from_numpy(y))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)

# Try to train a torch model on it

In [12]:
# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 2 layer neural network
class Net(nn.Module):
    def __init__(self, num_classes=10):
        super(Net, self).__init__()
        self.fc = nn.Sequential(
            BinaryLinear(28*28, 32),
            nn.ReLU(),
            BinaryLinear(32, num_classes))
        
    def forward(self, x):
        out = self.fc(x)
        return out

In [14]:
model = Net().to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Train the model
total_step = len(train_loader)
losses = []
for epoch in range(1):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 50 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, 1, i+1, total_step, loss.item()))
            losses.append(loss.item())

Epoch [1/1], Step [50/1875], Loss: 2.2968
Epoch [1/1], Step [100/1875], Loss: 2.2528
Epoch [1/1], Step [150/1875], Loss: 2.2204
Epoch [1/1], Step [200/1875], Loss: 2.0545
Epoch [1/1], Step [250/1875], Loss: 1.8130
Epoch [1/1], Step [300/1875], Loss: 1.7147
Epoch [1/1], Step [350/1875], Loss: 1.5275
Epoch [1/1], Step [400/1875], Loss: 1.4128
Epoch [1/1], Step [450/1875], Loss: 1.2401
Epoch [1/1], Step [500/1875], Loss: 1.2449
Epoch [1/1], Step [550/1875], Loss: 1.3538
Epoch [1/1], Step [600/1875], Loss: 0.9067
Epoch [1/1], Step [650/1875], Loss: 1.0802
Epoch [1/1], Step [700/1875], Loss: 0.7355
Epoch [1/1], Step [750/1875], Loss: 0.8332
Epoch [1/1], Step [800/1875], Loss: 0.6719
Epoch [1/1], Step [850/1875], Loss: 0.7919
Epoch [1/1], Step [900/1875], Loss: 0.8578
Epoch [1/1], Step [950/1875], Loss: 0.5641
Epoch [1/1], Step [1000/1875], Loss: 0.8055
Epoch [1/1], Step [1050/1875], Loss: 0.6227
Epoch [1/1], Step [1100/1875], Loss: 0.5661
Epoch [1/1], Step [1150/1875], Loss: 0.5361
Epoch [1

In [17]:
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    acc = 100 * correct / total
    print('Test Accuracy of the model on the 60000 test images: {} %'.format(acc))

Test Accuracy of the model on the 10000 test images: 84.28666666666666 %
