-
Notifications
You must be signed in to change notification settings - Fork 332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Functionality to add rounding of filters number for pruning #38
Comments
Hi @Serjio42 . It sounds great. Please make a new pull request for that. |
HI, @VainF . I've made a pull request. Please look if it is OK. |
Hello, I am a little confused about the speed-up ability of rounding. Does it mean a round channel number is more device-friendly than a non-round one? |
Yes. Nothing strange with that if you look at all the industry-known architectures, all of them strive to use number of channels as 32, 64, 128, 144, 192, etc. |
Thanks! I will try your code and conduct some experiments to verify the benefits of rounding. |
Hi @Serjio42 I have tried the strategy with rounding. Here is the inference time of MobileNetv2 with [16 x 3 x 32 x32] inputs. GPU: before pruning: inference time=0.014485 s, parameters=3504872
w/o rounding: inference time=0.007839 s, parameters=1969470
w/ rounding: inference time=0.008662 s, parameters=1967864 CPU: before pruning: inference time=0.267591 s, parameters=3504872
w/o rounding: inference time=0.154733 s, parameters=1969470
w/ rounding: inference time=0.149170 s, parameters=1967864 It seems that whether rounding can improve the inference time depends on the hardware. import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2
import torch_pruning as tp
import time
def measure_inference_time(net, input, repeat=100):
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(repeat):
model(input)
torch.cuda.synchronize()
end = time.perf_counter()
return (end-start) / repeat
device = torch.device('cpu')
repeat = 100
# w/o rounding
model = mobilenet_v2(pretrained=True).eval()
fake_input = torch.randn(16,3,224,224)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_before_pruning = measure_inference_time(model, fake_input, repeat)
print("before pruning: inference time=%f s, parameters=%d"%(inference_time_before_pruning, tp.utils.count_params(model)))
model = mobilenet_v2(pretrained=True).eval()
strategy = tp.strategy.L1Strategy()
DG = tp.DependencyGraph()
fake_input = fake_input.cpu()
DG.build_dependency(model, example_inputs=fake_input)
for m in model.modules():
if isinstance(m, nn.Conv2d):
pruning_idxs = strategy(m.weight, amount=0.2)
pruning_plan = DG.get_pruning_plan( m, tp.prune_conv, idxs=pruning_idxs )
pruning_plan.exec()
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_without_rounding = measure_inference_time(model, fake_input, repeat)
print("w/o rounding: inference time=%f s, parameters=%d"%(inference_time_without_rounding, tp.utils.count_params(model)))
# w/ rounding
model = mobilenet_v2(pretrained=True).eval()
strategy = tp.strategy.L1Strategy()
DG = tp.DependencyGraph()
fake_input = fake_input.cpu()
DG.build_dependency(model, example_inputs=fake_input)
for m in model.modules():
if isinstance(m, nn.Conv2d):
pruning_idxs = strategy(m.weight, amount=0.2, round_to=8)
pruning_plan = DG.get_pruning_plan( m, tp.prune_conv, idxs=pruning_idxs )
pruning_plan.exec()
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_with_rounding = measure_inference_time(model, fake_input, repeat)
print("w/ rounding: inference time=%f s, parameters=%d"%(inference_time_with_rounding, tp.utils.count_params(model))) |
I think rounding is useful for CPU deployment. Maybe we should implement the rounding operation as a function to make the strategy more clear. Let's merge it first and move the rounding functionality to a new function. |
Hi. I think it will be useful to add functionality to round number of pruned channels to provided number (32 or 16, for example). I've made it locally in prune/strategy.py script. It really accelerate inference speed!
If it can be useful to others, I can try to make pull request with this functionality this week. Any thoughts?
The text was updated successfully, but these errors were encountered: