In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import os # Import os for directory creation
import math # Import math for ceil for fixed point conversion

# --- 1. Define the BNN Model Components ---

class BinaryActivation(torch.autograd.Function):
    """
    Binary Activation function (Sign function) with Straight-Through Estimator (STE).
    For the forward pass, the input is binarized to -1 or 1.
    For the backward pass, the gradient is passed through unchanged (STE).
    """
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.sign()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input

class BinarizeWeights(torch.autograd.Function):
    """
    Binarizes weights to -1 or 1 using Straight-Through Estimator (STE).
    For the forward pass, weights are binarized. If a weight is 0 due to pruning,
    it remains 0 in the binarized output (as torch.sign(0) is 0).
    For the backward pass, gradients are computed with respect to the full-precision weights.
    """
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        # Binarize. torch.sign(0) is 0, which correctly handles pruned weights
        # by making them effectively 'no connection' in the forward pass.
        return input.sign()

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        # For binarized weights, the gradient is passed straight through (STE).
        grad_input = grad_output.clone()
        return grad_input

class IntegerBiasSTE(torch.autograd.Function):
    """
    Applies rounding to the bias in the forward pass for integer bias simulation,
    and uses Straight-Through Estimator for backward pass.
    """
    @staticmethod
    def forward(ctx, input_bias):
        # Round the bias to the nearest integer
        return input_bias.round()

    @staticmethod
    def backward(ctx, grad_output):
        # Pass the gradient straight through for backpropagation to the full-precision bias
        return grad_output.clone()

class BNNLinear(nn.Linear):
    """
    Custom Linear layer for BNNs that binarizes its weights during the forward pass
    and quantizes its bias to an integer using STE.
    It implicitly handles pruned weights (set to 0) as `torch.sign(0)` is 0.
    """
    def __init__(self, in_features, out_features, bias=True):
        super(BNNLinear, self).__init__(in_features, out_features, bias)
        self.binarize = BinarizeWeights.apply
        self.quantize_bias = IntegerBiasSTE.apply # Apply integer quantization to bias

    def forward(self, input):
        # Binarize weights. If a weight is 0 due to pruning, its sign is 0,
        # effectively making it a 'no connection' in the linear operation.
        binarized_weight = self.binarize(self.weight)

        # Apply integer quantization to the bias before using it in the linear operation
        quantized_bias = self.quantize_bias(self.bias) if self.bias is not None else None

        output = F.linear(input, binarized_weight, quantized_bias)
        return output

