# Basic example

In [4]:
import torch
import torchvision.models as models
from unip.core.pruner import BasePruner
from unip.utils.evaluation import cal_flops

model = models.resnet18()

# note: we need to set the `requires_grad=True` to connect the input
#       to the graph
example_input = torch.rand(1, 3, 224, 224, requires_grad=True)

# calculate the flops of the original model
print("original model:")
flops_ori, params_ori = cal_flops(model, example_input)

# define a pruner
pruner = BasePruner(
    model,
    example_input,
    "UniformRatio",
    algo_args={"score_fn": "weight_sum_l1_out"},
)

# prune the model
pruner.prune(0.3)

# calculate the flops of the pruned model
print("pruned model:")
flops_pruned, params_pruned = cal_flops(model, example_input)


original model:
('1.824G', '11.690M')
pruned model:
('927.553M', '5.868M')
