Import necessary libraries

In [1]:
import torch
from torch import nn
from torch.optim import *
from torchvision.datasets import *
from torchvision.transforms import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt


  from .autonotebook import tqdm as notebook_tqdm


Define the VGG16 model

In [43]:
class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 10),
        )
        self._initialize_weights()
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

model = VGG16().cuda()  # Move model to GPU


 Define the pruning function

In [44]:
def prune_by_magnitude(weight, sparsity):
    weight = weight.clone()
    threshold = torch.quantile(weight.abs(), sparsity)
    mask = weight.abs() >= threshold
    weight *= mask
    return weight


Define the weight distribution plotting function

In [45]:
def plot_weight_distribution(weight):
    plt.hist(weight.detach().cpu().numpy().flatten(), bins=100)
    plt.show()


Define the sparsity and model size calculation functions

In [46]:
def get_sparsity(tensor: torch.Tensor) -> float:
    """
    calculate the sparsity of the given tensor
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    return 1 - float(tensor.count_nonzero()) / tensor.numel()

def get_model_sparsity(model: nn.Module) -> float:
    """
    calculate the sparsity of the given model
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    num_nonzeros, num_elements = 0, 0
    for param in model.parameters():
        num_nonzeros += param.count_nonzero()
        num_elements += param.numel()
    return 1 - float(num_nonzeros) / num_elements

def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements

def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
    """
    calculate the model size in bits
    :param data_width: #bits per element
    :param count_nonzero_only: only count nonzero weights
    """
    return get_num_parameters(model, count_nonzero_only) * data_width


Load the CIFAR-10 dataset

In [47]:
# Data augmentation
transforms = {
    "train": Compose([
        RandomCrop(32, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]),
    "test": Compose([
        ToTensor(),
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]),
}

dataset = {}
for split in ["train", "test"]:
  dataset[split] = CIFAR10(
    root="data/cifar10",
    train=(split == "train"),
    download=True,
    transform=transforms[split],
  )
dataloader = {}
for split in ['train', 'test']:
  dataloader[split] = DataLoader(
    dataset[split],
    batch_size=256,
    shuffle=(split == 'train'),
    num_workers=0,
    pin_memory=True,
  )


Files already downloaded and verified
Files already downloaded and verified


Train the model

In [48]:
from torch.utils.data import random_split

# Split the training dataset into training and validation sets
train_dataset = dataset['train']
train_size = int(0.8 * len(train_dataset))  # 80% for training
val_size = len(train_dataset) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Create dataloaders for the training and validation sets
batch_size = 512
dataloader = {
    'train': DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True),
    'val': DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True),
    'test': dataloader['test']  # Use the existing test dataloader
}

def calculate_validation_loss(model, dataloader, criterion):
    model.eval()  # Set the model to evaluation mode
    running_loss = 0.0
    for inputs, targets in dataloader:
        inputs = inputs.cuda()
        targets = targets.cuda()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        running_loss += loss.item() * inputs.size(0)
    return running_loss / len(dataloader.dataset)

def train(
  model: nn.Module,
  dataloader: DataLoader,
  criterion: nn.Module,
  optimizer: Optimizer,
) -> None:
  model.train()

  for inputs, targets in tqdm(dataloader, desc='train', leave=False):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    targets = targets.cuda()

    # Reset the gradients (from the last iteration)
    optimizer.zero_grad()

    # Forward inference
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Backward propagation
    loss.backward()

    # Update optimizer
    optimizer.step()

criterion = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=0.0005)


# Define your learning rate scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

# Early stopping parameters
min_val_loss = float('inf')
patience = 20
patience_counter = 0

# Training loop with early stopping
num_epochs = 30
for epoch in range(num_epochs):
    train(model, dataloader['train'], criterion, optimizer)
    
    # Calculate validation loss here
    val_loss = calculate_validation_loss(model, dataloader['val'], criterion)
    
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss}')

    # Check if the validation loss improved
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        patience_counter = 0
    else:
        patience_counter += 1

    # If the validation loss didn't improve for 'patience' epochs, stop training
    if patience_counter >= patience:
        print('Early stopping')
        break

    scheduler.step(val_loss)

train:   0%|          | 0/79 [00:00<?, ?it/s]

                                                      

Epoch 1, Validation Loss: 2.3025472324371337


                                                      

Epoch 2, Validation Loss: 2.3023638061523437


                                                      

Epoch 3, Validation Loss: 2.3022195819854736


                                                      