class FullyConnectedBNN(nn.Module):
    """
    Fully Connected Binary Neural Network for MNIST classification.
    Uses custom BNNLinear layers, BatchNorm1d, and BinaryActivation.
    """
    def __init__(self, input_size, num_classes):
        super(FullyConnectedBNN, self).__init__()
        self.input_size = input_size
        self.num_classes = num_classes

        self.fc1 = BNNLinear(input_size, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.act1 = BinaryActivation.apply

        self.fc2 = BNNLinear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.act2 = BinaryActivation.apply

        self.fc3 = BNNLinear(256, num_classes)

    def forward(self, x):
        x = x.view(-1, self.input_size)

        x = self.fc1(x)
        x = self.bn1(x)
        x = self.act1(x)

        x = self.fc2(x)
        x = self.bn2(x)
        x = self.act2(x)

        x = self.fc3(x)
        return x

# --- Helper Functions for Pruning and Export ---

def apply_pruning_to_weights_only(model, sparsity_target):
    """
    Applies magnitude pruning to the weights of BNNLinear layers (excluding the final layer).
    Weights below the calculated threshold are set to 0.

    Args:
        model (nn.Module): The BNN model.
        sparsity_target (float): The desired sparsity (fraction of weights to prune, 0.0 to 1.0).
    """
    print(f"  Applying pruning with target sparsity: {sparsity_target*100:.2f}%")
    with torch.no_grad(): # Ensure no gradient computation during pruning
        for name, module in model.named_modules():
            # Apply pruning only to weights of BNNLinear layers, and exclude the final fc3 layer
            if isinstance(module, BNNLinear) and "fc3" not in name:
                weight = module.weight.data
                num_weights = weight.numel()
                num_prune = int(sparsity_target * num_weights)

                if num_prune <= 0 or num_prune >= num_weights: # Avoid pruning all or none if target is extreme
                    threshold = 0.0 if num_prune <= 0 else weight.abs().max() + 1.0 # Set threshold to prune none/all
                else:
                    abs_weights = weight.abs().flatten()
                    # Find the threshold: the (num_prune)-th smallest absolute weight value
                    # torch.topk with largest=False gives smallest values. We need the last one.
                    threshold = torch.topk(abs_weights, num_prune, largest=False).values[-1]

                # Create a mask: True for weights to keep (abs value > threshold), False for weights to prune
                pruning_mask = (weight.abs() > threshold).float()

                # Apply the mask to the weight data. This sets pruned weights to 0.
                weight.mul_(pruning_mask)

                current_sparsity = (weight == 0).sum().item() / num_weights
                print(f"    Layer {name}: Pruned {num_prune} weights. Actual sparsity: {current_sparsity*100:.2f}%")

def export_pruned_2bit_weights(weight_tensor, layer_name, output_dir):
    """
    Exports pruned and binarized weights of a single layer as a matrix of 2-bit codes.
    Mapping: -1 -> "00", 1 -> "01", 0 (pruned) -> "10".
    Each row in the file represents a row of the weight matrix, with 2-bit codes
    separated by spaces.

    Args:
        weight_tensor (torch.Tensor): The full-precision weight tensor of the layer
                                      which may contain 0s due to pruning.
                                      Shape is (out_features, in_features).
        layer_name (str): The name of the layer (e.g., 'fc1_folded_weight').
        output_dir (str): Directory to save the .txt files.
    """
    # Move to CPU for processing
    weights_cpu = weight_tensor.cpu()

    # Replace '.' in layer names with '_' for valid filenames
    output_filename = os.path.join(output_dir, f"{layer_name.replace('.', '_')}_2bit.txt")

    # Open the file in write mode
    with open(output_filename, 'w') as f:
        # Iterate over rows (output features)
        for row in weights_cpu:
            row_str_parts = []
            for val in row:
                if val == 0: # Pruned weight (full-precision value is 0)
                    row_str_parts.append("10") # Represents 'no connection'
                elif val > 0: # Binarized to 1
                    row_str_parts.append("01")
                else: # Binarized to -1
                    row_str_parts.append("00")
            f.write(' '.join(row_str_parts) + '\n')

    print(f"  Exported {layer_name} to {output_filename} (Matrix size: {weights_cpu.shape[0]}x{weights_cpu.shape[1]}, 2-bit per weight)")

def float_to_fixed_point(value, total_bits, frac_bits):
    """Converts a float to a fixed-point integer representation."""
    scaling_factor = 2**frac_bits
    fixed_val = round(value * scaling_factor)

    # Calculate min/max representable values for a signed fixed-point number
    max_val = (1 << (total_bits - 1)) - 1
    min_val = -(1 << (total_bits - 1))

    # Clip to prevent overflow
    clipped_val = max(min_val, min(max_val, fixed_val))
    return int(clipped_val)

def export_fixed_point_to_mem(float_tensor, param_name, output_dir, total_bits=16, frac_bits=8):
    """
    Exports a float tensor (e.g., BatchNorm parameters, biases) to a .mem file
    in fixed-point hexadecimal format.
    """
    # Ensure tensor is on CPU and flatten
    float_tensor = float_tensor.cpu().flatten()

    mem_content = []
    # Calculate the number of hexadecimal characters needed for the total_bits.
    hex_chars = math.ceil(total_bits / 4)

    for val in float_tensor:
        # Convert the floating-point value to its fixed-point integer representation.
        fixed_val = float_to_fixed_point(val.item(), total_bits, frac_bits)

        # Convert the fixed-point integer to a hexadecimal string.
        # For negative numbers, ensure proper two's complement representation in hex.
        if fixed_val < 0:
            # Mask with 2^total_bits to get two's complement representation
            hex_string = f'{(1 << total_bits) + fixed_val:0{hex_chars}X}'
        else:
            hex_string = f'{fixed_val:0{hex_chars}X}'
        mem_content.append(hex_string)

    output_filename = os.path.join(output_dir, f"{param_name.replace('.', '_')}_fixed.mem")
    with open(output_filename, 'w') as f:
        for hex_val in mem_content:
            f.write(hex_val + '\n')
    print(f"  Exported {param_name} to {output_filename} ({len(mem_content)} words of {total_bits} bits fixed-point)")


def fold_batchnorm(linear_layer, bn_layer):
    """
    Folds Batch Normalization parameters into the preceding linear layer's weights and biases
    for inference only. The linear_layer.weight.data should already reflect any pruning.

    Args:
        linear_layer (nn.Linear): The preceding linear layer (e.g., self.fc1)
        bn_layer (nn.BatchNorm1d): The BatchNorm layer (e.g., self.bn1)

    Returns:
        tuple: (folded_weight, folded_bias) as PyTorch tensors.
    """
    # Get parameters from PyTorch model
    # Note: linear_layer.bias.data here will already be integer-quantized due to the BNNLinear modification
    weight = linear_layer.weight.data # This weight will already have 0s for pruned connections
    bias = linear_layer.bias.data if linear_layer.bias is not None else torch.zeros(weight.shape[0], device=weight.device)
    gamma = bn_layer.weight.data
    beta = bn_layer.bias.data
    running_mean = bn_layer.running_mean
    running_var = bn_layer.running_var
    eps = bn_layer.eps

    # Calculate scale factor: gamma / sqrt(variance + epsilon)
    scale_factor = gamma / torch.sqrt(running_var + eps)

    # Folded Weight: W' = W * scale_factor
    # If W has 0s due to pruning, W' will also have 0s in those positions.
    folded_weight = weight * scale_factor.unsqueeze(1)

    # Folded Bias: B' = beta + (B - mean) * scale_factor
    # This bias (B) is already coming from the quantized_bias in BNNLinear forward pass
    folded_bias = beta + (bias - running_mean) * scale_factor

    return folded_weight, folded_bias


# --- 2. Data Loading and Preprocessing ---
# Define transformations for the MNIST dataset:
# 1. Convert PIL Image to PyTorch Tensor.
# 2. BINARIZE the image: pixels > 0.5 (after ToTensor, pixel values are 0.0-1.0) become 1.0, else 0.0.
#    Then map to -1.0 and 1.0 to align with BNN activations.
transform = transforms.Compose([
    transforms.ToTensor(), # Converts PIL Image to FloatTensor [0.0, 1.0]
    # Binarize: Pixels above 0.5 become 1.0, otherwise 0.0.
    # Then map 0.0 -> -1.0 and 1.0 -> 1.0.
    transforms.Lambda(lambda x: (x > 0.5).float() * 2.0 - 1.0)
])

# Load the MNIST training dataset.
train_dataset = torchvision.datasets.MNIST(
    root='./data',       # Directory where data will be downloaded
    train=True,          # Specify this is the training set
    download=True,       # Download the dataset if not already present
    transform=transform  # Apply the defined transformations
)

# Load the MNIST testing dataset.
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,         # Specify this is the test set
    download=True,
    transform=transform
)

