In [27]:
# Upgrade pip
!pip install --upgrade pip
# Install PyTorch with MPS support
!pip install torch torchvision
# Install thops for metric analysis
!pip install thop torch_pruning tqdm

[0m

In [28]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import torch_pruning as tp  # Ensure you have Torch-Pruning installed
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import utils.dependencies as utils
from tqdm import tqdm
import utils.metrics as metrics
import os

In [29]:
# Base model configuration
BASE_MODEL_NAME = 'efficientnet_b0'  # Change this to switch models (e.g., 'resnet18', 'vgg16')
NUM_CLASSES = 10  # Adjust based on your dataset (e.g., 10 for CIFAR10)

# Pruning configuration
PRUNING_METHOD = 'Activation-driven pruning'  # Options: 'l1_unstructured', 'random_unstructured', etc.
PRUNING_RATIO = 0.36  # Percentage of weights to prune (e.g., 0.2 for 20%)

# Paths
STATE_DICT_PATH = './models/finetuned_base_model.pth'  # Path to the saved state dictionary
SAVE_DIR = './pruned_models/'  # Directory to save pruned models and metrics

# Ensure SAVE_DIR exists
if not os.path.exists(SAVE_DIR):
    os.makedirs(SAVE_DIR)


In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = models.efficientnet_b0()
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, 10)
model = model.to(device)

# Load the state dictionary
state_dict = torch.load(STATE_DICT_PATH, map_location=device)
# Load state dict into the model
model.load_state_dict(state_dict, strict=False)

Using device: cuda


<All keys matched successfully>

In [32]:
# Training loop
num_epochs = 20
best_val_acc = 0.0
save_path = f'./models/finegrained_prune_ratio_{PRUNING_RATIO}'

# Initialize lists to store metrics
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

# Get the data loaders
train_loader, val_loader, test_loader = utils.get_data_loaders()


Files already downloaded and verified
Files already downloaded and verified


In [33]:
from collections import defaultdict
# Dictionary to store activation statistics
activation_stats = defaultdict(list)

# Define a hook function to capture activations
def get_activation(name):
    def hook(model, input, output):
        # Compute binary activations using PyTorch operations
        binary_activation = (output > 0).sum(dim=(0, 2, 3)).cpu()
        activation_stats[name].append(binary_activation)
    return hook

# Register hooks to all Conv2d layers
hooks = []
for name, module in model.named_modules():
    if isinstance(module, nn.Conv2d):
        hooks.append(module.register_forward_hook(get_activation(name)))

# Run the dataset through the model to collect activation data with a progress bar
model.eval()
with torch.no_grad():
    for inputs, _ in tqdm(train_loader, desc="Profiling activations"):
        inputs = inputs.to(device, non_blocking=True)
        outputs = model(inputs)

# Remove hooks after profiling
for hook in hooks:
    hook.remove()

# Aggregate activation statistics using PyTorch tensors
channel_activation_counts = {}
for name, activations in activation_stats.items():
    # Stack tensors and sum over all batches
    total_activation = torch.stack(activations).sum(dim=0).numpy()
    channel_activation_counts[name] = total_activation

print("Activation profiling completed.")

# Optional: Print some activation statistics
for name, counts in channel_activation_counts.items():
    print(f"Layer: {name}, Total Activations: {counts.sum()}, Channels: {len(counts)}")

Profiling activations:  88%|████████▊ | 310/352 [00:43<00:05,  7.16it/s]


KeyboardInterrupt: 

In [None]:
# Define the pruning ratio (e.g., 80% pruning)
PRUNING_RATIO = 0.8

# Assuming channel_activation_counts has been computed from profiling

# Dictionary to store indices of channels to prune for each layer
channels_to_prune = {}

for name, counts in channel_activation_counts.items():
    num_channels = counts.shape[0]
    num_prune = int(num_channels * PRUNING_RATIO)
    if num_prune < 1:
        continue  # Skip pruning if not enough channels to prune
    # Select channels with the lowest activation counts
    prune_indices = np.argsort(counts)[:num_prune]
    channels_to_prune[name] = prune_indices.tolist()

# Build Dependency Graph after profiling and before any pruning
dummy_input = torch.randn(1, 3, 224, 224).to(device)
DG = tp.DependencyGraph().build_dependency(model, example_inputs=dummy_input)

# Debug: Verify if all target modules are in the dependency graph
missing_modules = []
for name in channels_to_prune.keys():
    module = dict(model.named_modules()).get(name)
    if module not in DG.module2node:
        missing_modules.append(name)

if missing_modules:
    print("The following modules are not in the dependency graph and will be skipped:")
    for name in missing_modules:
        print(f" - {name}")
else:
    print("All target modules are present in the dependency graph.")

# Iterate over each layer and prune the selected channels
for name, prune_idxs in channels_to_prune.items():
    module = dict(model.named_modules()).get(name)
    if module is None:
        print(f"Module {name} does not exist in the model. Skipping pruning for this layer.")
        continue
    if isinstance(module, nn.Conv2d):
        try:
            # Prune output channels
            group = DG.get_pruning_group(module, tp.prune_conv_out_channels, prune_idxs)
            if DG.check_pruning_group(group):
                group.prune()
                print(f"Pruned {len(prune_idxs)} channels from layer '{name}'.")
            else:
                print(f"Pruning group for layer '{name}' is not valid. Skipping.")
        except Exception as e:
            print(f"Error pruning layer '{name}': {e}")

print("Pruning completed.")