Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The size of tensor a (100) must match the size of tensor b (17) at non-singleton dimension 3 #8

Closed
YihaoChan opened this issue Jun 26, 2022 · 6 comments
Labels

Comments

@YihaoChan
Copy link

YihaoChan commented Jun 26, 2022

Hello, thanks for your excellent library!

When I intend to prune a pre-trained BERT for 17-classes text classification, my code is:

# -*- coding: UTF-8 -*-
import os
from transformers import BertTokenizer, BertModel
from transformers import BertForSequenceClassification
from textpruner import summary, TransformerPruner, TransformerPruningConfig, inference_time
import directory
from torch.utils.data import DataLoader
from helper.dataset import TextDataset
from run import RunConfig
import multiprocessing
from evaluate import test_pro
import numpy as np
import torch

model = BertForSequenceClassification.from_pretrained(directory.PRETRAIN_DIR, num_labels=17)

model.load_state_dict(torch.load('model/fold_1_best.pth'))

tokenizer = BertTokenizer.from_pretrained(directory.PRETRAIN_DIR)

test_df = test_pro()

test_dataset = TextDataset(test_df, np.arange(test_df.shape[0]))

test_loader = DataLoader(
    test_dataset, batch_size=run_config.batch_size, shuffle=True, num_workers=multiprocessing.cpu_count()
)

print(summary(model))

transformer_pruning_config = TransformerPruningConfig(
    target_ffn_size=1536, target_num_of_heads=6,
    pruning_method='iterative', n_iters=1)

pruner = TransformerPruner(model, transformer_pruning_config=transformer_pruning_config)

pruner.prune(dataloader=test_loader, save_model=True)

tokenizer.save_pretrained(pruner.save_dir)

print(summary(model))

But it occurs:

Calculating IS with loss:   0%|                                                                                                                                | 0/125 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "/home/dell/programme/BERT-pruning/prune.py", line 57, in <module>
    pruner.prune(dataloader=test_loader, save_model=True)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/textpruner/pruners/transformer_pruner.py", line 86, in prune
    save_dir = self.iterative_pruning(dataloader, adaptor, batch_postprocessor, keep_shape, save_model=save_model, rewrite_cache=rewrite_cache)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/textpruner/pruners/transformer_pruner.py", line 149, in iterative_pruning
    head_importance, ffn_importance = self.get_importance_score(dataloader, adaptor, batch_postprocessor)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/textpruner/pruners/transformer_pruner.py", line 397, in get_importance_score
    outputs = model(*batch)
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1556, in forward
    outputs = self.bert(
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1018, in forward
    encoder_outputs = self.encoder(
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 607, in forward
    layer_outputs = layer_module(
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 493, in forward
    self_attention_outputs = self.attention(
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 423, in forward
    self_outputs = self.self(
  File "/home/dell/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/dell/anaconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 348, in forward
    attention_scores = attention_scores + attention_mask
RuntimeError: The size of tensor a (100) must match the size of tensor b (17) at non-singleton dimension 3

I found few materials or tutorials about TextPruner, maybe it is a little bit latest.

Please have a look at this bug when you are free. Thanks in advance!

@airaria
Copy link
Owner

airaria commented Jun 27, 2022

Would you please describe each element of the batch print their shapes like this?

for batch in test_loader:
    print([t.shape for t in batch])
    break

It is weird to seen 17 (the number of labels) appear in the calculation of attentions.

@YihaoChan
Copy link
Author

Would you please describe each element of the batch print their shapes like this?

for batch in test_loader:
    print([t.shape for t in batch])
    break

It is weird to seen 17 (the number of labels) appear in the calculation of attentions.

Thanks for your reply. I print the shape of batch in test_loader and train_loader, they both show:

torch.Size([16, 100])
torch.Size([16, 17])

But it works well in training and evaulating orz... I am also confused why 17 in attention for training and evaluating works, but fails in pruning TAT

@airaria
Copy link
Owner

airaria commented Jun 27, 2022

Would you please describe each element of the batch print their shapes like this?

for batch in test_loader:
    print([t.shape for t in batch])
    break

It is weird to seen 17 (the number of labels) appear in the calculation of attentions.

Thanks for your reply. I print the shape of batch in test_loader and train_loader, they both show:

torch.Size([16, 100])
torch.Size([16, 17])

But it works well in training and evaulating orz... I am also confused why 17 in attention for training and evaluating works, but fails in pruning TAT

It looks like the model has wrongly treated the second tensor ([16,17]) as the attention masks, because:

  1. In BertForSequenceClassification, the first argument is input_ids, second the argument is attention_mask by default
  2. The type of the batch is a tuple or a list, not a dict

To solve this, you can either

  1. modify your TextDataset code to make it return a dict like {'input_ids': TenosrA, 'labels': TensorB}

or

  1. define a function that takes a batch and return a new batch of dict with the names as the keys
def batch_postprocessor(batch):
    return {'input_ids':batch[0],'labels':batch[1]}

and then call the pruner as:

pruner.prune(dataloader=test_loader, save_model=True,batch_postprocessor= batch_postprocessor)

ps:
Also you have to make sure the first element of the model output is the loss (otherwise you have to define another adaptor function, see ?pruner.prune)

@albert-jin
Copy link

great~ I think you 'd better tip this mentions in your official docs.

@YihaoChan
Copy link
Author

Would you please describe each element of the batch print their shapes like this?

for batch in test_loader:
    print([t.shape for t in batch])
    break

It is weird to seen 17 (the number of labels) appear in the calculation of attentions.

Thanks for your reply. I print the shape of batch in test_loader and train_loader, they both show:

torch.Size([16, 100])
torch.Size([16, 17])

But it works well in training and evaulating orz... I am also confused why 17 in attention for training and evaluating works, but fails in pruning TAT

It looks like the model has wrongly treated the second tensor ([16,17]) as the attention masks, because:

  1. In BertForSequenceClassification, the first argument is input_ids, second the argument is attention_mask by default
  2. The type of the batch is a tuple or a list, not a dict

To solve this, you can either

  1. modify your TextDataset code to make it return a dict like {'input_ids': TenosrA, 'labels': TensorB}

or

  1. define a function that takes a batch and return a new batch of dict with the names as the keys
def batch_postprocessor(batch):
    return {'input_ids':batch[0],'labels':batch[1]}

and then call the pruner as:

pruner.prune(dataloader=test_loader, save_model=True,batch_postprocessor= batch_postprocessor)

ps: Also you have to make sure the first element of the model output is the loss (otherwise you have to define another adaptor function, see ?pruner.prune)

Thanks for your suggestion. I modify my Dataset as you suggest, and it works. And debugging shows the attention_mask was recognized by my label, cuz I didn't return as a dict type.

@stale
Copy link

stale bot commented Jul 12, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale label Jul 12, 2022
@airaria airaria closed this as completed Jul 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants