Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make print_summary by default as true #2

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

sizhky
Copy link

@sizhky sizhky commented Jun 13, 2020

Hi,
I've been using this module since a month and my experience with it has been largely pleasant. Thank you for your contribution 😄

One thing i observed was that i was always using print_summary as True, so i just feel, using True as default is more user friendly. In fact using False asserts the usage of needing a string instead of a printed report more than expecting a user to print it everytime...

Edit:
I added a few more edits to i/o. I think the changes might have been too aggressive in design change, but we can test it out more and try to break it...

@sizhky
Copy link
Author

sizhky commented Jun 13, 2020

By default, we will have all inputs and outputs including model's i/o

>>> class myNN(nn.Module):
      def __init__(self):
        super().__init__()
        self.model1 = nn.Sequential(
            nn.Linear(100,200),
            nn.ReLU(inplace=True),
            nn.Linear(200,50),
            nn.Softmax(-1)
        )
        self.model2 = nn.Sequential(
            nn.Linear(10,20),
            nn.ReLU(inplace=True),
            nn.Linear(20,5),
            nn.Softmax(-1)
        )
      def forward(self, x1, x2):
        y1 = self.model1(x1)
        y2 = self.model2(x2)
        return y1,y2

>>> mynn = myNN()
>>> summary(mynn, torch.zeros(1,100), torch.zeros(1,10))
-------------------------------------------------------------------------
      Layer (type)           Input Shape    Output Shape         Param #
=========================================================================
             Input     [1, 100], [1, 10]                              -1
          Linear-2              [1, 100]        [1, 200]          20,200
            ReLU-3              [1, 200]        [1, 200]               0
          Linear-4              [1, 200]         [1, 50]          10,050
         Softmax-5               [1, 50]         [1, 50]               0
          Linear-6               [1, 10]         [1, 20]             220
            ReLU-7               [1, 20]         [1, 20]               0
          Linear-8               [1, 20]          [1, 5]             105
         Softmax-9                [1, 5]          [1, 5]               0
            Output                       [1, 50], [1, 5]              -1
=========================================================================
Total params: 30,577
Trainable params: 30,577
Non-trainable params: 0
-------------------------------------------------------------------------

@amarczew
Copy link
Owner

Hi @sizhky! Thank you so much for your PR! After years using summary in keras, we needed a similar version in pytorch 😄

I agree, print as default is better. In general, when we call this method we want to print indeed

When I created this module, I thought about printing both input/output shapes, but thinking in keras behavior I realized that, excluding first and last layers, information is duplicated because output from last layer is the input to the next (in general). Maybe an option to drop one of them when desired could embrace all programmers who [do/do not] want to see so much information.

On the other hand, version with both sounds good, specially in your example. I think for your example, maybe, an version showing parent layer could be better than print both input and output shapes. Keras has an option for that, but I haven't implemented that yet

Did you test your version with lib examples?
I think version with both shapes in those examples will not have a good view due to shape layers. The "problem" showing both shapes get clear with them

What do you think about above points/questions?

@sizhky
Copy link
Author

sizhky commented Jun 16, 2020

I understand your points about keeping the table succinct. I want to add a couple of points -

  • The output shape is only one more column and almost always it is important to know what goes in and what comes out. Even if it is redundant, it is vital. And in many cases (see CNN example below where) I/O don't match between tensors because of using F. functions in forward.
  • Using parent will certainly help, but we must be aware that some modules can get tensors from multiple parents. The least we can do is show all the input tensor shapes and output tensor shapes so someone can at least match, which tensor is going where.

See the outputs for lib examples below.
The output in Transformers is very messed up 😄

The "problem" showing both shapes get clear with them

I didn't undrestand

CNN

-----------------------------------------------------------------------
      Layer (type)         Input Shape    Output Shape         Param #
=======================================================================
             Input      [1, 1, 28, 28]                              -1
          Conv2d-2      [1, 1, 28, 28] [1, 10, 24, 24]             260
          Conv2d-3     [1, 10, 12, 12]   [1, 20, 8, 8]           5,020
       Dropout2d-4       [1, 20, 8, 8]   [1, 20, 8, 8]               0
          Linear-5            [1, 320]         [1, 50]          16,050
          Linear-6             [1, 50]         [1, 10]             510
            Output                             [1, 10]              -1
=======================================================================
Total params: 21,838
Trainable params: 21,838
Non-trainable params: 0
-----------------------------------------------------------------------
=========================== Hierarchical Summary ===========================

Net(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1)), 260 params
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1)), 5,020 params
  (conv2_drop): Dropout2d(p=0.5, inplace=False), 0 params
  (fc1): Linear(in_features=320, out_features=50, bias=True), 16,050 params
  (fc2): Linear(in_features=50, out_features=10, bias=True), 510 params
), 21,840 params


