# Imports

In [18]:
import torch
import torch.nn.utils.prune as prune 
import numpy as np
import plotly.graph_objects as go 
import torchvision 
import timeit


# Retrieving MNIST Dataset 

In [3]:

train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('../data', 
    train=True, 
    download=True,
    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),])),
    batch_size=1, shuffle=True)

test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('../data', 
    train=False,
    transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),])),
    batch_size=1, shuffle=True)

    

In [4]:
# check the size of the dataset
for i, (images, labels) in enumerate(train_loader):
    print(images.shape)
    print(labels.shape)
    break

torch.Size([1, 1, 28, 28])
torch.Size([1])


# Defining the Neural Network Architecture

In [30]:

"""Referenced from https://www.geeksforgeeks.org/training-neural-networks-with-validation-using-pytorch/"""
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.fc1 = torch.nn.Linear(784, 1000) #input_size (784 becaues MNIST images are of dimension 28*28), hidden size
        self.fc2 = torch.nn.Linear(1000, 1000) #hidden size, output size
        self.fc3 = torch.nn.Linear(1000, 500) #hidden size, output size
        self.fc4 = torch.nn.Linear(500, 200) #hidden size, output size
        self.fc5 = torch.nn.Linear(200, 10) #hidden size, output size
     
    def forward(self,x):
        x = x.view(1, -1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.fc2(x)
        x = torch.nn.functional.relu(x)
        x = self.fc3(x)
        x = torch.nn.functional.relu(x)
        x = self.fc4(x)
        x = torch.nn.functional.relu(x)
        x = self.fc5(x)
        #print(x.shape)
        return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
mnist_model = NeuralNetwork()
if torch.cuda.is_available():
    mnist_model.cuda()
mnist_model.to(device)

loss_function = torch.nn.CrossEntropyLoss() # Criterion
optimizer  = torch.optim.Adam(mnist_model.parameters(), lr=0.001)


cpu


# Model Training and Validation

In [5]:
epochs = 2
min_validation_loss = np.inf
mnist_model.train() # set model to training mode
# if a GPU is available, use it
# if a model is in the directory, load it instead of training
"""Also referenced from https://www.geeksforgeeks.org/training-neural-networks-with-validation-using-pytorch/"""
for epoch in range(epochs):
    training_loss = 0.0    
    for data, labels in train_loader:
        if torch.cuda.is_available():
            data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        target = mnist_model(data)
        loss = loss_function(target, labels)
        loss.backward()
        optimizer.step() # update the weights
        training_loss += loss.item()
    validation_loss = 0.0
    mnist_model.eval()
    for data, labels in test_loader:
        if torch.cuda.is_available():
            data, labels = data.to(device), labels.to(device)
        target = mnist_model(data)
        loss = loss_function(target, labels)
        validation_loss += loss.item()*data.size(0)
    print(f'Epoch {epoch+1}/{epochs} Loss: {training_loss/len(train_loader)}')
    if min_validation_loss > validation_loss:
        print(f'Validation Loss Decreased({min_validation_loss:.6f}--->{validation_loss:.6f})\t Saving The Model')



Epoch 1/3 Loss: 0.3876839260937323
Validation LOss Decreased(inf--->2543.591023)	 Savaing The Model
Epoch 2/3 Loss: 0.2497183335412024
Validation LOss Decreased(inf--->1864.175267)	 Savaing The Model
Epoch 3/3 Loss: 0.22650559566471065
Validation LOss Decreased(inf--->3439.695278)	 Savaing The Model


# Saving the Model

In [6]:
torch.save(mnist_model.state_dict(), 'model_weights.pth')

# Loading the Model

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mnist_model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device(device)))

<All keys matched successfully>

# Neural Network Pruning
- Prune away (set to zero) the k% of weights using weight and unit pruning for k in [0, 25, 50, 60, 70, 80, 90, 95, 97, 99]. Remember not to prune the weights leading to the output logits.

In [46]:
mnist_model = NeuralNetwork() # initializing the model again in case I want to prune again
mnist_model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device(device)))


"""Global pruning logic (https://pytorch.org/tutorials/intermediate/pruning_tutorial.html, Global Pruning Section)"""
def prune_weights(model, pruning_rate):
    parameters_to_prune = (
        (model.fc1, 'weight'),
        (model.fc2, 'weight'),
        (model.fc3, 'weight'),
        (model.fc4, 'weight'),
    )
    prune.global_unstructured(parameters_to_prune, pruning_method = prune.L2Unstructured, amount = pruning_rate)
    return None

"""Accuracy calculation referenced from https://www.geeksforgeeks.org/training-neural-networks-with-validation-using-pytorch/"""
def evaluate_model_accuracy(model, loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data, labels in loader:
            if torch.cuda.is_available():
                data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1) # get the index of the max log-probability, ignore the first output
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct/total
k_percentages = [0, 25, 50, 60, 70, 80, 90, 95, 97, 99] #k% of weights to prune
accuracy_list_weights = []
weight_pruned_model = mnist_model
for k in k_percentages:
    prune_weights(mnist_model, k/100)
    print(f'{k}% of the weights have been pruned')
    #print(list(pruned_model.named_parameters()))
    accuracy_list_weights.append(evaluate_model_accuracy(weight_pruned_model, test_loader))

