In [1]:
import torch
from torch import nn
from src.utils import load_model, save_model
from src.data_loader import get_cifar10_loader
from src.train import train_model
from src.model import ResNet, BasicBlock, resnet110
from src.evaluate import evaluate, count_total_parameters
import torch_pruning as tp
from src.utils import iterative_pruner


In [2]:
# Parameters
device = torch.device("mps")
model_path = "models/resnet110_baseline_120_mps.pth"

batch_size = 128
learning_rate = 0.001
num_epochs = 30

# base parameters: 1730714
# ch_sparsity = 0.15 -> 1228878 ca. 30%
# ch_sparsity = 0.29 -> 848388  ca. 50%
# ch_sparsity = 0.34 -> 7?????  ca. 60%
# ch_sparsity = 0.45 -> 509972  ca. 70%
# ch_sparsity = 0.95 -> 6765    ca. 0.39%

ch_sparsity = 0.45
iterative_pruning_steps = 5

In [3]:
# Load pretrained model
model = load_model(model_path, device)
pruned_model = model

In [4]:
count_total_parameters(model)
#evaluate(model, val_loader, device)


Total number of parameters in the model: 1730714


1730714

In [5]:
pruned_model.to("cpu")
example_inputs = torch.randn(1, 3, 32, 32)
calculate_gradient = True
imp = tp.importance.TaylorImportance()

pruner = tp.pruner.MagnitudePruner(
    pruned_model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_pruning_steps,
    ch_sparsity=ch_sparsity,
)

iterative_pruner(pruner, calculate_gradient, iterative_pruning_steps)




In [6]:
count_total_parameters(pruned_model)

Total number of parameters in the model: 509972


509972

In [7]:
print(pruned_model)


ResNet(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (conv1): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (ff): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (ff): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv1): Conv2d(8, 8, kernel_

In [8]:
print(model)

ResNet(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (conv1): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (ff): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (ff): FloatFunctional(
        (activation_post_process): Identity()
      )
      (conv1): Conv2d(8, 8, kernel_

In [9]:
# Define optimizer and criterion for training
optimizer = torch.optim.Adam(pruned_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Load data
train_loader = get_cifar10_loader('train', batch_size=batch_size)
val_loader = get_cifar10_loader('val', batch_size=batch_size)

In [10]:
train_model(pruned_model, train_loader, optimizer, criterion, device, num_epochs=num_epochs)

                                                                                       

In [11]:
evaluate(pruned_model, val_loader, device)

Validation Accuracy: 87.49%, Avg Loss: 0.4480, Time: 3.10s


(87.49, 0.4479996481895447, 3.0999460220336914)

In [12]:

save_model(pruned_model, "pruned_45-30_resnet110_mps.pth")