============================================================================

Transformer

-----------------------------------------------------------------------------------
      Layer (type)                     Input Shape    Output Shape         Param #
===================================================================================
             Input                  [1, 5], [1, 5]                              -1
         Encoder-2                          [1, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5]      17,332,224
         Decoder-3     [1, 5], [1, 5], [1, 5, 512] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5]      22,060,544
          Linear-4                     [1, 5, 512]       [1, 5, 7]           3,584
            Output                                 [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5]              -1
===================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
-----------------------------------------------------------------------------------
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Layer (type)                                                                                                                                                                                                                                                            Input Shape    Output Shape         Param #
==========================================================================================================================================================================================================================================================================================================================
             Input                                                                                                                                                                                                                                                         [1, 5], [1, 5]                              -1
         Encoder-2                                                                                                                                                                                                                                                                 [1, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5]      17,332,224
         Decoder-3                                                                                                                                                                                                                                            [1, 5], [1, 5], [1, 5, 512] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5]      22,060,544
          Linear-4                                                                                                                                                                                                                                                            [1, 5, 512]       [1, 5, 7]           3,584
            Output                                                                                                                                                                                                                                                                        [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5]              -1
==========================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
Batch size: 1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


================================ Hierarchical Summary ================================

Transformer(
  (encoder): Encoder(
    (src_emb): Embedding(6, 512), 3,072 params
    (pos_emb): Embedding(6, 512), 3,072 params
    (layers): ModuleList(
      (0): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
      (1): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
      (2): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
      (3): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
      (4): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
      (5): EncoderLayer(
        (enc_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 2,887,680 params
    ), 17,326,080 params
  ), 17,332,224 params
  (decoder): Decoder(
    (tgt_emb): Embedding(7, 512), 3,584 params
    (pos_emb): Embedding(6, 512), 3,072 params
    (layers): ModuleList(
      (0): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
      (1): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
      (2): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
      (3): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
      (4): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
      (5): DecoderLayer(
        (dec_self_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (dec_enc_attn): MultiHeadAttention(
          (W_Q): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_K): Linear(in_features=512, out_features=512, bias=True), 262,656 params
          (W_V): Linear(in_features=512, out_features=512, bias=True), 262,656 params
        ), 787,968 params
        (pos_ffn): PoswiseFeedForwardNet(
          (conv1): Conv1d(512, 2048, kernel_size=(1,), stride=(1,)), 1,050,624 params
          (conv2): Conv1d(2048, 512, kernel_size=(1,), stride=(1,)), 1,049,088 params
        ), 2,099,712 params
      ), 3,675,648 params
    ), 22,053,888 params
  ), 22,060,544 params
  (projection): Linear(in_features=512, out_features=7, bias=False), 3,584 params
), 39,396,352 params


======================================================================================

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Layer (type)                                                                                                                                                                                                                                                            Input Shape    Output Shape         Param #
==========================================================================================================================================================================================================================================================================================================================
             Input                                                                                                                                                                                                                                                         [1, 5], [1, 5]                              -1
       Embedding-2                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,072
       Embedding-3                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,072
    EncoderLayer-4                                                                                                                                                                                                                                                 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]       2,887,680
    EncoderLayer-5                                                                                                                                                                                                                                                 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]       2,887,680
    EncoderLayer-6                                                                                                                                                                                                                                                 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]       2,887,680
    EncoderLayer-7                                                                                                                                                                                                                                                 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]       2,887,680
    EncoderLayer-8                                                                                                                                                                                                                                                 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]       2,887,680
    EncoderLayer-9                                                                                                                                                                                                                                                 [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]       2,887,680
      Embedding-10                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,584
      Embedding-11                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,072
   DecoderLayer-12                                                                                                                                                                                                                         [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5]       3,675,648
   DecoderLayer-13                                                                                                                                                                                                                         [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5]       3,675,648
   DecoderLayer-14                                                                                                                                                                                                                         [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5]       3,675,648
   DecoderLayer-15                                                                                                                                                                                                                         [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5]       3,675,648
   DecoderLayer-16                                                                                                                                                                                                                         [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5]       3,675,648
   DecoderLayer-17                                                                                                                                                                                                                         [1, 5, 512], [1, 5, 512], [1, 5, 5], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5], [1, 8, 5, 5]       3,675,648
         Linear-18                                                                                                                                                                                                                                                            [1, 5, 512]       [1, 5, 7]           3,584
            Output                                                                                                                                                                                                                                                                        [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5]              -1
