In [93]:
# libraries
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [94]:
# Device configuration
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [95]:
# Hyper-parameters
input_size = 784 # 28x28 images
hidden_size = 500
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001

In [96]:
# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data/',
                                            train=False,
                                            transform=transforms.ToTensor())


In [97]:
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                            batch_size=batch_size,
                                            shuffle=False)


In [98]:
def binarize(tensor):
    return tensor.sign()

class BinarizeLinear(nn.Linear):
    def __init__(self, in_features, out_features):
        super(BinarizeLinear, self).__init__(in_features, out_features)

    def forward(self, input):
        # input * weight
        
        # binarize input
        input.data = binarize(input.data)

        # binarize weight
        if not hasattr(self.weight, 'org'):
            self.weight.org = self.weight.data.clone()
            
        self.weight.data = binarize(self.weight.org)

        res = nn.functional.linear(input, self.weight)

        return res

In [99]:
# Neural Network Model
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.htanh1 = nn.Hardtanh()

        self.l2 = BinarizeLinear(hidden_size, 400)
        self.bn2 = nn.BatchNorm1d(400)
        self.htanh2 = nn.Hardtanh()

        self.l3 = BinarizeLinear(400, 300)
        self.bn3 = nn.BatchNorm1d(300)
        self.htanh3 = nn.Hardtanh()

        self.l4 = nn.Linear(300, num_classes)

    def forward(self, x):
        out = self.l1(x)
        out = self.bn1(out)
        out = self.htanh1(out)

        out = self.l2(out)
        out = self.bn2(out)
        out = self.htanh2(out)

        out = self.l3(out)
        out = self.bn3(out)
        out = self.htanh3(out)

        out = self.l4(out)
        return out
    
    

In [100]:
# Model, loss function and optimizer
model = NeuralNet(input_size, hidden_size, num_classes).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [101]:
# Train the model

model.train()

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        # optimizer.step() -> this leads to error
        
        # Straight through estimator -> this is used to update the weights before binarizing them
        for p in list(model.parameters()):
            if hasattr(p, 'org'):
                p.data.copy_(p.org)
        
        optimizer.step()

        for p in list(model.parameters()):
            if hasattr(p, 'org'):
                p.org.copy_(p.data.clamp_(-1, 1))

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                   .format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))



Epoch [1/5], Step [100/600], Loss: 0.1832
Epoch [1/5], Step [200/600], Loss: 0.3248
Epoch [1/5], Step [300/600], Loss: 0.1494
Epoch [1/5], Step [400/600], Loss: 0.1810
Epoch [1/5], Step [500/600], Loss: 0.0477
Epoch [1/5], Step [600/600], Loss: 0.2220
Epoch [2/5], Step [100/600], Loss: 0.1475
Epoch [2/5], Step [200/600], Loss: 0.0303
Epoch [2/5], Step [300/600], Loss: 0.2174
Epoch [2/5], Step [400/600], Loss: 0.0974
Epoch [2/5], Step [500/600], Loss: 0.2284
Epoch [2/5], Step [600/600], Loss: 0.1868
Epoch [3/5], Step [100/600], Loss: 0.1199
Epoch [3/5], Step [200/600], Loss: 0.1014
Epoch [3/5], Step [300/600], Loss: 0.1771
Epoch [3/5], Step [400/600], Loss: 0.1383
Epoch [3/5], Step [500/600], Loss: 0.0918
Epoch [3/5], Step [600/600], Loss: 0.0848
Epoch [4/5], Step [100/600], Loss: 0.0864
Epoch [4/5], Step [200/600], Loss: 0.0569
Epoch [4/5], Step [300/600], Loss: 0.0880
Epoch [4/5], Step [400/600], Loss: 0.1625
Epoch [4/5], Step [500/600], Loss: 0.0347
Epoch [4/5], Step [600/600], Loss:

In [102]:
# test the model
model.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

Accuracy of the network on the 10000 test images: 97.14 %
