Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summary does not show sizes for all elements using module that returns List[torch.tensor] #234

Open
xavierjimenezp opened this issue Feb 22, 2023 · 0 comments
Labels
help wanted Extra attention is needed

Comments

@xavierjimenezp
Copy link

xavierjimenezp commented Feb 22, 2023

Describe the bug
If I try to use summary on a model that returns a list, it will only print the output shape of the first element in the list.

To Reproduce
This issue can be reproduced with a very simple model that would run like this

import torch
from torch import nn 
from torchinfo import summary

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return [logits, logits, logits]

if __name__=="__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = NeuralNetwork()

    print(summary(model, input_data=torch.rand(1, 28, 28, device=device), col_names = ["output_size"]))

Expected behavior
Usually models that return lists contain tensors. For instance, in a Feature Pyramid Network (FPN), one would often return a list with the different outputs for each level. Therefore, it would be nice to have a pretty way to print those shapes.

The calculate_size function should be modified accordingly for all types of input.

Something cool would be to have:

=================================================================
Layer (type:depth-idx)                   Output Shape
=================================================================
NeuralNetwork                            [1, 10]
├─Flatten: 1-1                           [1, 784]
├─Sequential: 1-2                        [1, 10]
│    └─Linear: 2-1                       [1, 512]
│    └─ReLU: 2-2                         [1, 512]
│    └─Linear: 2-3                       [1, 512]
│    └─ReLU: 2-4                         [1, 512]
│    └─Linear: 2-5                       [[1, 10]
│                                         [1, 10]
│                                         [1, 10]]
=================================================================
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
Total mult-adds (M): 0.67
=================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 2.68
Estimated Total Size (MB): 2.69
=================================================================

As a reminder, version 1.7.2 would only show

=================================================================
Layer (type:depth-idx)                   Output Shape
=================================================================
NeuralNetwork                            [1, 10]
├─Flatten: 1-1                           [1, 784]
├─Sequential: 1-2                        [1, 10]
│    └─Linear: 2-1                       [1, 512]
│    └─ReLU: 2-2                         [1, 512]
│    └─Linear: 2-3                       [1, 512]
│    └─ReLU: 2-4                         [1, 512]
│    └─Linear: 2-5                       [1, 10]
=================================================================
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
Total mult-adds (M): 0.67
=================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 2.68
Estimated Total Size (MB): 2.69
=================================================================

I remember seeing a fix for #152 to take into account modules that return numpy.ndarrays. Perhaps it would be nice to have a unified approach for different types.

Desktop (please complete the following information):

  • OS: Ubuntu 22.04
  • Pytorch 1.12.0
  • Torchinfo 1.7.2
@TylerYep TylerYep added the help wanted Extra attention is needed label Apr 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants