Add support for pruned models#103
Conversation
f65625a to
685b046
Compare
|
I am not sure how to fix the mypy errors but I have tested this branch on pruned VGG networks and it seems to work. If you need any inputs from my side please ping me :) |
|
@MajorCarrot Looks good so far. Could you add a test case with a pruned VGG network and the expected output? Also left a couple comments on the changes |
3185e3a to
823962e
Compare
|
FYI, you'll need to run |
823962e to
9bb29b4
Compare
|
I have fixed almost all the errors with mypy (by rewriting the rgetattr and changing its signature), but I am still not able to get the whole thing passed through mypy xD If you have suggestions on how to fix this (hopefully final) error, we should be good to go! |
5461e37 to
b71e003
Compare
|
You can add a |
|
|
||
| def rgetattr(obj: torch.nn.Module, attr: str) -> torch.Tensor: | ||
| """Get the tensor submodule called attr from obj.""" | ||
| for i in attr.split("."): |
There was a problem hiding this comment.
Let's use a dfferent variable since i is typically reserved for numbers
There was a problem hiding this comment.
Would attr_i be fine (to say some part of the variable attr)?
According to the [pytorch documentation on pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html), the original parameter is replaced with one ending with `_orig` and a new buffer ending with `_mask`. The mask contains 0s and 1s based on which the correct parameters are chosen. All instances of `param.nelements()` have been replaced by a variable `cur_params` whose value is set based on whether it is a masked model or not. To keep consistency with the rest of the code base, the `_orig` is removed from the `name` variable right after the calculation of `cur_params`.
b71e003 to
f1646b9
Compare
Codecov Report
@@ Coverage Diff @@
## main #103 +/- ##
==========================================
- Coverage 99.30% 99.10% -0.21%
==========================================
Files 5 5
Lines 434 446 +12
==========================================
+ Hits 431 442 +11
- Misses 3 4 +1
Continue to review full report at Codecov.
|
|
(I am really sorry for messing up the whole testing thing, so many runs 🤦) I had removed the |
| model = SingleInputNet() | ||
| for module in model.modules(): | ||
| if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): | ||
| prune.l1_unstructured(module, "weight", 0.5) # type: ignore[no-untyped-call] |
There was a problem hiding this comment.
| prune.l1_unstructured(module, "weight", 0.5) # type: ignore[no-untyped-call] | |
| prune.l1_unstructured( # type: ignore[no-untyped-call] | |
| module, "weight", 0.5 | |
| ) |
There was a problem hiding this comment.
I think it worked! Thanks :)
c561cf3 to
6af4b5a
Compare
|
Thanks for the contribution! Please feel free to continue finding improvements :) |
|
This is now available in torchinfo v1.6.1 |
According to the pytorch documentation on pruning,
the original parameter is replaced with one ending with
_origanda new buffer ending with
_mask. The mask contains 0s and 1s based onwhich the correct parameters are chosen.
All instances of
param.nelements()have been replaced by a variablecur_paramswhose value is set based on whether it is a masked model ornot. To keep consistency with the rest of the code base, the
_origisremoved from the
namevariable right after the calculation ofcur_params.