# Train a 1-bit model

This notebook is an example of a building a 1-bit model which can be used to classify images.

One of the reasons why this model is intersting is all of the information is stored in the network/graph structure and there are only weights of +/-1 in the model.

This means the model can be more easily adapted to situations which have low computing power but high throughput.

In [1]:
# Importing the classes we need

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
# Device configuration
device = torch.device('cpu')

# Hyperparameters for our training
batch_size = 64
learning_rate = 0.01
epochs = 10

In [3]:
# We're looking to build an integer based model after training so load our dataset into ints vs floats

# Ensure the input is a tensor before applying the integer conversion
def preprocess_mnist_to_uint8(images):
    images = transforms.ToTensor()(images)  # Convert PIL to Tensor
    images = (images * 255).clamp(0, 255).to(torch.int32)  # Convert to uint8
    return images

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.Lambda(preprocess_mnist_to_uint8)  # Apply uint8 conversion
])

In [4]:
# As in previous weeks we need to define a data-loader for our datasets

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
# Binary Quantization Function with STE (+1/-1)
class BinaryQuantizeSTE(torch.autograd.Function):

    # Moving forwards in training we want to pass only the sign of the weight
    # This is something which can be quantized into 1-bit of information
    @staticmethod
    def forward(ctx, input):
        return input.sign()  # Binary quantization to +1/-1

    # Moving back through the model we want to pass the full gradient
    # This makes it easier when training to get a stable result
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output  # Straight-Through Estimator (STE)

# Binary Neural Network with Sign Activation
class BinaryNet(nn.Module):
    def __init__(self):
        super(BinaryNet, self).__init__()

        # We want to construct 3 FUlly-Interconnected layers without biases
        # These should go from the mnist input (28x28) to a classifier decision (10)
        # I would recommend input -> 1024 -> 512 -> 10
        self.fc1 = nn.Linear(28*28, ## FINISH_ME ##
        self.fc2 = nn.Linear( ## FINISH_ME ##
        self.fc3 = nn.Linear(512, 10, bias=False)

    def forward(self, x):
        # First we 'flatten' our input into 1D so DNN can use this
        x = x.view(-1, 28*28).to(torch.float)

        # Now we need to use our custom layer to quantize the outputs as we train
        x = BinaryQuantizeSTE.apply(self.fc1(x))
        x = BinaryQuantizeSTE.apply( ## FINISH_ME ## )

        # Careful not to binary-quantize the output layer for stability
        x = self.fc3(x)
        return x

In [6]:
# Model, Loss, Optimizer
model = BinaryNet().to(device)
# We want binary cross entropy for our loss function
criterion = nn.CrossEntropyLoss()
# We will use SGD with the learning rate given above
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

In [7]:
# Training and Evaluation
for epoch in range(epochs):

    # First put the model into training mode and set some observables
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # Iterate through all of our training data
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Perform the training
        # Forward evaluate
        # Calculate loss
        # Back-Propagate
        # Apply Optimization
        optimizer.zero_grad()
        outputs = ## FINISH_ME ##
        loss = ## FINISH_ME ##
        loss.backward()
        optimizer.step()

        # Take our losses and outputs and work out
        # How many times did we evaluate correctly? 
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    # Print training stats per epoch
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')

Epoch [1/10], Loss: 0.3724, Accuracy: 89.29%
Epoch [2/10], Loss: 0.2789, Accuracy: 91.88%
Epoch [3/10], Loss: 0.2644, Accuracy: 92.30%
Epoch [4/10], Loss: 0.2564, Accuracy: 92.59%
Epoch [5/10], Loss: 0.2505, Accuracy: 92.65%
Epoch [6/10], Loss: 0.2502, Accuracy: 92.69%
Epoch [7/10], Loss: 0.2467, Accuracy: 92.81%
Epoch [8/10], Loss: 0.2454, Accuracy: 92.82%
Epoch [9/10], Loss: 0.2449, Accuracy: 92.83%
Epoch [10/10], Loss: 0.2458, Accuracy: 92.73%


In [8]:
# This method will capture the quantized weights from our trained model as ints

# Quantize weights to ±1
def quantize_weights(model):
    quantized_state_dict = {}
    for name, param in model.named_parameters():
        quantized_state_dict[name] = param.detach().sign().to(torch.int32)
    return quantized_state_dict

# Lets store the quantized weight values
quantized_weights = quantize_weights(model)

In [9]:
# Now we want to build a model which only uses integer weights

# Optimized 1-bit Integer Model
class OneBitIntegerNet(nn.Module):
    def __init__(self, quantized_weights):
        super(OneBitIntegerNet, self).__init__()
        self.fc1_weight = quantized_weights['fc1.weight'].to(device).to(torch.int32)
        self.fc2_weight = quantized_weights['fc2.weight'].to(device).to(torch.int32)
        self.fc3_weight = quantized_weights['fc3.weight'].to(device).to(torch.int32)

    def forward(self, x):
        # Again we need to 'flatten' our input
        x = x.view(-1, 28*28)  # Scale to 0-1

        # Now we want to explicity evaluate forward through our 1-bit model
        x = torch.matmul(x, self.fc1_weight.t())
        x = torch.matmul(x, self.fc2_weight.t())
        x = torch.matmul(x, self.fc3_weight.t())
        return x

In [10]:
# Full Precision Model Inference Accuracy
model.eval()
correct = 0
total = 0

with torch.no_grad():
    # Iterate through out test data
    for images, labels in test_loader:

        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        # Track how many times we evaluated and how correct we are
        total += labels.size(0)
        correct += ## FINISH_ME ##

print(f'Test Accuracy with Full Precision Model: {100 * correct / total:.2f}%')

Test Accuracy with Full Precision Model: 92.36%


In [11]:
# Testing with 1-bit Integer Model Accuracy
one_bit_model = OneBitIntegerNet(quantized_weights)
one_bit_model = one_bit_model.to(device)

one_bit_model.eval()
correct = 0
total = 0

with torch.no_grad():
    # Iterate through out test data
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = one_bit_model(images)
        _, predicted = outputs.max(1)
        # Track how many times we evaluated and how correct we are
        total += labels.size(0)
        correct += ## FINISH_ME ##

print(f'Test Accuracy with 1-bit Integer Model: {100 * correct / total:.2f}%')

Test Accuracy with 1-bit Integer Model: 78.80%
