## 🧮 BitOps: The Concept

> **BitOps = Number of operations × Bit-width per operation**

In quantized networks, instead of counting full 32-bit floating point operations (FLOPs), we count operations using lower precision (e.g., 2-bit, 4-bit). So, each multiplication or addition between quantized weights and inputs counts as a **bitwise operation**, scaled by the number of bits.


## 🔍 1. **Depthwise Convolution BitOps**

In a **depthwise convolution**, each input channel is convolved separately:

### **Formula:**

$$
\text{BitOps}_{\text{depthwise}} = H_{\text{out}} \times W_{\text{out}} \times C_{\text{in}} \times K \times K \times B
$$

Where:  
- $H_{\text{out}}, W_{\text{out}}$: height and width of the output feature map  $$
- $C_{\text{in}}$: number of input channels  
- $K \text{x} K$: kernel size (usually \$3 x 3$)  
- $ B$: bit-width used for quantized weights 
 


## 🔍 2. **Pointwise Convolution BitOps**

A **pointwise convolution** is a \(1 \times 1\) convolution across channels, essentially a matrix multiplication.

### **Formula:**

$$
\text{BitOps}_{\text{pointwise}} = H_{\text{out}} \times W_{\text{out}} \times C_{\text{in}} \times C_{\text{out}} \times B
$$

**Where:**

<br>

- $H_{\text{out}},\ W_{\text{out}}$: output size (same as input to pointwise conv)  
- $C_{\text{in}}$: number of input channels  
- $C_{\text{out}}$: number of output channels  
- $B$: bit-width used for quantized weights


- **Output Size**:

$$
H_{\text{out}} = \left\lfloor \frac{H_{\text{in}} + 2P - K}{S} \right\rfloor + 1
$$

**Where:**

- \( P \): padding  
- \( K \): kernel size  
- \( S \): stride  


## ✅ Example (Depthwise Layer)

For a depthwise conv with:  
- $Input: 32 x 32, 64 channels$  
- $Kernel: 3 x 3$
- $Stride: 1$  
- $Precision: 4-bit$  

$$
\text{BitOps}_{\text{depthwise}} = 32 \times 32 \times 64 \times 3 \times 3 \times 4 = 2,359,296 \text{ BitOps}
$$


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [2]:

# Helper function for weight quantization
def quantize_weights(weights, num_bits):
    qmin = -2 ** (num_bits - 1)
    qmax = 2 ** (num_bits - 1) - 1
    scale = weights.abs().max() / qmax
    quantized = torch.round(weights / scale).clamp(qmin, qmax) * scale
    return quantized

# Calculate BitOps for a conv layer
def calculate_bitops(conv_layer, input_size, weight_precision):
    out_channels = conv_layer.out_channels
    in_channels = conv_layer.in_channels
    kernel_size = conv_layer.kernel_size[0] * conv_layer.kernel_size[1]
    stride = conv_layer.stride[0]
    padding = conv_layer.padding[0]

    H_in, W_in = input_size
    H_out = (H_in + 2 * padding - conv_layer.kernel_size[0]) // stride + 1
    W_out = (W_in + 2 * padding - conv_layer.kernel_size[1]) // stride + 1

    ops_per_position = in_channels * kernel_size
    total_ops = ops_per_position * H_out * W_out * out_channels

    return total_ops * weight_precision

In [3]:
# Define a depthwise separable conv block with quantization
class QuantizedMobileNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, weight_precision):
        super().__init__()
        self.weight_precision = weight_precision
        self.stride = stride

        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        self.depthwise.weight.data = quantize_weights(self.depthwise.weight.data, self.weight_precision)
        x = self.depthwise(x)
        x = self.bn1(x)
        x = F.relu6(x)

        self.pointwise.weight.data = quantize_weights(self.pointwise.weight.data, self.weight_precision)
        x = self.pointwise(x)
        x = self.bn2(x)
        x = F.relu6(x)
        return x
    
    def bitops_dict(self, input_size):
        dw_bitops = calculate_bitops(self.depthwise, input_size, self.weight_precision)
        H_out = (input_size[0] + 2 * self.depthwise.padding[0] - 3) // self.stride + 1
        W_out = (input_size[1] + 2 * self.depthwise.padding[0] - 3) // self.stride + 1
        pw_bitops = calculate_bitops(self.pointwise, (H_out, W_out), self.weight_precision)

        return {
            'depthwise': dw_bitops,
            'pointwise': pw_bitops,
            'total': dw_bitops + pw_bitops,
            'output_size': (H_out, W_out)
        }

    def bitops(self, input_size):
        dw_bitops = calculate_bitops(self.depthwise, input_size, self.weight_precision)
        H_out = (input_size[0] + 2 * self.depthwise.padding[0] - 3) // self.stride + 1
        W_out = (input_size[1] + 2 * self.depthwise.padding[0] - 3) // self.stride + 1
        pw_bitops = calculate_bitops(self.pointwise, (H_out, W_out), self.weight_precision)
        return dw_bitops + pw_bitops, (H_out, W_out)


