In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.utils.prune as prune
from torch.utils.data import DataLoader
from utils import MNIST_partial, accuracy, plot_training_metrics


In [2]:

# Hyperparameters
learning_rate = 1e-3
num_epochs = 50

# Data loading and preprocessing
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
# train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# dataset from csv file, to use for the challenge
train_dataset = MNIST_partial(split = 'train')
val_dataset = MNIST_partial(split='val')

# definition of the dataloader, to process the data in the model
# here, we need a batch size of 1 to use the boson sampler
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size, shuffle = False)

# Define the CNN model
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(8, 12, kernel_size=5)
        self.fc1 = nn.Linear(12 * 4 * 4, 20)
        self.fc2 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.pool(self.conv1(x))
        x = self.pool(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc1(x)
        x = self.fc2(x)
        return x


In [3]:

# Instantiate the model, loss function, and optimizer
model = CNNModel()
criterion = nn.CrossEntropyLoss()


In [4]:

# Apply structured pruning to the convolutional and fully connected layers
def apply_pruning(module, amount=0.3):
    """Apply structured pruning to convolutional and fully connected layers."""
    for name, layer in module.named_modules():
        if isinstance(layer, nn.Conv2d):
            # Prune entire filters (along dim=0) based on L2 norm
            prune.ln_structured(layer, name='weight', amount=amount, n=2, dim=0)
            print(f"Applied structured pruning to Conv2d layer {name} with {amount * 100:.1f}% filters pruned.")
        elif isinstance(layer, nn.Linear):
            # Prune entire neurons (along dim=1) based on L2 norm
            prune.ln_structured(layer, name='weight', amount=amount, n=2, dim=1)
            print(f"Applied structured pruning to Linear layer {name} with {amount * 100:.1f}% neurons pruned.")

# Check how many non-zero weights remain after pruning
def print_nonzero_weights(model):
    """Print the percentage of remaining non-zero weights in the model."""
    total_params, nonzero_params = 0, 0
    for name, param in model.named_parameters():
        if 'weight' in name or 'bias' in name:
            total_params += param.numel()
            nonzero_params += param.nonzero().size(0)
    print(f"Non-zero parameters: {nonzero_params}/{total_params} ({100 * nonzero_params / total_params:.2f}%)")




In [5]:
apply_pruning(model, amount=0.5)
# Remove pruning to finalize the reduced model
for name, layer in model.named_modules():
    if isinstance(layer, (nn.Conv2d, nn.Linear)):
        prune.remove(layer, 'weight')
        
print_nonzero_weights(model)

print(
    sum(p.numel() for p in model.parameters() if p.requires_grad)
)

Applied structured pruning to Conv2d layer conv1 with 50.0% filters pruned.
Applied structured pruning to Conv2d layer conv2 with 50.0% filters pruned.
Applied structured pruning to Linear layer fc1 with 50.0% neurons pruned.
Applied structured pruning to Linear layer fc2 with 50.0% neurons pruned.
Non-zero parameters: 3370/6690 (50.37%)
6690


In [6]:
# for name, param in model.named_parameters():
#     print(param)

In [7]:
def get_non_pruned_params(model):
    """Get parameters that are not pruned to pass to the optimizer."""
    params = []
    for name, param in model.named_parameters():
        # Include only parameters that are not masked with 'weight_orig'
        if 'weight_orig' in name:
            print(param)
        if 'weight_orig' not in name:
            params.append(param)
    return params

class MaskedAdam(torch.optim.Adam):
    def __init__(self, params, **kwargs):
        super().__init__(params, **kwargs)

    def step(self, closure=None):
        """Override the step function to skip updates for zeroed weights."""
        for group in self.param_groups:
            for param in group['params']:
                if param.grad is None:
                    continue

                # Mask the gradients: set gradient to zero where parameter is zero
                grad_mask = param.data != 0  # Boolean mask: True where weights are non-zero
                param.grad.data.mul_(grad_mask)  # Zero-out gradients for zeroed weights

        # Call the original Adam step function
        super().step(closure)

optimizer = MaskedAdam(model.parameters(), lr=learning_rate)


In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)


CNNModel(
  (conv1): Conv2d(1, 8, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(8, 12, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=192, out_features=20, bias=True)
  (fc2): Linear(in_features=20, out_features=10, bias=True)
)

In [9]:

# Training loop
for epoch in range(num_epochs):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)  # Move data to GPU

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i+1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}")


Epoch [1/50], Step [10/47], Loss: 2.2917
Epoch [1/50], Step [20/47], Loss: 2.2718
Epoch [1/50], Step [30/47], Loss: 2.2301
Epoch [1/50], Step [40/47], Loss: 2.1596
Epoch [2/50], Step [10/47], Loss: 1.8770
Epoch [2/50], Step [20/47], Loss: 1.9020
Epoch [2/50], Step [30/47], Loss: 1.4417
Epoch [2/50], Step [40/47], Loss: 1.4444
Epoch [3/50], Step [10/47], Loss: 1.0035
Epoch [3/50], Step [20/47], Loss: 0.9131
Epoch [3/50], Step [30/47], Loss: 0.8264
Epoch [3/50], Step [40/47], Loss: 0.6834
Epoch [4/50], Step [10/47], Loss: 0.6987
Epoch [4/50], Step [20/47], Loss: 0.6967
Epoch [4/50], Step [30/47], Loss: 0.5329
Epoch [4/50], Step [40/47], Loss: 0.6754
Epoch [5/50], Step [10/47], Loss: 0.4796
Epoch [5/50], Step [20/47], Loss: 0.4219
Epoch [5/50], Step [30/47], Loss: 0.3760
Epoch [5/50], Step [40/47], Loss: 0.4218
Epoch [6/50], Step [10/47], Loss: 0.4758
Epoch [6/50], Step [20/47], Loss: 0.3900
Epoch [6/50], Step [30/47], Loss: 0.5188
Epoch [6/50], Step [40/47], Loss: 0.3403
Epoch [7/50], St

In [10]:

# Evaluate the pruned model
def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)  # Move data to GPU

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Test Accuracy: {100 * correct / total:.2f}%')

evaluate(model, val_loader)



# print("Pruning removed. Model restored to dense weights.")

Test Accuracy: 93.17%


In [11]:
print_nonzero_weights(model)


Non-zero parameters: 3370/6690 (50.37%)


In [12]:
# 0.1, 0.2, 0.3, 0.4, 0.5 

# weight_sharing_parameter_num_list = [675, 1300, 2120, 2745, 3370] 

# weight_sharing_test_acc_list = [
#                                 [26.61, 84.55, 91.08, 95.25, 96.65],
#                                 [50.07, 79.97, 91.15, 95.29, 96.51],
#                                 [32.13, 67.58, 90.97, 94.20, 96.93]
#                                 ]

# MNIST partial

pruning_parameter_num_list = [675, 1300, 2120, 2745, 3370] 

pruning_test_acc_list = [
                        [29.00, ],
                        [21.50, ],
                        [],
                        ]