Epoch 4, Validation Loss: 2.3020576244354247


                                                      

Epoch 5, Validation Loss: 2.3018782012939454


                                                      

Epoch 6, Validation Loss: 2.3017520721435547


                                                      

Epoch 7, Validation Loss: 2.3016021350860596


                                                      

Epoch 8, Validation Loss: 2.3014849369049073


                                                      

Epoch 9, Validation Loss: 2.3013191066741943


                                                      

Epoch 10, Validation Loss: 2.301159693145752


                                                      

Epoch 11, Validation Loss: 2.301060609817505


                                                      

Epoch 12, Validation Loss: 2.300906921005249


                                                      

Epoch 13, Validation Loss: 2.300792721557617


                                                      

Epoch 14, Validation Loss: 2.3006124473571776


                                                      

Epoch 15, Validation Loss: 2.300486852645874


                                                      

Epoch 16, Validation Loss: 2.300344374847412


                                                      

Epoch 17, Validation Loss: 2.300208726501465


                                                      

Epoch 18, Validation Loss: 2.300008084869385


                                                      

Epoch 19, Validation Loss: 2.2998540279388426


                                                      

Epoch 20, Validation Loss: 2.2997462493896483


                                                      

Epoch 21, Validation Loss: 2.299599536514282


                                                      

Epoch 22, Validation Loss: 2.299373710632324


                                                      

Epoch 23, Validation Loss: 2.2991888378143313


                                                      

Epoch 24, Validation Loss: 2.299067244720459


                                                      

Epoch 25, Validation Loss: 2.2988147640228274


                                                      

Epoch 26, Validation Loss: 2.2987237995147707


                                                      

Epoch 27, Validation Loss: 2.298449454498291


                                                      

Epoch 28, Validation Loss: 2.2982151229858396


                                                      

Epoch 29, Validation Loss: 2.2979885540008547


                                                      

Epoch 30, Validation Loss: 2.2977690177917482


Test the model

In [49]:
@torch.inference_mode()
def evaluate(
  model: nn.Module,
  dataloader: DataLoader,
  verbose=True,
) -> float:
  model.eval()

  num_samples = 0
  num_correct = 0

  for inputs, targets in tqdm(dataloader, desc="eval", leave=False,
                              disable=not verbose):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    targets = targets.cuda()

    # Inference
    outputs = model(inputs)

    # Convert logits to class indices
    outputs = outputs.argmax(dim=1)

    # Update metrics
    num_samples += targets.size(0)
    num_correct += (outputs == targets).sum()

  return (num_correct / num_samples * 100).item()

accuracy = evaluate(model, dataloader['test'])
print(f'Accuracy of the network on the 10000 test images: {accuracy}%')


eval:   0%|          | 0/40 [00:00<?, ?it/s]

                                                     

Accuracy of the network on the 10000 test images: 9.999999046325684%




Prune the model

In [12]:
def prune_by_magnitude(weight, sparsity):
    weight = weight.clone()
    # Sort the tensor
    sorted_weights = torch.sort(weight.abs().flatten())[0]
    # Compute the index of the threshold
    index = int(sparsity * sorted_weights.numel())
    # Select the value at the computed index
    threshold = sorted_weights[index]
    mask = weight.abs() >= threshold
    weight *= mask
    return weight

for name, param in model.named_parameters():
    if 'weight' in name:
        param.data = prune_by_magnitude(param.data, 0.2)  # Prune 20% of the weights


Fine-tune the pruned model

In [13]:
for epoch in range(10):  # loop over the dataset multiple times
    train(model, dataloader['train'], criterion, optimizer)
    print(f'Finished Fine-Tuning Epoch {epoch+1}')


                                                      

Finished Fine-Tuning Epoch 1


                                                      

Finished Fine-Tuning Epoch 2


                                                      

Finished Fine-Tuning Epoch 3


train:   8%|▊         | 8/98 [01:34<18:06, 12.07s/it]

Test the pruned model

In [None]:
accuracy = evaluate(model, dataloader['test'])
print(f'Accuracy of the pruned network on the 10000 test images: {accuracy}%')


Plot the weight distribution

In [None]:
for name, param in model.named_parameters():
    if 'weight' in name:
        plot_weight_distribution(param.data)


Calculate and print the sparsity and size of the model

In [None]:
sparsity = get_model_sparsity(model)
size = get_model_size(model)
print(f'Sparsity of the model: {sparsity}')
print(f'Size of the model: {size} bits')
