<a href="https://colab.research.google.com/github/AnuragQ/Audio-Signal-Processing/blob/master/PytorchDLMNISTWeightPruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
import torch
import numpy as np
from torchvision import datasets, transforms

In [0]:
# Construct a ReLU-activated neural network with four hidden layers with sizes [1000, 1000,
# 500, 200]. Note: you’ll have a fifth layer for your output logits, which you will have 10 of.
# 3. Train your network on MNIST or Fashion-MNIST (your choice, whatever is easier)
# 4. Prune away (set to zero) the k% of weights using weight and unit pruning for k in [0, 25,
# 50, 60, 70, 80, 90, 95, 97, 99]. Remember not to prune the weights leading to the output
# logits.
# 5. Create a table or plot showing the percent sparsity (number of weights in your network
# that are zero) versus percent accuracy with two curves (one for weight pruning and one
# for unit pruning).
# 6. Make your code clean and readable. Add comments where needed.
# 7. Follow the submission guidelines

class Net(nn.Module):
    def __init__(self, layers):
        super().__init__()
        n_layers=[]
        
        for i in range(len(layers)-2):
            n_layers.append(nn.Linear(layers[i], layers[i+1]))
            n_layers.append(nn.ReLU(inplace=True))
        
        n_layers.append(nn.Linear(layers[-2], layers[-1]))
        
        self.model = nn.ModuleList(n_layers)
#         self.model= nn.Sequential(
#             nn.Linear(784, 1000),
#             nn.Linear(1000, 10)
#         )
    def forward(self, x):
        x=x.view(-1,784)
        for l in self.model:
            x=l(x)
        return x
#     def set_masks(self, masks,layers):
#         # Should be a less manual way to set masks
#         # Leave it for the future
#         for i in range(len(layers)-2):
            
#             self.n_layers[i].set_mask(masks[i])
        


In [260]:
Net([784, 1000, 1000, 500, 200, 10])

Net(
  (model): ModuleList(
    (0): Linear(in_features=784, out_features=1000, bias=True)
    (1): ReLU(inplace)
    (2): Linear(in_features=1000, out_features=1000, bias=True)
    (3): ReLU(inplace)
    (4): Linear(in_features=1000, out_features=500, bias=True)
    (5): ReLU(inplace)
    (6): Linear(in_features=500, out_features=200, bias=True)
    (7): ReLU(inplace)
    (8): Linear(in_features=200, out_features=10, bias=True)
  )
)

In [261]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = Net([784, 1000, 1000, 500, 200, 10]).to(device)
optimizer = optim.RMSprop(net.parameters(), lr=1e-3)

critic = nn.CrossEntropyLoss()
net

Net(
  (model): ModuleList(
    (0): Linear(in_features=784, out_features=1000, bias=True)
    (1): ReLU(inplace)
    (2): Linear(in_features=1000, out_features=1000, bias=True)
    (3): ReLU(inplace)
    (4): Linear(in_features=1000, out_features=500, bias=True)
    (5): ReLU(inplace)
    (6): Linear(in_features=500, out_features=200, bias=True)
    (7): ReLU(inplace)
    (8): Linear(in_features=200, out_features=10, bias=True)
  )
)

In [0]:
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


train_loader = DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=train_transforms), batch_size=64, shuffle=True
)

test_loader = DataLoader(
    datasets.MNIST('.', train=False, download=True, transform=test_transforms), batch_size=64, shuffle=False
)

In [0]:
def train(epochs, model, optimizer, critic, train_dl):
    model.train()
    
    for ep in range(epochs):
        print(f'Epoch: {ep+1}')
        for batch_i, (x_batch, y_batch) in enumerate(train_dl):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            preds = model(x_batch)
            loss = critic(preds, y_batch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch_i % 100 == 0:
                pred_labels = torch.argmax(preds, dim=1)
                acc = (pred_labels == y_batch).sum().float()
                print(f'\tLoss: {loss.item():.4f} \t Accuracy: {acc/x_batch.shape[0]:.2f}')

In [264]:
train(2,net, optimizer, critic, train_loader)

Epoch: 1
	Loss: 2.3018 	 Accuracy: 0.14
	Loss: 0.2167 	 Accuracy: 0.94
	Loss: 0.2040 	 Accuracy: 0.97
	Loss: 0.2563 	 Accuracy: 0.94
	Loss: 0.2305 	 Accuracy: 0.95
	Loss: 0.2187 	 Accuracy: 0.94
	Loss: 0.1065 	 Accuracy: 0.95
	Loss: 0.1241 	 Accuracy: 0.95
	Loss: 0.1499 	 Accuracy: 0.97
	Loss: 0.1676 	 Accuracy: 0.94
Epoch: 2
	Loss: 0.3414 	 Accuracy: 0.94
	Loss: 0.1224 	 Accuracy: 0.97
	Loss: 0.1628 	 Accuracy: 0.95
	Loss: 0.1254 	 Accuracy: 0.97
	Loss: 0.0307 	 Accuracy: 0.98
	Loss: 0.0273 	 Accuracy: 1.00
	Loss: 0.2840 	 Accuracy: 0.97
	Loss: 0.0857 	 Accuracy: 0.98
	Loss: 0.0675 	 Accuracy: 0.98
	Loss: 0.0241 	 Accuracy: 0.98


In [0]:
# def weight_prune(model, pruning_perc):
#     '''
#     Prune pruning_perc% weights globally (not layer-wise)
#     arXiv: 1606.09274
#     '''    
#     all_weights = []
#     for p in model.parameters():
#         if len(p.data.size()) != 1:
#             all_weights += list(p.cpu().data.abs().numpy().flatten())
#     threshold = np.percentile(np.array(all_weights), pruning_perc)

#     # generate mask
#     masks = []
#     for p in model.parameters():
#         if len(p.data.size()) != 1:
#             pruned_inds = p.data.abs() > threshold
#             masks.append(pruned_inds.float())
#     return masks

In [266]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images.to(device))
        outputs.data=outputs.data.to(device)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on test images: 97 %


In [0]:
for child in net.children():
    for param in child.parameters():
        param =  torch.zeros(param.size())
        param.requires_grad = False
# param = {
#     'pruning_perc': 90.,
#     'batch_size': 128, 
#     'test_batch_size': 100,
#     'num_epochs': 5,
#     'learning_rate': 0.001,
#     'weight_decay': 5e-4,
# }



# masks = weight_prune(net, param['pruning_perc'])
# net.set_masks(masks)
# print("--- {}% parameters pruned ---".format(param['pruning_perc']))

In [0]:
all_weights = []
for p in net.parameters():
    if len(p.data.size()) != 1:
        all_weights += list(p.cpu().data.abs().numpy().flatten())
threshold = np.percentile(np.array(all_weights), 90)

In [269]:
print(threshold)

0.003994928859174252


In [0]:
for arr in net.model[0::2]:
    
    for j in range(arr.weight.data.size()[1]):
            for i in range(arr.weight.data.size()[0]):
                if(abs(arr.weight.data[i][j].item())<threshold):
                    arr.weight.data[i][j]=0              

        

In [271]:
net.model[2].weight.data

tensor([[-0.0218, -0.0189, -0.0090,  ...,  0.0456, -0.0187, -0.0462],
        [ 0.0073, -0.0241, -0.0579,  ..., -0.0179, -0.0117, -0.0088],
        [ 0.0152,  0.0079,  0.0160,  ..., -0.0425,  0.0000,  0.0366],
        ...,
        [-0.0140, -0.0063, -0.0113,  ...,  0.0152, -0.0134,  0.0073],
        [ 0.0408, -0.0349, -0.0079,  ...,  0.0286, -0.0308, -0.0211],
        [ 0.0155,  0.0167, -0.0434,  ..., -0.0292,  0.0081,  0.0376]])

In [275]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = net(images.to(device))
        outputs.data=outputs.data.to(device)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on test images: 96 %