==========================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Layer (type)                                                                                                                                                                                                                                                            Input Shape    Output Shape         Param #
==========================================================================================================================================================================================================================================================================================================================
             Input                                                                                                                                                                                                                                                         [1, 5], [1, 5]                              -1
       Embedding-2                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,072
       Embedding-3                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,072
          Linear-4                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
          Linear-5                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
          Linear-6                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
          Conv1d-7                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
          Conv1d-8                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
          Linear-9                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-10                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-11                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-12                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-13                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
         Linear-14                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-15                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-16                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-17                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-18                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
         Linear-19                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-20                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-21                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-22                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-23                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
         Linear-24                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-25                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-26                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-27                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-28                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
         Linear-29                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-30                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-31                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-32                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-33                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
      Embedding-34                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,584
      Embedding-35                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,072
         Linear-36                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-37                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-38                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-39                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-40                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-41                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-42                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-43                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
         Linear-44                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-45                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-46                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-47                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-48                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-49                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-50                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-51                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
         Linear-52                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-53                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-54                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-55                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-56                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-57                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-58                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-59                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
         Linear-60                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-61                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-62                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-63                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-64                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-65                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-66                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-67                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
         Linear-68                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-69                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-70                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-71                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-72                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-73                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-74                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-75                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
         Linear-76                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-77                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-78                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-79                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-80                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Linear-81                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]         262,656
         Conv1d-82                                                                                                                                                                                                                                                            [1, 512, 5]    [1, 2048, 5]       1,050,624
         Conv1d-83                                                                                                                                                                                                                                                           [1, 2048, 5]     [1, 512, 5]       1,049,088
         Linear-84                                                                                                                                                                                                                                                            [1, 5, 512]       [1, 5, 7]           3,584
            Output                                                                                                                                                                                                                                                                        [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5]              -1
==========================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
                      Parent Layers                Layer (type)                                                                                                                                                                                                                                                            Input Shape    Output Shape         Param #
======================================================================================================================================================================================================================================================================================================================================================================
                                                          Input                                                                                                                                                                                                                                                         [1, 5], [1, 5]                              -1
                Transformer/Encoder                 Embedding-2                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,072
                Transformer/Encoder                 Embedding-3                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,072
   Transformer/Encoder/EncoderLayer        MultiHeadAttention-4                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Encoder/EncoderLayer     PoswiseFeedForwardNet-5                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
   Transformer/Encoder/EncoderLayer        MultiHeadAttention-6                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Encoder/EncoderLayer     PoswiseFeedForwardNet-7                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
   Transformer/Encoder/EncoderLayer        MultiHeadAttention-8                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Encoder/EncoderLayer     PoswiseFeedForwardNet-9                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
   Transformer/Encoder/EncoderLayer       MultiHeadAttention-10                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Encoder/EncoderLayer    PoswiseFeedForwardNet-11                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
   Transformer/Encoder/EncoderLayer       MultiHeadAttention-12                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Encoder/EncoderLayer    PoswiseFeedForwardNet-13                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
   Transformer/Encoder/EncoderLayer       MultiHeadAttention-14                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Encoder/EncoderLayer    PoswiseFeedForwardNet-15                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
                Transformer/Decoder                Embedding-16                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,584
                Transformer/Decoder                Embedding-17                                                                                                                                                                                                                                                                 [1, 5]     [1, 5, 512]           3,072
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-18                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-19                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer    PoswiseFeedForwardNet-20                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-21                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-22                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer    PoswiseFeedForwardNet-23                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-24                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-25                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer    PoswiseFeedForwardNet-26                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-27                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-28                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer    PoswiseFeedForwardNet-29                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-30                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-31                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer    PoswiseFeedForwardNet-32                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-33                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer       MultiHeadAttention-34                                                                                                                                                                                                                       [1, 5, 512], [1, 5, 512], [1, 5, 512], [1, 5, 5] [1, 5, 512], [1, 8, 5, 5]         787,968
   Transformer/Decoder/DecoderLayer    PoswiseFeedForwardNet-35                                                                                                                                                                                                                                                            [1, 5, 512]     [1, 5, 512]       2,099,712
                        Transformer                   Linear-36                                                                                                                                                                                                                                                            [1, 5, 512]       [1, 5, 7]           3,584
                                                         Output                                                                                                                                                                                                                                                                        [5, 7], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5], [1, 8, 5, 5]              -1
======================================================================================================================================================================================================================================================================================================================================================================
Total params: 39,396,350
Trainable params: 39,390,206
Non-trainable params: 6,144
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

@amarczew
Copy link
Owner

@sizhky table width is breaking in transformer example?

Im going to merge your PR when table width is validated and add a parameter to drop some column as optional parameter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants