we wish to use the MNIST, construct a feed forward neural network for pruning. such that we remove weights, if the mutual information between those neurons were too low



In [2]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from scipy.stats import entropy
import torch.nn.functional as F
import matplotlib.pyplot as plt


In [3]:


# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=False)
testloader = DataLoader(datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform), batch_size=64, shuffle=False)

# Define the feed-forward neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 10)
        self.input_fc1 = []
        self.output_fc1 = []
        self.output_fc2 = []

    def forward(self, x):
        x = x.view(-1, 28*28)
        self.input_fc1.extend(x.detach().cpu().numpy())
        x = torch.relu(self.fc1(x))
        self.output_fc1.extend(x.detach().cpu().numpy())
        x = self.fc2(x)
        self.output_fc2.extend(F.softmax(x, dim=1).detach().cpu().numpy())
        return x


In [4]:
# Train the network
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [5]:
# Set the number of epochs (full passes through the dataset)
epochs = 5

# Loop over each epoch
for e in range(epochs):
    # Initialize counters for the number of correct predictions and the total number of predictions
    correct = 0
    total = 0

    # Loop over each batch of images and labels in the training data
    for images, labels in trainloader:
        # Zero the gradients of the model parameters
        optimizer.zero_grad()

        # Pass the images through the model to get the output probabilities
        output = model(images)

        # Calculate the loss between the output probabilities and the true labels
        loss = criterion(output, labels)

        # Backpropagate the gradients of the loss with respect to the model parameters
        loss.backward()

        # Update the model parameters
        optimizer.step()

        # Update the total number of predictions
        total += labels.size(0)

        # Update the number of correct predictions
        correct += (torch.argmax(output, dim=1) == labels).sum().item()

    # Print the loss and accuracy for this epoch
    print(f"Epoch {e+1}/{epochs}, Loss: {loss.item()}")
    print(f"Accuracy: {correct/total}")

# Print a message to indicate that training is complete
print("Training complete")

Epoch 1/5, Loss: 0.0623445063829422
Accuracy: 0.8985333333333333
Epoch 2/5, Loss: 0.04524664208292961
Accuracy: 0.9526166666666667
Epoch 3/5, Loss: 0.024592244997620583
Accuracy: 0.9660666666666666
Epoch 4/5, Loss: 0.017265560105443
Accuracy: 0.97265
Epoch 5/5, Loss: 0.016559118404984474
Accuracy: 0.9773833333333334
Training complete


In [6]:
model.input_fc1 = np.array(model.input_fc1)
model.output_fc1 = np.array(model.output_fc1)
model.output_fc2 = np.array(model.output_fc2)

In [7]:

# mutual information between 2 random variables
def mutual_information(x, y):
    return entropy(x) - entropy(x, y)

In [8]:
model.input_fc1.shape,model.output_fc1.shape, model.output_fc2.shape

((300000, 784), (300000, 512), (300000, 10))

In [9]:
model_mi = Net()
model_mi.load_state_dict(model.state_dict())

<All keys matched successfully>

In [14]:
model_mi.fc1.weight.data.shape

torch.Size([512, 784])

In [25]:
for i in range(784):
    print( )

-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-0.9999835416666667
-0.9999385416666666
-0.999971875
-0.9999988541666667
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-0.9999979166666667
-0.9999927083333333
-0.9999270833333334
-0.9998094791666666
-0.9996571875
-0.9994970833333333
-0.9990577083333333
-0.9987392708333334
-0.9986338541666666
-0.9986085416666667
-0.9985153125
-0.9986341666666667
-0.9985339583333334
-0.9987948958333334
-0.9992145833333334
-0.9994413541666667
-0.9995780208333334
-0.9998322916666667
-0.9999209375
-0.9999721875
-1.0
-1.0
-1.0
-1.0
-1.0
-1.0
-0.9999916666666666
-0.9999944791666666
-0.9999454166666667
-0.999956875
-0.9996304166666666
-0.9989144791666666
-0.9979280208333333
-0.9960265625
-0.9932015625
-0.9898817708333333
-0.985330625
-0.9801572916666667
-0.9748892708333333
-0.9715641666666667
-0.9708080208333333
-0.9733916666666667
-0.97801625
-0.9839654166666667
-0.9905716666666666
-0.9950317708333334
-0.9976771875
-

In [30]:
for i in range(784):
    sum_neuron = np.sum(model.input_fc1[:,i])/len( model.input_fc1[:,i])
    if sum_neuron ==-1:
        model_mi.fc1.weight.data[:,i] = 0
        print("neuron",i,"is not used")
    else:
        print(sum_neuron)

neuron 0 is not used
neuron 1 is not used
neuron 2 is not used
neuron 3 is not used
neuron 4 is not used
neuron 5 is not used
neuron 6 is not used
neuron 7 is not used
neuron 8 is not used
neuron 9 is not used
neuron 10 is not used
neuron 11 is not used
-0.9999835416666667
-0.9999385416666666
-0.999971875
-0.9999988541666667
neuron 16 is not used
neuron 17 is not used
neuron 18 is not used
neuron 19 is not used
neuron 20 is not used
neuron 21 is not used
neuron 22 is not used
neuron 23 is not used
neuron 24 is not used
neuron 25 is not used
neuron 26 is not used
neuron 27 is not used
neuron 28 is not used
neuron 29 is not used
neuron 30 is not used
neuron 31 is not used
-0.9999979166666667
-0.9999927083333333
-0.9999270833333334
-0.9998094791666666
-0.9996571875
-0.9994970833333333
-0.9990577083333333
-0.9987392708333334
-0.9986338541666666
-0.9986085416666667
-0.9985153125
-0.9986341666666667
-0.9985339583333334
-0.9987948958333334
-0.9992145833333334
-0.9994413541666667
-0.9995780208

In [31]:
# pruning the weights of 1st layer 
for i in range(117,784):
    edges = 0
    sum_neuron = np.sum(model.input_fc1[:,i])/len( model.input_fc1[:,i])
    if sum_neuron !=-1:
        for j in range(512):
            v = mutual_information( model.output_fc1[:,j] , model.input_fc1[:,i] )
            if v < 0 or np.isnan(v):
                model_mi.fc1.weight.data[j,i] = 0
                edges+=1
    else:
        model_mi.fc1.weight.data[:,i] = 0
        edges+=512
    print(f"edges removed for {i}th input: {edges}")

edges removed for 117th input: 512
edges removed for 118th input: 512
edges removed for 119th input: 512
edges removed for 120th input: 512


KeyboardInterrupt: 

In [16]:
model_mi.fc2.weight.data.shape

torch.Size([10, 512])

In [18]:
mi = []
for i in range(10):
    edges = 0
    for j in range(512):
        v = mutual_information( model.output_fc1[:,j] , model.output_fc2[:,i])
        if v < 0 or np.isnan(v):
            model_mi.fc2.weight.data[i,j] = 0
            edges+=1
    print(i, edges)

0 359
1 406
2 221
3 244
4 372
5 252
6 423
7 324
8 224
9 287


In [88]:
# test accuracy of the model with the pruned weights
correct = 0
total = 0
for images, labels in testloader:
    output = model_mi(images)
    total += labels.size(0)
    correct += (torch.argmax(output, dim=1) == labels).sum().item()
print(f"Accuracy: {correct/total}")

Accuracy: 0.9136