In [4]:



# Build the full model
class QuantizedMobileNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.blocks = nn.ModuleList()
        input_channels = 3
        for out_channels, stride, bits in config:
            block = QuantizedMobileNetBlock(input_channels, out_channels, stride, bits)
            self.blocks.append(block)
            input_channels = out_channels
        self.classifier = nn.Linear(input_channels, 10)

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
    
    
# Calculate BitOps for each layer
    def bitops_per_layer(self, input_size):
        current_size = input_size
        all_ops = []
        for idx, block in enumerate(self.blocks):
            ops = block.bitops_dict(current_size)
            all_ops.append({
                'block': idx,
                'depthwise_bitops': ops['depthwise'],
                'pointwise_bitops': ops['pointwise'],
                'total_block_bitops': ops['total']
            })
            current_size = ops['output_size']
        return all_ops

# Calculate total BitOps for the entire model

    def total_bitops(self, input_size):
        total = 0
        current_size = input_size
        for block in self.blocks:
            ops, current_size = block.bitops(current_size)
            total += ops
        return total




In [5]:
# Training loop for CIFAR-10
def train(model, train_loader, val_loader, device, epochs=10):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    train_losses, val_losses, train_accs, val_accs = [], [], [], []

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0, 0, 0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            correct += (outputs.argmax(1) == targets).sum().item()
            total += targets.size(0)

        train_losses.append(total_loss / len(train_loader))
        train_accs.append(correct / total)

        model.eval()
        total_loss, correct, total = 0, 0, 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                total_loss += loss.item()
                correct += (outputs.argmax(1) == targets).sum().item()
                total += targets.size(0)

        val_losses.append(total_loss / len(val_loader))
        val_accs.append(correct / total)

        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accs[-1]:.4f} | Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_accs[-1]:.4f}")

    return {
        'train_loss': train_losses,
        'val_loss': val_losses,
        'train_acc': train_accs,
        'val_acc': val_accs
    }

In [6]:
# Example usage
if __name__ == "__main__":
    mobilenet_config = [
        (32, 1, 8),
        (64, 2, 4),
        (128, 2, 2),
        (128, 1, 4),
        (256, 2, 8),
    ]

    model = QuantizedMobileNet(mobilenet_config)
    dummy_input = torch.randn(1, 3, 32, 32)
    output = model(dummy_input)
    print(output.shape)  

    total_bitops = model.total_bitops((32, 32))
    print(f"Total BitOps: {total_bitops:.2e}")
    layer_ops = model.bitops_per_layer((32, 32))
    for op in layer_ops:
        print(f"Block {op['block']} -> Depthwise: {op['depthwise_bitops']:.2e}, "
              f"Pointwise: {op['pointwise_bitops']:.2e}, "
              f"Total: {op['total_block_bitops']:.2e}")

    

torch.Size([1, 10])
Total BitOps: 8.38e+07
Block 0 -> Depthwise: 6.64e+05, Pointwise: 7.86e+05, Total: 1.45e+06
Block 1 -> Depthwise: 9.44e+06, Pointwise: 2.10e+06, Total: 1.15e+07
Block 2 -> Depthwise: 4.72e+06, Pointwise: 1.05e+06, Total: 5.77e+06
Block 3 -> Depthwise: 3.77e+07, Pointwise: 4.19e+06, Total: 4.19e+07
Block 4 -> Depthwise: 1.89e+07, Pointwise: 4.19e+06, Total: 2.31e+07


In [7]:
# # Example usage
# if __name__ == "__main__":
#     mobilenet_config = [
#         (32, 1, 8),
#         (64, 2, 4),
#         (128, 2, 2),
#         (128, 1, 4),
#         (256, 2, 8),
#     ]

#     model = QuantizedMobileNet(mobilenet_config)

#     transform = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
#     ])

#     train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
#     val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

#     train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
#     val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

#     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#     log = train(model, train_loader, val_loader, device, epochs=1)

#     total_bitops = model.total_bitops((32, 32))
#     print(f"Total BitOps: {total_bitops:.2e}")


In [8]:
# Example usage
if __name__ == "__main__":
    mobilenet_config = [
        (32, 1, 8),
        (64, 2, 4),
        (128, 2, 2),
        (128, 1, 4),
        (256, 2, 8),
    ]

    model = QuantizedMobileNet(mobilenet_config)
    dummy_input = torch.randn(1, 3, 224, 224)
    output = model(dummy_input)
    print(output.shape)  

    total_bitops = model.total_bitops((224, 224))
    print(f"Total BitOps: {total_bitops:.2e}")

torch.Size([1, 10])
Total BitOps: 4.10e+09
