In [1]:
# Upgrade pip
!pip install --upgrade pip
# Install PyTorch with MPS support
!pip install torch torchvision
# Install thops for metric analysis
!pip install thop

[0m

In [2]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.utils.prune as prune
import torch.optim as optim
import os
import utils.dependencies as utils
import utils.metrics as metrics

In [3]:
# Base model configuration
BASE_MODEL_NAME = 'efficientnet_b0'  # Change this to switch models (e.g., 'resnet18', 'vgg16')
NUM_CLASSES = 10  # Adjust based on your dataset (e.g., 10 for CIFAR10)

# Pruning configuration
PRUNING_METHOD = 'l1_unstructured'  # Options: 'l1_unstructured', 'random_unstructured', etc.
PRUNING_RATIO = 0.8  # Percentage of weights to prune (e.g., 0.2 for 20%)

# Paths
STATE_DICT_PATH = './models/finetuned_base_model.pth'  # Path to the saved state dictionary
SAVE_DIR = './pruned_models/'  # Directory to save pruned models and metrics

# Ensure SAVE_DIR exists
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [5]:
model = models.efficientnet_b0()
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, 10)
model = model.to(device)

In [6]:
# Load the state dictionary
state_dict = torch.load(STATE_DICT_PATH, map_location=device)
# Load state dict into the model
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

In [7]:
# Function to apply pruning to all Conv2d and Linear layers
def apply_fine_grained_pruning(model, method='l1_unstructured', amount=0.2):
    parameters_to_prune = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            parameters_to_prune.append((module, 'weight'))
    
    # Apply pruning
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured if method == 'l1_unstructured' else prune.RandomUnstructured,
        amount=amount,
    )
    print(f"Applied {method} pruning with amount={amount*100}%.")
    return model

def calculate_sparsity(model):
    total_weights = 0
    zero_weights = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            weight = module.weight.data.cpu().numpy()
            total_weights += weight.size
            zero_weights += (weight == 0).sum()
    sparsity = 100. * zero_weights / total_weights
    return sparsity

# Calculate and print sparsity before pruning
initial_sparsity = calculate_sparsity(model)
print(f"Initial Sparsity: {initial_sparsity:.2f}%")

# Apply pruning
model = apply_fine_grained_pruning(model, method=PRUNING_METHOD, amount=PRUNING_RATIO)


Initial Sparsity: 0.00%
Applied l1_unstructured pruning with amount=80.0%.


In [8]:
def remove_pruning(model):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            try:
                prune.remove(module, 'weight')
                print(f"Removed pruning reparameterization from {name}.")
            except ValueError:
                print(f"No pruning to remove for {name}.")
    return model

# Make pruning permanent
model = remove_pruning(model)
# Calculate and print sparsity after pruning
post_pruning_sparsity = calculate_sparsity(model)
print(f"Sparsity after pruning: {post_pruning_sparsity:.2f}%")

Removed pruning reparameterization from features.0.0.
Removed pruning reparameterization from features.1.0.block.0.0.
Removed pruning reparameterization from features.1.0.block.1.fc1.
Removed pruning reparameterization from features.1.0.block.1.fc2.
Removed pruning reparameterization from features.1.0.block.2.0.
Removed pruning reparameterization from features.2.0.block.0.0.
Removed pruning reparameterization from features.2.0.block.1.0.
Removed pruning reparameterization from features.2.0.block.2.fc1.
Removed pruning reparameterization from features.2.0.block.2.fc2.
Removed pruning reparameterization from features.2.0.block.3.0.
Removed pruning reparameterization from features.2.1.block.0.0.
Removed pruning reparameterization from features.2.1.block.1.0.
Removed pruning reparameterization from features.2.1.block.2.fc1.
Removed pruning reparameterization from features.2.1.block.2.fc2.
Removed pruning reparameterization from features.2.1.block.3.0.
Removed pruning reparameterization fro

In [9]:
PRUNED_STATE_DICT_PATH = os.path.join(SAVE_DIR, 'finegrained_prune_0.8.pth')
torch.save(model.state_dict(), PRUNED_STATE_DICT_PATH)
print(f"Pruned model saved at {PRUNED_STATE_DICT_PATH}.")

Pruned model saved at ./pruned_models/finegrained_prune_0.8.pth.


In [10]:
# Freeze all layers except the classifier
for name, param in model.named_parameters():
    if "classifier" not in name:
        param.requires_grad = False
    
# Define the layers to unfreeze (last two blocks)
layers_to_unfreeze = ['features.5', 'features.6', 'features.7']

# Unfreeze the specified layers
utils.unfreeze_layers(model, layers_to_unfreeze)

# Define optimizer to include only trainable parameters
optimizer = optim.SGD(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=0.01,  
    momentum=0.9,
    weight_decay=5e-4
)

# Define a learning rate scheduler for fine-tuning
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

In [None]:
# Training loop
num_epochs = 10
best_val_acc = 0.0
save_path = f'./models/finegrained_prune_ratio_{PRUNING_RATIO}'

# Initialize lists to store metrics
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

# Get the data loaders
train_loader, val_loader, test_loader = utils.get_data_loaders()

for epoch in range(1, num_epochs + 1):
    print(f"--- Epoch {epoch} ---")    
    # Train
    train_loss, train_acc = utils.train_epoch(model, device, train_loader, optimizer)  
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    # Validate
    val_loss, val_acc = utils.validate_epoch(model, device, val_loader)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    # Step the scheduler
    scheduler.step()    
    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), save_path)
        print(f"Best model saved with Val Acc: {best_val_acc:.2f}%\n")
    else:
        print("No improvement this epoch.\n")

Files already downloaded and verified
Files already downloaded and verified
--- Epoch 1 ---
Train Loss: 0.4252 | Train Acc: 85.88% | Time: 63.94s
Val Loss: 0.2360 | Val Acc: 92.02%
Best model saved with Val Acc: 92.02%

--- Epoch 2 ---
Train Loss: 0.2095 | Train Acc: 92.77% | Time: 61.73s
Val Loss: 0.2075 | Val Acc: 93.02%
Best model saved with Val Acc: 93.02%

--- Epoch 3 ---
Train Loss: 0.1447 | Train Acc: 95.07% | Time: 59.48s
Val Loss: 0.2091 | Val Acc: 92.96%
No improvement this epoch.

--- Epoch 4 ---
Train Loss: 0.1083 | Train Acc: 96.16% | Time: 60.43s
Val Loss: 0.1810 | Val Acc: 94.04%
Best model saved with Val Acc: 94.04%

--- Epoch 5 ---
Train Loss: 0.0846 | Train Acc: 97.06% | Time: 58.65s
Val Loss: 0.1867 | Val Acc: 94.06%
Best model saved with Val Acc: 94.06%

--- Epoch 6 ---


In [None]:
# Generate and save metrics plots and table
metrics.generate_and_save_metrics(
    model=model, 
    device=device, 
    test_loader=test_loader, 
    criterion=utils.criterion, 
    model_name=f'fine_grained_prune_{PRUNING_RATIO}', 
    pruning_ratio=PRUNING_RATIO, 
    description= f'Model with fine grained pruning',
    train_losses=train_losses,
    train_accuracies=train_accuracies,
    val_losses=val_losses,
    val_accuracies=val_accuracies,
    save_dir='metrics_plots'
)