In [27]:
import torch
from torchvision import datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

In [29]:
# This function handles quantizing a tensor to bfp
def fp32_to_bfp(tensor, group_size, mantissa_bits):
    flat_tensor = tensor.flatten()
    original_size = flat_tensor.size(0)
    padded_size = (original_size + group_size - 1) // group_size * group_size
    padded_tensor = torch.nn.functional.pad(
        flat_tensor, (0, padded_size - original_size)
    )
    padded_tensor = padded_tensor.view(-1, group_size)
    max_exponents = torch.max(padded_tensor.abs().log2().ceil(), dim=1, keepdim=True)[0]
    aligned_mantissas = padded_tensor * 2 ** (
        max_exponents - padded_tensor.abs().log2().ceil()
    )
    scale = 2**mantissa_bits
    truncated_mantissas = torch.floor(aligned_mantissas * scale) / scale
    bfp_values = truncated_mantissas * 2 ** (
        -max_exponents + padded_tensor.abs().log2().ceil()
    )
    bfp_values = bfp_values.view(-1)[:original_size]
    return bfp_values.view(tensor.shape)

In [30]:
# Quant the whole model
def quantize_model_bfp(model, group_size, mantissa_bits):
    for param in model.parameters():
        with torch.no_grad():
            param.copy_(fp32_to_bfp(param, group_size, mantissa_bits))

In [31]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [32]:
# # We have our own Linear layer
# class BFPLinear(nn.Linear):
#     def __init__(
#         self, in_features, out_features, bias=True, group_size=16, mantissa_bits=4
#     ):
#         super(BFPLinear, self).__init__(in_features, out_features, bias)
#         self.group_size = group_size
#         self.mantissa_bits = mantissa_bits

#     def forward(self, input):
#         bfp_weight = fp32_to_bfp(self.weight, self.group_size, self.mantissa_bits)
#         if self.bias is not None:
#             bfp_bias = fp32_to_bfp(self.bias, self.group_size, self.mantissa_bits)
#         else:
#             bfp_bias = None
#         return nn.functional.linear(input, bfp_weight, bfp_bias)

In [33]:
# This is for modyfing models we create (base, quant, and restored)
class BFPModelWrapper(nn.Module):
    def __init__(self, model, group_size, mantissa_bits):
        super(BFPModelWrapper, self).__init__()
        self.model = model
        self.group_size = group_size
        self.mantissa_bits = mantissa_bits
        self.full_precision_params = {
            name: param.clone().detach() for name, param in model.named_parameters()
        }

    def quantize(self):
        quantize_model_bfp(self.model, self.group_size, self.mantissa_bits)

    def restore_full_precision(self):
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                param.copy_(self.full_precision_params[name].detach().clone())

    def forward(self, x):
        return self.model(x)

In [42]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
train_dataset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [43]:
def train_model(model, train_loader, criterion, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 100 == 99:  # print every 100 mini-batches
                print(
                    f"[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}"
                )
                running_loss = 0.0

In [44]:
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [45]:
model = SimpleNN()
model.eval()

SimpleNN(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
)

In [46]:
group_size = 16
mantissa_bits = 4

# Wrap the model with BFPModelWrapper
bfp_model = BFPModelWrapper(model, group_size=group_size, mantissa_bits=mantissa_bits)

In [47]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train the model
train_model(bfp_model, train_loader, criterion, optimizer, epochs=5)

[Epoch 1, Batch 100] loss: 0.663
[Epoch 1, Batch 200] loss: 0.288
[Epoch 1, Batch 300] loss: 0.240
[Epoch 1, Batch 400] loss: 0.220
[Epoch 1, Batch 500] loss: 0.198
[Epoch 1, Batch 600] loss: 0.167
[Epoch 1, Batch 700] loss: 0.153
[Epoch 1, Batch 800] loss: 0.147
[Epoch 1, Batch 900] loss: 0.137
[Epoch 2, Batch 100] loss: 0.096
[Epoch 2, Batch 200] loss: 0.098
[Epoch 2, Batch 300] loss: 0.098
[Epoch 2, Batch 400] loss: 0.102
[Epoch 2, Batch 500] loss: 0.091
[Epoch 2, Batch 600] loss: 0.102
[Epoch 2, Batch 700] loss: 0.094
[Epoch 2, Batch 800] loss: 0.100
[Epoch 2, Batch 900] loss: 0.090
[Epoch 3, Batch 100] loss: 0.068
[Epoch 3, Batch 200] loss: 0.065
[Epoch 3, Batch 300] loss: 0.058
[Epoch 3, Batch 400] loss: 0.063
[Epoch 3, Batch 500] loss: 0.052
[Epoch 3, Batch 600] loss: 0.068
[Epoch 3, Batch 700] loss: 0.061
[Epoch 3, Batch 800] loss: 0.072
[Epoch 3, Batch 900] loss: 0.081
[Epoch 4, Batch 100] loss: 0.043
[Epoch 4, Batch 200] loss: 0.046
[Epoch 4, Batch 300] loss: 0.053
[Epoch 4, 

In [48]:
# Evaluate the full precision model
full_precision_accuracy = evaluate_model(bfp_model, test_loader)
print(f"Accuracy of the full precision model: {full_precision_accuracy}%")

Accuracy of the full precision model: 97.54%


In [49]:
# Quantize the model and evaluate
bfp_model.quantize()
bfp_accuracy = evaluate_model(bfp_model, test_loader)
print(f"Accuracy of the BFP quantized model: {bfp_accuracy}%")

Accuracy of the BFP quantized model: 96.87%


In [52]:
# Restore the full precision model and evaluate
bfp_model.restore_full_precision()
restored_full_precision_accuracy = evaluate_model(bfp_model, test_loader)
print(
    f"Accuracy of the restored full precision model: {restored_full_precision_accuracy}%"
)

Accuracy of the restored full precision model: 11.74%
