we wish to use the MNIST, construct a feed forward neural network for pruning. such that we remove weights, if the mutual information between those neurons were too low



In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from scipy.stats import entropy
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

In [19]:
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=False)
testloader = DataLoader(datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform), batch_size=64, shuffle=False)

# Define the feed-forward neural network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 10)
        self.input_fc1 = []
        self.output_fc1 = []
        self.output_fc2 = []

    def forward(self, x):
        x = x.view(-1, 28*28)
        self.input_fc1.extend(x.detach().cpu().numpy())
        x = torch.relu(self.fc1(x))
        self.output_fc1.extend(x.detach().cpu().numpy())
        x = self.fc2(x)
        self.output_fc2.extend(F.softmax(x, dim=1).detach().cpu().numpy())
        return x


In [24]:
def mutual_information(x, y):
    joint_distribution = np.histogram2d(x, y)[0]
    marginal_x = np.sum(joint_distribution, axis=0)
    marginal_y = np.sum(joint_distribution, axis=1)

    joint_distribution /= np.sum(joint_distribution)
    marginal_x /= np.sum(marginal_x)
    marginal_y /= np.sum(marginal_y)

    joint_distribution = joint_distribution.flatten()
    marginal_distribution = np.outer(marginal_y, marginal_x).flatten()

    # Remove zero entries to avoid division by zero
    mask = np.logical_and(joint_distribution > 0, marginal_distribution > 0)

    return entropy(joint_distribution[mask], marginal_distribution[mask])

In [25]:
def prune_model(model):
    
    assert (len(model.input_fc1) == len(model.output_fc1) == len(model.output_fc2))
    
    print("Pruning the model")
    model.input_fc1 = np.array(model.input_fc1)
    model.output_fc1 = np.array(model.output_fc1)
    model.output_fc2 = np.array(model.output_fc2)
    
    model_mi = Net()
    model_mi.load_state_dict(model.state_dict())
    model_mi.input_fc1 = model_mi.output_fc1 = model_mi.output_fc2 = [] # Clear the lists
    
    # Compute the mutual information between the output of the output layer and the input of the output layer
    layer2 = torch.zeros((10,512)) # store the mutual information between the output of the output layer and the input of the output layer
    
    for i in range(10):
        for j in range(512):
            v = mutual_information( model.output_fc2[:,i] ,  model.output_fc1[:,j] )
            layer2[i][j] = v
            
    print(layer2.shape)
    print(model_mi.fc2.weight.data.shape)
    
    model_mi.fc2.weight.data[layer2<=0.01]  = 0 # prune the weights of fc2
    # make the weights no longer trainable
    model_mi.fc2.weight.data[layer2<=0.01].requires_grad = False
    
    print("Finished pruning fc2")
    # find the neurons which are dead
    neurons_not_dead = []
    count = 0
    for i in range(512):
        weights_out_neuron = np.abs(model_mi.fc2.weight.data[:,i]).sum()
        if weights_out_neuron==0:
            count+=1
        else:
            neurons_not_dead.append(i)
    
    print("Number of neurons dead in fc2: ", count)
    
    # prune the neurons which are dead
    
    layer1 = torch.zeros((784,512))
    
    # pruning the weights of 1st layer
    for i in range(784):
        if model.input_fc1[:,i].mean()!=-1:
            for j in neurons_not_dead:
                v = mutual_information( model.output_fc1[:,j] , model.input_fc1[:,i] )
                layer1[i][j] = v
                
                
    model_mi.fc1.weight.data[layer1.T<=0.01] = 0 # prune the weights of fc1
    model_mi.fc1.weight.data[layer1.T<=0.01].requires_grad = False # make the weights no longer trainable
    
    print("Finished pruning fc1")
    
    return model_mi

In [26]:
# Train the network
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [27]:
# Set the number of epochs (full passes through the dataset)
epochs = 10

# Loop over each epoch
for e in range(epochs):
    # Initialize counters for the number of correct predictions and the total number of predictions
    correct = 0
    total = 0

    # Loop over each batch of images and labels in the training data
    for images, labels in trainloader:
        # Zero the gradients of the model parameters
        optimizer.zero_grad()

        # Pass the images through the model to get the output probabilities
        output = model(images)

        # Calculate the loss between the output probabilities and the true labels
        loss = criterion(output, labels)

        # Backpropagate the gradients of the loss with respect to the model parameters
        loss.backward()

        # Update the model parameters
        optimizer.step()

        # Update the total number of predictions
        total += labels.size(0)

        # Update the number of correct predictions
        correct += (torch.argmax(output, dim=1) == labels).sum().item()
    
    # Prune the model
    model = prune_model(model)

    # Print the loss and accuracy for this epoch
    print(f"Epoch {e+1}/{epochs}, Loss: {loss.item()}")
    print(f"Accuracy: {correct/total}")

# Print a message to indicate that training is complete
print("Training complete")

Pruning the model
torch.Size([10, 512])
torch.Size([10, 512])
Finished pruning fc2
Number of neurons dead in fc2:  207


KeyboardInterrupt: 

In [None]:
# test accuracy of the model with the pruned weights
correct = 0
total = 0
for images, labels in testloader:
    output = model(images)
    total += labels.size(0)
    correct += (torch.argmax(output, dim=1) == labels).sum().item()
print(f"Accuracy: {correct/total}")

Accuracy: 0.6762
