In [1]:
import torch
from torch import nn
from src.utils import load_model
from src.data_loader import get_cifar10_loader
from src.train import train_model
from src.model import ResNet, BasicBlock, resnet110
from src.evaluate import evaluate
from src.utils import count_total_parameters
import torch
import torch_pruning as tp


In [2]:
# Parameters
device = torch.device("mps")
model_path = "resnet110_pretrained.pth"

batch_size = 128
learning_rate = 0.001
num_epochs = 1

ch_sparsity = 0.34 # 34% of all channels are supposed to be pruned; note, that all connected layers also get pruned
iterative_pruning_steps = 5

In [3]:
# Load pretrained model
#teacher_model = load_model(model_path, device=device)
model = resnet110() # Until I pretrained the model aha
pruned_model = model

In [4]:
# Define optimizer and criterion for training
optimizer = torch.optim.Adam(pruned_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Load data
train_loader = get_cifar10_loader('train', batch_size=batch_size)
val_loader = get_cifar10_loader('val', batch_size=batch_size)

In [7]:
count_total_parameters(model)
evaluate(model, val_loader, device)
pruned_model.to("cpu")

Total number of parameters in the model: 1730714
Validation Accuracy: 10.00%, Avg Loss: 1433238891.7248, Time: 4.73s


ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (ff): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (ff): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_

In [8]:
# maybe put it in a function, but then you just see "magic" happening like in model distillation

# Importance criteria
example_inputs = torch.randn(1, 3, 32, 32)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in pruned_model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 10:
        ignored_layers.append(m) # DO NOT prune the final classifier!

pruner = tp.pruner.MagnitudePruner(
    pruned_model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_pruning_steps,
    ch_sparsity=ch_sparsity,
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)
for i in range(iterative_pruning_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = pruned_model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(pruned_model, example_inputs)

    # Calculate the difference in parameters between the original and pruned model
    #print(f"After pruning step {i + 1}:")
    #print(f"Number of parameters: {nparams}")
    #print(f"Difference in parameters: {base_nparams - nparams}\n")

In [9]:
train_model(pruned_model, train_loader, optimizer, criterion, device, num_epochs=num_epochs)

                                                                                    

In [10]:
evaluate(pruned_model, val_loader, device)

Validation Accuracy: 9.29%, Avg Loss: 2.7144, Time: 3.70s


(9.29, 2.714369675064087)