Skip to content

Add support for pruned models#103

Merged
TylerYep merged 2 commits intoTylerYep:mainfrom
MajorCarrot:account_for_mask
Dec 21, 2021
Merged

Add support for pruned models#103
TylerYep merged 2 commits intoTylerYep:mainfrom
MajorCarrot:account_for_mask

Conversation

@MajorCarrot
Copy link
Copy Markdown
Contributor

According to the pytorch documentation on pruning,
the original parameter is replaced with one ending with _orig and
a new buffer ending with _mask. The mask contains 0s and 1s based on
which the correct parameters are chosen.

All instances of param.nelements() have been replaced by a variable
cur_params whose value is set based on whether it is a masked model or
not. To keep consistency with the rest of the code base, the _orig is
removed from the name variable right after the calculation of
cur_params.

@MajorCarrot
Copy link
Copy Markdown
Contributor Author

I am not sure how to fix the mypy errors but I have tested this branch on pruned VGG networks and it seems to work. If you need any inputs from my side please ping me :)

Comment thread torchinfo/layer_info.py Outdated
Comment thread torchinfo/layer_info.py Outdated
@TylerYep
Copy link
Copy Markdown
Owner

@MajorCarrot Looks good so far. Could you add a test case with a pruned VGG network and the expected output?

Also left a couple comments on the changes

@MajorCarrot MajorCarrot force-pushed the account_for_mask branch 2 times, most recently from 3185e3a to 823962e Compare December 21, 2021 06:27
@TylerYep
Copy link
Copy Markdown
Owner

FYI, you'll need to run pytest --overwrite in order to create the expected output file for the test case you added

@MajorCarrot
Copy link
Copy Markdown
Contributor Author

I have fixed almost all the errors with mypy (by rewriting the rgetattr and changing its signature), but I am still not able to get the whole thing passed through mypy xD

If you have suggestions on how to fix this (hopefully final) error, we should be good to go!

@TylerYep
Copy link
Copy Markdown
Owner

You can add a # type: ignore on the last error. The issue is that this l2_unstructured function isn't type-annotated

Comment thread tests/torchinfo_test.py Outdated
Copy link
Copy Markdown
Owner

@TylerYep TylerYep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple more changes

Comment thread torchinfo/layer_info.py Outdated

def rgetattr(obj: torch.nn.Module, attr: str) -> torch.Tensor:
"""Get the tensor submodule called attr from obj."""
for i in attr.split("."):
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use a dfferent variable since i is typically reserved for numbers

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would attr_i be fine (to say some part of the variable attr)?

Comment thread torchinfo/layer_info.py Outdated
According to the [pytorch documentation on
pruning](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html),
the original parameter is replaced with one ending with `_orig` and
a new buffer ending with `_mask`. The mask contains 0s and 1s based on
which the correct parameters are chosen.

All instances of `param.nelements()` have been replaced by a variable
`cur_params` whose value is set based on whether it is a masked model or
not. To keep consistency with the rest of the code base, the `_orig` is
removed from the `name` variable right after the calculation of
`cur_params`.
@codecov
Copy link
Copy Markdown

codecov Bot commented Dec 21, 2021

Codecov Report

Merging #103 (c561cf3) into main (d9f4857) will decrease coverage by 0.20%.
The diff coverage is 93.75%.

❗ Current head c561cf3 differs from pull request most recent head 6af4b5a. Consider uploading reports for the commit 6af4b5a to get more accurate results
Impacted file tree graph

@@            Coverage Diff             @@
##             main     #103      +/-   ##
==========================================
- Coverage   99.30%   99.10%   -0.21%     
==========================================
  Files           5        5              
  Lines         434      446      +12     
==========================================
+ Hits          431      442      +11     
- Misses          3        4       +1     
Impacted Files Coverage Δ
torchinfo/layer_info.py 97.67% <93.75%> (-0.62%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d9f4857...6af4b5a. Read the comment docs.

@MajorCarrot
Copy link
Copy Markdown
Contributor Author

(I am really sorry for messing up the whole testing thing, so many runs 🤦)

I had removed the 0 in 0.5 because flake8 doesn't like the longer line and I can't shift the # type: ignore anywhere else

Comment thread tests/torchinfo_test.py Outdated
model = SingleInputNet()
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
prune.l1_unstructured(module, "weight", 0.5) # type: ignore[no-untyped-call]
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prune.l1_unstructured(module, "weight", 0.5) # type: ignore[no-untyped-call]
prune.l1_unstructured( # type: ignore[no-untyped-call]
module, "weight", 0.5
)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it worked! Thanks :)

@TylerYep TylerYep merged commit 2141b78 into TylerYep:main Dec 21, 2021
@TylerYep
Copy link
Copy Markdown
Owner

Thanks for the contribution! Please feel free to continue finding improvements :)

@TylerYep
Copy link
Copy Markdown
Owner

This is now available in torchinfo v1.6.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants