In [1]:
import torch
from torch import nn
from src.utils import load_model, save_model
from src.data_loader import get_cifar10_loader
from src.train import train_model
from src.model import ResNet, BasicBlock, resnet110
from src.evaluate import evaluate
from src.utils import count_total_parameters
import torch
import torch_pruning as tp
from src.utils import iterative_pruner



In [2]:
# Parameters
device = torch.device("mps")
model_path = "models/resnet110_baseline_30_mps.pth"

batch_size = 128
learning_rate = 0.001
num_epochs = 1

ch_sparsity = 0.95 # 34% of all channels are supposed to be pruned; note, that all connected layers also get pruned
iterative_pruning_steps = 1

In [3]:
# Load pretrained model
#teacher_model = load_model(model_path, device=device)
model = load_model(model_path, device)# Until I pretrained the model aha
pruned_model = model

In [4]:
# Define optimizer and criterion for training
optimizer = torch.optim.Adam(pruned_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Load data
train_loader = get_cifar10_loader('train', batch_size=batch_size)
val_loader = get_cifar10_loader('val', batch_size=batch_size)

In [5]:
count_total_parameters(model)
#evaluate(model, val_loader, device)


In [6]:
pruned_model.to("cpu")
example_inputs = torch.randn(1, 3, 32, 32)
imp = tp.importance.TaylorImportance()

pruner = tp.pruner.MagnitudePruner(
    pruned_model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_pruning_steps,
    ch_sparsity=ch_sparsity,
)

iterative_pruner(pruner, iterative_pruning_steps)




In [8]:
train_model(pruned_model, train_loader, optimizer, criterion, device, num_epochs=10)

                                                                                       

In [10]:
evaluate(pruned_model, val_loader, device)

Validation Accuracy: 9.29%, Avg Loss: 2.7144, Time: 3.70s


(9.29, 2.714369675064087)

In [7]:
count_total_parameters(pruned_model)

Total number of parameters in the model: 88236


88236

In [10]:
from src.utils import save_model
save_model(pruned_model, "models/pruned_95-10_resnet110_mps.pth")