# Define batch size for data loaders.
batch_size = 64

# Create data loaders for training and testing.
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# --- 3. Model Initialization, Loss Function, and Optimizer ---

input_size = 28 * 28
num_classes = 10
model = FullyConnectedBNN(input_size, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# --- 4. Training and Evaluation Functions ---

def train(model, device, train_loader, optimizer, epoch, pruning_active=False):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()

        # Apply gradient masking for pruned weights BEFORE optimizer.step()
        if pruning_active:
            with torch.no_grad(): # Don't track operations for gradient masking
                for name, param in model.named_parameters():
                    # Only mask gradients of weights in BNNLinear layers (not biases, not BatchNorm params)
                    # Check if the module is BNNLinear and it's a weight parameter.
                    # We also explicitly exclude fc3 from this masking/pruning, as it's the output layer.
                    if 'weight' in name and isinstance(model._modules.get(name.split('.')[0]), BNNLinear) and "fc3" not in name and param.grad is not None:
                        # Create mask from current (potentially pruned) weights.
                        # This mask will be 0 for weights that were set to 0 by apply_pruning_to_weights_only.
                        pruning_mask = (param.data != 0).float()
                        param.grad.mul_(pruning_mask) # Zero out gradients for pruned weights

        optimizer.step()

        running_loss += loss.item()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    print(f"Epoch {epoch} Training Loss: {running_loss / len(train_loader):.4f}")

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.0f}%)\n')
    return accuracy

# --- 5. Main Training Loop ---

num_epochs = 10
best_accuracy = 0.0
best_model_path = "" # To store the path of the best saved .pth model

# Directory for saving .mem and .txt files
output_dir = "bnn_pruned_weights_export" # New directory name for this version
os.makedirs(output_dir, exist_ok=True) # Create directory if it doesn't exist

