Description
This issue explains a bug in torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py
We found this while importing an exported model into MLIR. This occurs for an exported MultiheadAttention layer with "NeedWeight = false" which means weights are not going to be returned by the layer. So, the second output attn_output_weights will be None in this case.
The following error is raised:
Python Error: NotImplementedError: OutputKind.USER_OUTPUT for <class
'torch.export.graph_signature.ConstantArgument'>: ConstantArgument(name='',
value=None)
[Additionally, I couldn't visualize the exported model as .pt2 using a tool like https://netron.app/,
However, I am able to import the exported model and visualize it when "NeedWeight = true", i.e. attn_output_weights will not be None in this case]
doc: https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
parameters:
need_weights: [bool] If specified, returns attn_output_weights
outputs:
attn_output_weights: Only returned when need_weights=True.
Source code to reproduce the exported model with attn_output_weights = None
import torch
import torch.nn as nn
import torch.nn.functional as F
class CustomModel(nn.Module):
def __init__(self, kwargs):
super(CustomModel, self).__init__()
self.kwargs = kwargs
self.attn = nn.MultiheadAttention(embed_dim=kwargs['embedding_dim'], num_heads=kwargs['num_heads'], dropout=kwargs['dropout'], add_bias_kv=kwargs['add_bias_kv'], add_zero_attn=kwargs['add_zero_attn'], kdim=kwargs['kdim'], vdim=kwargs['vdim'], batch_first=kwargs['batch_first'])
def forward(self, *args):
query, key, value, attn_mask, kp_mask = args[0], args[1], args[2], args[3], args[4]
return self.attn(query, key, value, attn_mask=attn_mask, key_padding_mask=kp_mask, need_weights=self.kwargs['need_weights'], average_attn_weights=self.kwargs['average_attn_weights'], is_causal=self.kwargs['is_causal'])
# Create model instance
model = CustomModel(kwargs = {
'embedding_dim': 64,
'num_heads': 1,
'dropout': 0.1,
'add_bias_kv': True,
'add_zero_attn': False,
'kdim': 16,
'vdim': None, #used None inseatd of string(missing)
'batch_first': True,
'need_weights': False,
'average_attn_weights': True,
'is_causal': False
})
# Dummy input tensors
query = torch.rand(1, 50, 64) # (batch, seq_len, embedding_dim)
key = torch.rand(1, 10, 16)
value = torch.rand(1, 10, 64)
attn_mask = torch.zeros(50, 10) # (seq_len, seq_len)
key_padding_mask = torch.zeros(1, 10) # (batch, seq_len)
# Export the model
exported_model = torch.export.export(
model, args=(query, key, value, attn_mask, key_padding_mask))
# use exported_model.graph to inspect the TorchScript graph
print(exported_model)
The error occurs due to a missing case in lines # 661, 662 in the source code below (torch.export.graph_signature.ConstantArgument is not handled)
torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py
Before, proposing code changes to solve this issue, we wanted to check the expected behavior and confirm whether the OutputSpec is intentionally handled this way in the source code or if it's an actual bug that needs to be fixed.
This is a snippet from the exported program
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_attn_q_proj_weight: "f32[64, 64]", p_attn_k_proj_weight: "f32[64, 16]", p_attn_v_proj_weight: "f32[64, 64]", p_attn_in_proj_bias: "f32[192]", p_attn_bias_k: "f32[1, 1, 64]", p_attn_bias_v: "f32[1, 1, 64]", p_attn_out_proj_weight: "f32[64, 64]", p_attn_out_proj_bias: "f32[64]", args_0: "f32[1, 50, 64]", args_1: "f32[1, 10, 16]", args_2: "f32[1, 10, 64]", args_3: "f32[50, 10]", args_4: "f32[1, 10]"):
#
transpose: "f32[50, 1, 64]" = torch.ops.aten.transpose.int(args_0, 1, 0); args_0 = None
....
view_8: "f32[50, 1, 64]" = torch.ops.aten.view.default(linear_3, [50, 1, 64]); linear_3 = None
transpose_6: "f32[1, 50, 64]" = torch.ops.aten.transpose.int(view_8, 1, 0); view_8 = None
return (transpose_6, **None**)
Graph signature: ExportGraphSignature(
input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_attn_q_proj_weight'), target='attn.q_proj_weight', persistent=None), ...],
output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='transpose_6'), target=None),
OutputSpec(kind=<**OutputKind.USER_OUTPUT: 1>, arg=ConstantArgument(name='', value=None**), target=None)])
We noticed that OutputSpec has enum below while the source code handles only two types of the enum below (TensorArgument, and SymIntArgument)
https://pytorch.org/docs/stable/export.html#torch.export.graph_signature.OutputSpec