<a href="https://colab.research.google.com/github/aquibjaved/Bits_and_Pieces_DL/blob/main/Simple_BitLinear_layer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!python -m pip install lightning

Collecting lightning
  Downloading lightning-2.2.0.post0-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m15.9 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.8.0 (from lightning)
  Downloading lightning_utilities-0.10.1-py3-none-any.whl (24 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.3.1-py3-none-any.whl (840 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.2.0.post0-py3-none-any.whl (800 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m800.9/800.9 kB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: lightning-utilities, torchmetrics, pytorch-lightning, lightning
Successfully installed lightning-2.2.0.post0 lightning-utilities-0.10.1 pytorch-lightning-2.2.0.post0 t

In [9]:
import lightning as L
import torch.nn.functional as F
from torch.optim import Adam

import torch
import torch.nn as nn
from torch import Tensor, nn

In [36]:
class BitLinear(nn.Linear):
    """
    BitLinear is a custom linear layer that performs binarization of weights and quantization of activations
    in a group-wise manner.

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        bias (bool, optional): If set to False, the layer will not learn an additive bias. Default is True.
        num_groups (int, optional): Number of groups to divide the weights and activations into. Default is 1.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        num_groups: int = 1,
    ):
        super().__init__(in_features, out_features, bias)
        self.num_groups = num_groups
        self.eps = 1e-5
        self.norm = nn.LayerNorm(in_features)

    def ste(self, x):
        """
        Applies the sign function for binarization and uses Straight-Through Estimator (STE) during backward pass.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Binarized tensor.
        """
        binarized_x = torch.sign(x)
        binarized_x = (binarized_x - x).detach() + x
        return binarized_x

    def binarize_weights_groupwise(self):
        """
        Binarizes the weights of the layer in a group-wise manner using STE.

        Returns:
            Tensor: Binarized weights tensor.
        """
        group_size = self.weight.shape[0] // self.num_groups
        binarized_weights = torch.zeros_like(self.weight)

        for g in range(self.num_groups):
            start_idx = g * group_size
            end_idx = (g + 1) * group_size
            weight_group = self.weight[start_idx:end_idx]

            alpha_g = weight_group.mean()
            binarized_weights[start_idx:end_idx] = self.ste(weight_group - alpha_g)

        return binarized_weights

    def quantize_activations_groupwise(self, x, b=8):
        """
        Quantizes the activations of the layer in a group-wise manner.

        Args:
            x (Tensor): Input tensor.
            b (int, optional): Number of bits for quantization. Default is 8.

        Returns:
            Tensor: Quantized activations tensor.
        """
        Q_b = 2 ** (b - 1)

        group_size = x.shape[0] // self.num_groups
        quantized_x = torch.zeros_like(x)

        for g in range(self.num_groups):
            start_idx = g * group_size
            end_idx = (g + 1) * group_size
            activation_group = x[start_idx:end_idx]

            gamma_g = activation_group.abs().max()
            quantized_x[start_idx:end_idx] = torch.clamp(
                activation_group * Q_b / (gamma_g + self.eps),
                -Q_b + self.eps,
                Q_b - self.eps,
            )

        return quantized_x

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the BitLinear layer.

        Args:
            x (Tensor): Input tensor.

        Returns:
            Tensor: Output tensor.
        """
        # Normalize input
        x = self.norm(x)

        # Binarize weights and quantize activations
        binarized_weights = self.binarize_weights_groupwise()

        # Perform linear transformation
        output = torch.nn.functional.linear(x, binarized_weights, self.bias)

        # Quantize activations
        output = self.quantize_activations_groupwise(output)

        # Return output
        return output

In [42]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [43]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

# Load the dataset
dataset = datasets.MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())

# Split dataset into training and testing
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create DataLoader for training data
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Optionally, create DataLoader for test data
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [44]:
import torch

# Assuming you have a DataLoader for your training data: train_loader
model = SimpleNN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.NLLLoss()
num_epochs=1
model.train() # Set the model to training mode
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad() # Zero the gradients
        output = model(data) # Forward pass
        loss = criterion(output, target) # Compute the loss
        loss.backward() # Backward pass
        optimizer.step() # Update parameters


In [45]:
def predict(model, data_loader):
    model.eval() # Set the model to evaluation mode
    true_labels = []
    predictions = []

    with torch.no_grad():
        for data, labels in test_loader:
            output = model(data.view(data.size(0), -1))
            _, preds = torch.max(output, 1)
            true_labels.extend(labels.numpy())
            predictions.extend(preds.numpy())

    return true_labels, predictions

# Assuming you have a DataLoader for your test data: test_loader
true_labels, predictions  = predict(model, test_loader)


In [46]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def calculate_performance_metrics(true_labels, predictions, average='macro'):
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions, average=average)
    recall = recall_score(true_labels, predictions, average=average)
    f1 = f1_score(true_labels, predictions, average=average)

    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1 Score: {f1:.4f}')

calculate_performance_metrics(true_labels, predictions )

Accuracy: 0.9553
Precision: 0.9554
Recall: 0.9548
F1 Score: 0.9549


In [51]:
## With BitLinear Layer

class SimpleBitNN(nn.Module):
    def __init__(self):
        super(SimpleBitNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = BitLinear(512, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [52]:
model = SimpleBitNN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.NLLLoss()
num_epochs=1
model.train() # Set the model to training mode
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad() # Zero the gradients
        output = model(data) # Forward pass
        loss = criterion(output, target) # Compute the loss
        loss.backward() # Backward pass
        optimizer.step() # Update parameters


In [53]:
true_labels, predictions  = predict(model, test_loader)

In [54]:
calculate_performance_metrics(true_labels, predictions )

Accuracy: 0.9148
Precision: 0.9272
Recall: 0.9148
F1 Score: 0.9147