# Fixed-point parameters for exporting folded biases
# Set FP_FRAC_BITS to 0 to ensure biases are integers
FP_TOTAL_BITS = 16
FP_FRAC_BITS = 0 # Forces biases to be integers during export

# Pruning parameters
pruning_sparsity = 0.1
pruning_start_epoch = 3 # Start pruning from this epoch onwards

for epoch in range(1, num_epochs + 1):
    # Determine if pruning should be active for this epoch's training and weight update
    pruning_active_this_epoch = (epoch >= pruning_start_epoch)
    train(model, device, train_loader, optimizer, epoch, pruning_active=pruning_active_this_epoch)

    if pruning_active_this_epoch:
        # Apply pruning to the weights after each training epoch (if pruning is active)
        apply_pruning_to_weights_only(model, pruning_sparsity)

        accuracy = test(model, device, test_loader) # This accuracy now reflects integer biases and pruned weights
    else:
      accuracy = 0

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        # Save the best model's full state_dict
        best_model_path = f"best_bnn_mnist_pruned_accuracy_{best_accuracy:.2f}%.pth"
        torch.save(model.state_dict(), best_model_path)
        print(f"Saved new best model .pth: {best_model_path}")

        print(f"Exporting weights and biases for best model (Accuracy: {best_accuracy:.2f}%)...")

        # Ensure model is in eval mode for correct BN behavior during folding
        model.eval()
        # Fold batchnorm: The linear_layer.weight.data should already be pruned.
        fc1_folded_w, fc1_folded_b = fold_batchnorm(model.fc1, model.bn1)
        fc2_folded_w, fc2_folded_b = fold_batchnorm(model.fc2, model.bn2)
        model.train() # Set back to train mode

        # Export pruned and binarized folded weights as 2-bit plain text matrix
        export_pruned_2bit_weights(fc1_folded_w, 'fc1_folded_weight', output_dir)
        export_pruned_2bit_weights(fc2_folded_w, 'fc2_folded_weight', output_dir)

        # Export fixed-point folded biases to .mem files (now effectively integers)
        export_fixed_point_to_mem(fc1_folded_b, 'fc1_folded_bias', output_dir,
                                  total_bits=FP_TOTAL_BITS, frac_bits=FP_FRAC_BITS)
        export_fixed_point_to_mem(fc2_folded_b, 'fc2_folded_bias', output_dir,
                                  total_bits=FP_TOTAL_BITS, frac_bits=FP_FRAC_BITS)

        # For the final layer (fc3), no BN is applied, so export original weights and biases.
        # The fc3 weights are NOT pruned by `apply_pruning_to_weights_only` by design.
        export_pruned_2bit_weights(model.fc3.weight.data, 'fc3_weight', output_dir)
        export_fixed_point_to_mem(model.fc3.bias.data, 'fc3_bias', output_dir,
                                  total_bits=FP_TOTAL_BITS, frac_bits=FP_FRAC_BITS)

        print("Finished exporting weights and biases files.")

print(f"\nTraining finished. Best Test Accuracy: {best_accuracy:.2f}%")
if best_model_path:
    print(f"Best model .pth saved at: {best_model_path}")
    print(f"Weights and biases files saved in: {output_dir}")


Epoch 1 Training Loss: 3.6230
Epoch 2 Training Loss: 2.9793
Epoch 3 Training Loss: 2.8803
  Applying pruning with target sparsity: 10.00%
    Layer fc1: Pruned 40140 weights. Actual sparsity: 10.00%
    Layer fc2: Pruned 13107 weights. Actual sparsity: 10.00%

Test set: Average loss: 0.0406, Accuracy: 8575/10000 (86%)

Saved new best model .pth: best_bnn_mnist_pruned_accuracy_85.75%.pth
Exporting weights and biases for best model (Accuracy: 85.75%)...
  Exported fc1_folded_weight to bnn_pruned_weights_export/fc1_folded_weight_2bit.txt (Matrix size: 512x784, 2-bit per weight)
  Exported fc2_folded_weight to bnn_pruned_weights_export/fc2_folded_weight_2bit.txt (Matrix size: 256x512, 2-bit per weight)
  Exported fc1_folded_bias to bnn_pruned_weights_export/fc1_folded_bias_fixed.mem (512 words of 16 bits fixed-point)
  Exported fc2_folded_bias to bnn_pruned_weights_export/fc2_folded_bias_fixed.mem (256 words of 16 bits fixed-point)
  Exported fc3_weight to bnn_pruned_weights_export/fc3_wei