-
Notifications
You must be signed in to change notification settings - Fork 36
Open
Description
- Version: python==3.7.9, torch==1.9.0+cu111, torchvision==0.10.0+cu111, calflops==0.2.0
- Problem:
Here's an example to see the error:
import torch
import torch.nn as nn
from calflops import calculate_flops
class Model(nn.Module):
def __init__(self, in_channels=768, num_heads=8, qkv_dim=256):
super().__init__()
self.num_heads = num_heads
self.qkv_dim = qkv_dim
self.to_qkvs = nn.Conv1d(in_channels, 3 * num_heads * qkv_dim, 1)
self.softmax = nn.Softmax(-1)
def forward(self, x):
N, C, L = x.shape
q, k, v = torch.chunk(self.to_qkvs(x).view(N, 3 * self.num_heads, self.qkv_dim, L), 3, dim=1)
atten = self.softmax(torch.einsum('nhcp,nhcq->nhpq', [q, k]) / (self.qkv_dim))
result = torch.einsum('nhcp,nhpq->nhcq', [v, atten])
return result
inputs ={
'x': torch.rand((1, 768, 64)),
}
model = Model()
flops, macs, params = calculate_flops(model=model, kwargs=inputs, print_results=False)
print(flops)
print(macs)
print(params)
This is a toy example of multi-head self-attention implemented using torch.einsum
. By running this script, you could get the following AttributeError
raised in function _einsum_flops_compute()
:
Traceback (most recent call last):
File "my_compute_flops.py", line 60, in <module>
flops, macs, params = calculate_flops(model=model, kwargs=inputs, print_results=False)
File "[anomyous]/lib/python3.7/site-packages/calflops/flops_counter.py", line 154, in calculate_flops
_ = model(*args, **kwargs)
File "[anomyous]/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1071, in _call_impl
result = forward_call(*input, **kwargs)
File "my_compute_flops.py", line 51, in forward
atten = self.softmax(torch.einsum('nhcp,nhcq->nhpq', [q, k]) / (self.qkv_dim))
File "[anomyous]/lib/python3.7/site-packages/calflops/pytorch_ops.py", line 360, in newFunc
flops, macs = funcFlopCompute(*args, **kwds)
File "[anomyous]/lib/python3.7/site-packages/calflops/pytorch_ops.py", line 295, in _einsum_flops_compute
input_shapes = [o.shape for o in operands]
File "[anomyous]/lib/python3.7/site-packages/calflops/pytorch_ops.py", line 295, in <listcomp>
input_shapes = [o.shape for o in operands]
AttributeError: 'list' object has no attribute 'shape'
Metadata
Metadata
Assignees
Labels
No labels