0% of the weights have been pruned
25% of the weights have been pruned
50% of the weights have been pruned
60% of the weights have been pruned
70% of the weights have been pruned
80% of the weights have been pruned
90% of the weights have been pruned
95% of the weights have been pruned
97% of the weights have been pruned
99% of the weights have been pruned


In [47]:
accuracy_list_weights

[0.0941, 0.0942, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098, 0.098]

# Neural Network Unit Pruning

In [44]:
def prune_neurons(model, pruning_rate):
    parameters_to_prune = (
        (model.fc1, 'bias'),
        (model.fc2, 'bias'),
        (model.fc3, 'bias'),
        (model.fc4, 'bias'),
    )
    prune.global_unstructured(parameters_to_prune, pruning_method = prune.L1Unstructured, amount = pruning_rate)
    return None
    
mnist_model = NeuralNetwork() # initializing the model again in case I want to prune again
mnist_model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device(device)))
accuracy_list_bias = []
bias_pruned_model = mnist_model
for k in k_percentages:
    prune_neurons(mnist_model, k/100)
    print(f'{k}% of the units have been pruned')
    #print(list(pruned_model.named_parameters()))
    accuracy_list_bias.append(evaluate_model_accuracy(bias_pruned_model, test_loader))

0% of the units have been pruned
25% of the units have been pruned
50% of the units have been pruned
60% of the units have been pruned
70% of the units have been pruned
80% of the units have been pruned
90% of the units have been pruned
95% of the units have been pruned
97% of the units have been pruned
99% of the units have been pruned


In [48]:
accuracy_list_bias

[0.0941, 0.0928, 0.098, 0.0918, 0.0979, 0.0983, 0.098, 0.098, 0.098, 0.098]

# Plot the Percent Sparcity vs. Percent Accuracy (weight pruning)
- Prune away (set to zero) the k% of weights using weight and unit pruning for k in [0, 25, 50, 60, 70, 80, 90, 95, 97, 99]. Remember not to prune the weights leading to the output logits.

In [49]:
trace = go.Scatter(x=k_percentages, y=np.array(accuracy_list_weights)*1000, mode='lines+markers',)
fig = go.Figure(data=trace, layout = go.Layout(title='Accuracy vs. Pruning Rate'))

fig.update_yaxes(title_text='Accuracy (%)')
fig.update_xaxes(title_text='Percentage of Weights Pruned (%)')
fig.show()

# Plot the Percent Sparcity vs. Percent Accuracy (unit pruning)

In [50]:
trace = go.Scatter(x=k_percentages, y=np.array(accuracy_list_bias)*1000, mode='lines+markers',)
fig = go.Figure(data=trace, layout = go.Layout(title='Accuracy vs. Pruning Rate'))

fig.update_yaxes(title_text='Accuracy (%)')
fig.update_xaxes(title_text='Percentage of Weights Pruned (%)')
fig.show()

# Bonus Attempt: Speed Comparison between Weight and Unit Pruning and the original network

In [53]:
"""Referenced from https://pytorch.org/tutorials/recipes/recipes/timer_quick_start.html"""
from torch.utils.benchmark import Timer


mnist_model = NeuralNetwork() # initializing the model again in case I want to prune again
mnist_model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device(device)))
timer1 = Timer(
    '''lambda: evaluate_model_accuracy(mnist_model, test_loader),'''
)
timer2 = Timer(
    '''lambda: evaluate_model_accuracy(weight_pruned_model, test_loader),'''
)
timer3 = Timer(
    '''lambda: evaluate_model_accuracy(bias_pruned_model, test_loader),'''
)

m1 = timer1.timeit(number=1)
m2 = timer2.timeit(number=1)
m3 = timer3.timeit(number=1)
print(f'Time for 1 run of the model: {m1}')
print(f'Time for 1 run of the 99% weight pruned model: {m2}')
print(f'Time for 1 run of the 99% unit pruned model: {m3}')

Time for 1 run of the model: <torch.utils.benchmark.utils.common.Measurement object at 0x7fe02006d280>
lambda: evaluate_model_accuracy(mnist_model, test_loader),
  150.00 ns
  1 measurement, 1 runs , 1 thread
Time for 1 run of the pruned model: <torch.utils.benchmark.utils.common.Measurement object at 0x7fe0200da3d0>
lambda: evaluate_model_accuracy(weight_pruned_model, test_loader),
  120.02 ns
  1 measurement, 1 runs , 1 thread
Time for 1 run of the pruned model: <torch.utils.benchmark.utils.common.Measurement object at 0x7fe02006d520>
lambda: evaluate_model_accuracy(bias_pruned_model, test_loader),
  120.02 ns
  1 measurement, 1 runs , 1 thread


# 