In [11]:
# Upgrade pip
!pip install --upgrade pip
# Install PyTorch with MPS support
!pip install torch torchvision
# Install thops for metric analysis
!pip install thop torch_pruning

[0m

In [12]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import os
import torch_pruning as tp  # Torch-Pruning library
import thop  # For FLOPs and parameter counting
import utils.dependencies as utils  # Ensure this module exists and contains necessary functions
import utils.metrics as metri

In [13]:
# Base model configuration
BASE_MODEL_NAME = 'efficientnet_b0'  # Change this to switch models (e.g., 'resnet18', 'vgg16')
NUM_CLASSES = 10  # Adjust based on your dataset (e.g., 10 for CIFAR10)

# Pruning configuration
PRUNING_METHOD = 'channel pruning'  # Options: 'l1_unstructured', 'random_unstructured', etc.
PRUNING_RATIO = 0.2  # Percentage of weights to prune (e.g., 0.2 for 20%)

# Paths
STATE_DICT_PATH = './models/finetuned_base_model.pth'  # Path to the saved state dictionary
SAVE_DIR = './pruned_models/'  # Directory to save pruned models and metrics

# Ensure SAVE_DIR exists
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = models.efficientnet_b0()
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, 10)
model = model.to(device)

# Load the state dictionary
state_dict = torch.load(STATE_DICT_PATH, map_location=device)
# Load state dict into the model
model.load_state_dict(state_dict, strict=False)


Using device: cuda


<All keys matched successfully>

In [None]:
def calculate_sparsity(model):
    """
    Calculates the sparsity (percentage of zero weights) in the model's Conv2d and Linear layers.
    """
    total_weights = 0
    zero_weights = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            weight = module.weight.data.cpu().numpy()
            total_weights += weight.size
            zero_weights += (weight == 0).sum()
    sparsity = 100.0 * zero_weights / total_weights if total_weights > 0 else 0
    return sparsity

def apply_channel_pruning(model, pruning_ratio=0.2):
    """
    Applies channel-level pruning to the given model using L1-norm based importance.

    Args:
        model (nn.Module): The PyTorch model to prune.
        pruning_ratio (float): The fraction of channels to prune globally (e.g., 0.2 for 20%).

    Returns:
        nn.Module: The pruned model.
    """
    # Set model to evaluation mode
    model.eval()  
    # Create a dummy input tensor matching the input shape expected by the model
    dummy_input = torch.randn(1, 3, 224, 224).to(next(model.parameters()).device)  
    # Define the importance metric: L1 norm
    importance = tp.importance.MagnitudeImportance(p=1) 
    # Identify layers to ignore during pruning (e.g., final classifier layer)
    ignored_layers = []
    for m in model.modules():
        # Avoid the final classifier layer
        if isinstance(m, nn.Linear) and m.out_features == 10:
            ignored_layers.append(m)  
    # Initialize the MetaPruner with the desired settings
    pruner = tp.pruner.MetaPruner(
        model=model,
        example_inputs=dummy_input,
        importance=importance,
        pruning_ratio=pruning_ratio,  
        ignored_layers=ignored_layers,
        #For hardware acceleration
        round_to=8,  
    )
    # Perform pruning
    pruner.step()
    print(f"Applied {pruning_ratio * 100:.1f}% channel-level pruning using L1 norm.")
    return model

# Example usage:
initial_sparsity = calculate_sparsity(model)
print(f"Initial Sparsity: {initial_sparsity:.2f}%")

model = apply_channel_pruning(model, pruning_ratio=PRUNING_RATIO)

post_pruning_sparsity = calculate_sparsity(model)
print(f"Sparsity after channel pruning: {post_pruning_sparsity:.2f}%")


Initial Sparsity: 0.00%
Applied 36.0% channel-level pruning using L1 norm.
Sparsity after channel pruning: 0.00%


In [17]:
# Define the path to save the pruned model
PRUNED_STATE_DICT_PATH = os.path.join(SAVE_DIR, 'channel_pruned_model.pth')

# Save the pruned model's state dictionary
torch.save(model.state_dict(), PRUNED_STATE_DICT_PATH)
print(f"Pruned model saved at {PRUNED_STATE_DICT_PATH}.")


