<a href="https://colab.research.google.com/github/Hamza-Faarooq/Neural_Network_Pruning_with_SNNs/blob/main/Basic_SNN_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install snntorch

Collecting snntorch
  Downloading snntorch-0.9.4-py2.py3-none-any.whl.metadata (15 kB)
Downloading snntorch-0.9.4-py2.py3-none-any.whl (125 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/125.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m125.6/125.6 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: snntorch
Successfully installed snntorch-0.9.4


In [2]:
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import utils
from snntorch import functional as SF
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [3]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load MNIST
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_data = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 514kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.89MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.20MB/s]


In [4]:
# SNN model
class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 100)
        self.lif1 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())
        self.fc2 = nn.Linear(100, 10)
        self.lif2 = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid())

    def forward(self, x, num_steps=25):
        mem1, mem2 = self.lif1.init_leaky(), self.lif2.init_leaky()
        spk2_rec = []
        for step in range(num_steps):
            cur1 = self.fc1(x.view(x.size(0), -1))
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)

        return torch.stack(spk2_rec).sum(0)

model = SNN().to(device)
loss_fn = SF.mse_count_loss(correct_rate=0.9)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [5]:
for epoch in range(5):
    model.train()
    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)
        # Remove the one-hot encoding line
        # targets_onehot = torch.nn.functional.one_hot(targets, 10).float()
        optimizer.zero_grad()
        output = model(data)
        # Pass the original targets to the loss function
        loss = loss_fn(output, targets)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1, Loss: 2.3227
Epoch 2, Loss: 2.1719
Epoch 3, Loss: 2.2289
Epoch 4, Loss: 2.2602
Epoch 5, Loss: 2.2273


In [6]:
def prune_weights(model, prune_ratio=0.5):
    with torch.no_grad():
        for name, param in model.named_parameters():
            if 'weight' in name:
                flat = param.view(-1)
                k = int(len(flat) * prune_ratio)
                threshold = torch.topk(torch.abs(flat), k, largest=False).values.max()
                mask = torch.abs(param) > threshold
                param *= mask  # Zero out pruned weights

prune_weights(model, prune_ratio=0.5)

In [7]:
def compute_criticality(model, data_sample):
    model.eval()
    data_sample = data_sample.to(device)
    data_sample.requires_grad = True

    output = model(data_sample)
    score = output.sum()
    score.backward()

    criticality = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            criticality[name] = param.grad.abs().mean().item()
    return criticality

# Example: Compute criticality from one sample
data_sample, _ = next(iter(train_loader))
criticality_scores = compute_criticality(model, data_sample[:1])
print(criticality_scores)

{'fc1.weight': 0.000143510740599595, 'fc1.bias': 0.001367890159599483, 'fc2.weight': 0.3957519233226776, 'fc2.bias': 9.710383415222168}


In [8]:
!pip install cvxpy



In [9]:
import cvxpy as cp
import numpy as np

np.random.seed(0)
W = np.random.randn(10)
X = np.random.randn(100, 10)
Y = X @ W + np.random.randn(100) * 0.1

m = cp.Variable(10)  # Continuous mask
objective = cp.Minimize(0.5 * cp.sum_squares(cp.multiply(W, m) @ X.T - Y) + 0.1 * cp.norm(m, 1))
constraints = [m >= 0, m <= 1, cp.sum(m) <= 6]
problem = cp.Problem(objective, constraints)

problem.solve(solver=cp.OSQP)
print("Relaxed pruning mask:", m.value)

Relaxed pruning mask: [ 9.81291267e-01  2.61332984e-01  8.40310495e-01  9.57598391e-01
  1.00000000e+00  8.54759885e-01  8.79327296e-01 -1.17099840e-16
 -1.12465879e-16  2.25379681e-01]


In [10]:
binary_mask = (m.value > 0.5).astype(int)
print("Binary pruning mask:", binary_mask)

Binary pruning mask: [1 0 1 1 1 1 1 0 0 0]


In [11]:
import numpy as np

soft_mask = np.array([
    0.981, 0.261, 0.840, 0.957, 1.0,
    0.855, 0.879, -1.17e-16, -1.12e-16, 0.225
])

threshold = 0.5
binary_mask = (soft_mask > threshold).astype(int)
print("Binary pruning mask:", binary_mask)

Binary pruning mask: [1 0 1 1 1 1 1 0 0 0]


In [12]:
import torch

# dummy weight vector
weights = torch.randn(10)

# apply pruning
weights *= torch.tensor(binary_mask, dtype=torch.float32)

In [13]:
layer = torch.nn.Linear(10, 1)
with torch.no_grad():
    layer.weight *= torch.tensor(binary_mask, dtype=torch.float32).unsqueeze(0)

In [14]:
def compute_sparsity(layer):
    total = layer.weight.numel()
    zeros = torch.sum(layer.weight == 0).item()
    return zeros / total

print("Sparsity:", compute_sparsity(layer))


Sparsity: 0.4


In [16]:
for epoch in range(2):  # You can increase epochs later
    model.train()
    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)

        # Forward pass
        output = model(data)
        loss = loss_fn(output, targets)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Remove the incorrect pruning application:
        # Apply binary pruning mask AFTER weight update
        # for name, param in model.named_parameters():
        #     if "fc1.weight" in name:
        #         param.data *= torch.tensor(binary_mask).float().to(param.device)

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)
