In [3]:
import torch
import numpy as np
import random
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import importnb
with importnb.Notebook():
    from matrix_based_entropy_estimators import IT_calculator

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x1dab0bbc790>

In [5]:
class SimpleCNN(nn.Module):
    "Build a simple 3 layer CNN model"
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(7 * 7 * 32, 64)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        return x

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=60, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=60, shuffle=False)
# data_example, label = next(iter(trainloader))
# print(data_example)

model = SimpleCNN().to(device)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model = SimpleCNN()
model.load_state_dict(torch.load('SimpleCNN.pth'))
model.eval()

# Evaluation on test set
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

Test Accuracy: 98.90%


In [35]:
model.load_state_dict(torch.load('SimpleCNN.pth'))
# Get the state dictionary of the model
state_dict = model.state_dict()

# Identify the filters you want to discard (e.g., filter indices 2 and 4)
# filters_to_discard = [14, 0, 1, 10]
filters_to_discard = [12, 2, 5, 8]

# Set the weights corresponding to the filters to discard to zero
state_dict['conv1.weight'][filters_to_discard, :, :, :] = 0
state_dict['conv1.bias'][filters_to_discard] = 0

# Set the updated weights back to the first convolutional layer
model.load_state_dict(state_dict)

<All keys matched successfully>

In [36]:
# Evaluation on test set
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

Test Accuracy: 97.87%


In [37]:
model.load_state_dict(torch.load('SimpleCNN.pth'))
# Get the state dictionary of the model
state_dict = model.state_dict()

# Identify the filters you want to discard (e.g., filter indices 2 and 4)
filters_to_discard = [14, 0, 1, 10]
# filters_to_discard = [12, 2, 5, 8]

# Set the weights corresponding to the filters to discard to zero
state_dict['conv1.weight'][filters_to_discard, :, :, :] = 0
state_dict['conv1.bias'][filters_to_discard] = 0

# Set the updated weights back to the first convolutional layer
model.load_state_dict(state_dict)

# Evaluation on test set
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

Test Accuracy: 85.92%