Pruned model saved at ./pruned_models/channel_pruned_model.pth.


In [18]:
# Freeze all layers except the classifier
for name, param in model.named_parameters():
    if "classifier" not in name:
        param.requires_grad = False
    
# Define the layers to unfreeze (last two blocks)
layers_to_unfreeze = ['features.5', 'features.6', 'features.7']

# Unfreeze the specified layers
utils.unfreeze_layers(model, layers_to_unfreeze)

# Define optimizer to include only trainable parameters
optimizer = optim.SGD(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=0.09,  
    momentum=0.9,
    weight_decay=5e-4
)

# Define a learning rate scheduler for fine-tuning
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

In [19]:
# Training loop
num_epochs = 20
best_val_acc = 0.0
save_path = f'./models/finegrained_prune_ratio_{PRUNING_RATIO}'

# Initialize lists to store metrics
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

# Get the data loaders
train_loader, val_loader, test_loader = utils.get_data_loaders()

for epoch in range(1, num_epochs + 1):
    print(f"--- Epoch {epoch} ---")    
    # Train
    train_loss, train_acc = utils.train_epoch(model, device, train_loader, optimizer)  
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    # Validate
    val_loss, val_acc = utils.validate_epoch(model, device, val_loader)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    # Step the scheduler
    scheduler.step()    
    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), save_path)
        print(f"Best model saved with Val Acc: {best_val_acc:.2f}%\n")
    else:
        print("No improvement this epoch.\n")

Files already downloaded and verified
Files already downloaded and verified
--- Epoch 1 ---
Train Loss: 0.9023 | Train Acc: 68.58% | Time: 44.40s
Val Loss: 0.7390 | Val Acc: 73.30%
Best model saved with Val Acc: 73.30%

--- Epoch 2 ---
Train Loss: 0.6358 | Train Acc: 77.92% | Time: 44.72s
Val Loss: 0.5865 | Val Acc: 79.50%
Best model saved with Val Acc: 79.50%

--- Epoch 3 ---
Train Loss: 0.5619 | Train Acc: 80.55% | Time: 44.38s
Val Loss: 0.6582 | Val Acc: 77.68%
No improvement this epoch.

--- Epoch 4 ---
Train Loss: 0.5363 | Train Acc: 81.29% | Time: 44.19s
Val Loss: 0.5521 | Val Acc: 80.44%
Best model saved with Val Acc: 80.44%

--- Epoch 5 ---
Train Loss: 0.5196 | Train Acc: 82.05% | Time: 44.61s
Val Loss: 0.5532 | Val Acc: 81.30%
Best model saved with Val Acc: 81.30%

--- Epoch 6 ---
Train Loss: 0.5174 | Train Acc: 82.14% | Time: 44.71s
Val Loss: 0.6249 | Val Acc: 78.86%
No improvement this epoch.

--- Epoch 7 ---
Train Loss: 0.3564 | Train Acc: 87.69% | Time: 45.18s
Val Loss: 0.

In [20]:
import utils.metrics as metrics

# Generate and save metrics plots and table
metrics.generate_and_save_metrics(
    model=model, 
    device=device, 
    test_loader=test_loader, 
    criterion=utils.criterion, 
    model_name=f'channel_pruning_{PRUNING_RATIO}', 
    pruning_ratio=PRUNING_RATIO, 
    description= f'Model with channel pruning',
    train_losses=train_losses,
    train_accuracies=train_accuracies,
    val_losses=val_losses,
    val_accuracies=val_accuracies,
    save_dir='metrics_plots'
)

--- Generating Metrics for channel_pruning_0.36 ---
Calculating Model Size...
Measuring Inference Time...
Computing FLOPs...
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
Measuring Memory Usage...
Evaluating Model on Test Set...
Test Loss: 0.3720 | Test Accuracy: 87.65%
Generating Metrics Table...
Metrics table saved in 'metrics_plots' as 'channel_pruning_0.36_metrics_table.png'.
Generating Training and Validation Metrics Plots...
Training and validation metrics plots saved in 'metrics_plots' as 'channel_pruning_0.36_training_validation_plots.png