Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD kotoba-whisper-v1.0 #1455

Open
kyakuno opened this issue Apr 16, 2024 · 4 comments
Open

ADD kotoba-whisper-v1.0 #1455

kyakuno opened this issue Apr 16, 2024 · 4 comments
Assignees

Comments

@kyakuno
Copy link
Collaborator

kyakuno commented Apr 16, 2024

日本語特化のwhisperモデル。
https://huggingface.co/kotoba-tech/kotoba-whisper-v1.0

@kyakuno
Copy link
Collaborator Author

kyakuno commented Apr 30, 2024

@ooe1123 他のモデルが終わった後に、こちらをお願いできると嬉しいです _ _

@ooe1123
Copy link
Contributor

ooe1123 commented May 20, 2024

〇 transformers/models/whisper/modeling_whisper.py

class WhisperSdpaAttention(WhisperAttention):
    ...
    def forward(
        self,
        ...
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        ...
        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            ...
        elif is_cross_attention:
            ...
        elif past_key_value is not None:
            ...
        else:
            ...

        ...
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=attention_mask,
            dropout_p=self.dropout if self.training else 0.0,
            # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
            is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
        )

class WhisperDecoder(WhisperPreTrainedModel): 
    ...
    def forward(
        ...
    ):
        ...
        if self._use_flash_attention_2:
            ...
        elif self._use_sdpa and head_mask is None and not output_attentions:
            # output_attentions=True & head_mask can not be supported when using SDPA.
            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
                attention_mask, input_shape, inputs_embeds, past_key_values_length
            )
        else:
            ...

class WhisperSdpaAttention(WhisperAttention):
    ...
    def forward(
        self,
        ...
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        ...
        if is_cross_attention:
            key_states = torch.cat([past_key_value[0], self._shape(self.k_proj(key_value_states), -1, bsz)], dim=2)
            value_states = torch.cat([past_key_value[1], self._shape(self.v_proj(key_value_states), -1, bsz)], dim=2)
            key_states = key_states[:,:,:1500,:]
            value_states = value_states[:,:,:1500,:]
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        ...
        if torch.onnx.is_in_onnx_export():
            if self.is_causal:
                attn_output_1 = torch.nn.functional.scaled_dot_product_attention(
                    query_states,
                    key_states,
                    value_states,
                    attn_mask=attention_mask,
                    dropout_p=self.dropout if self.training else 0.0,
                    # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
                    is_causal=False
                )
                attn_output_2 = torch.nn.functional.scaled_dot_product_attention(
                    query_states,
                    key_states,
                    value_states,
                    attn_mask=attention_mask,
                    dropout_p=self.dropout if self.training else 0.0,
                    # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
                    is_causal=True
                )
                ind = torch.gt(tgt_len, 1).type(torch.int64)
                sel = torch.stack([attn_output_1, attn_output_2])
                attn_output = sel[ind]
            else:
                attn_output = torch.nn.functional.scaled_dot_product_attention(
                    query_states,
                    key_states,
                    value_states,
                    attn_mask=attention_mask,
                    dropout_p=self.dropout if self.training else 0.0,
                    # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
                    is_causal=False,
                )
        else:
            # オリジナル実装
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query_states,
                key_states,
                value_states,
                attn_mask=attention_mask,
                dropout_p=self.dropout if self.training else 0.0,
                # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
                is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
            )

class WhisperDecoder(WhisperPreTrainedModel): 
    ...
    def forward(
        ...
    ):
        ...
        if self._use_flash_attention_2:
            ...
        elif self._use_sdpa and head_mask is None and not output_attentions:
            # output_attentions=True & head_mask can not be supported when using SDPA.
            # attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
            #     attention_mask, input_shape, inputs_embeds, past_key_values_length
            # )
            attention_mask = None
        else:
            ...

@ooe1123
Copy link
Contributor

ooe1123 commented May 20, 2024

〇 transformers/generation/utils.py

class GenerationMixin:
    ...
    def _greedy_search(
        ...
    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
        ...
        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

class GenerationMixin:
    ...
    def _greedy_search(
        ...
    ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
        ...

        if 1:
            class Net(nn.Module):
                def __init__(self, net):
                    super(Net, self).__init__()
                    self.net = net
                def forward(
                        self, decoder_input_ids, encoder_hidden_states,
                        past_key_values_0_decoder_key, past_key_values_0_decoder_value, past_key_values_0_encoder_key, past_key_values_0_encoder_value, past_key_values_1_decoder_key, past_key_values_1_decoder_value, past_key_values_1_encoder_key, past_key_values_1_encoder_value,
                    ):
                    model_inputs = {
                        "decoder_input_ids": decoder_input_ids,
                        "encoder_outputs": [encoder_hidden_states],
                        "past_key_values": [
                            [
                                past_key_values_0_decoder_key,
                                past_key_values_0_decoder_value,
                                past_key_values_0_encoder_key,
                                past_key_values_0_encoder_value,
                            ],
                            [
                                past_key_values_1_decoder_key,
                                past_key_values_1_decoder_value,
                                past_key_values_1_encoder_key,
                                past_key_values_1_encoder_value,
                            ],
                        ],
                    }
                    outputs = self.net(
                        **model_inputs,
                        return_dict=True,
                        output_attentions=output_attentions,
                        output_hidden_states=output_hidden_states,
                    )
                    # return outputs  # Updated
                    return (
                        outputs["logits"],
                        outputs["past_key_values"][0][0].type(torch.float16),
                        outputs["past_key_values"][0][1].type(torch.float16),
                        outputs["past_key_values"][0][2].type(torch.float16),
                        outputs["past_key_values"][0][3].type(torch.float16),
                        outputs["past_key_values"][1][0].type(torch.float16),
                        outputs["past_key_values"][1][1].type(torch.float16),
                        outputs["past_key_values"][1][2].type(torch.float16),
                        outputs["past_key_values"][1][3].type(torch.float16),
                    )

            model = Net(self)
        
        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # Add
            if model_inputs["past_key_values"] is None:
                b = model_inputs["encoder_outputs"][0].size(0)
                d = model_inputs["encoder_outputs"][0].device
                model_inputs["past_key_values"] = [
                    [
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                    ]
                ] * 2

            if 1 and 0 < model_inputs["past_key_values"][0][0].size(2):
                print("------>")
                from torch.autograd import Variable
                xx = (
                    Variable(model_inputs["decoder_input_ids"]),
                    Variable(model_inputs["encoder_outputs"].last_hidden_state),
                    Variable(model_inputs["past_key_values"][0][0]),
                    Variable(model_inputs["past_key_values"][0][1]),
                    Variable(model_inputs["past_key_values"][0][2]),
                    Variable(model_inputs["past_key_values"][0][3]),
                    Variable(model_inputs["past_key_values"][1][0]),
                    Variable(model_inputs["past_key_values"][1][1]),
                    Variable(model_inputs["past_key_values"][1][2]),
                    Variable(model_inputs["past_key_values"][1][3]),
                )
                torch.onnx.export(
                    model, xx, 'decoder_model.onnx',
                    input_names=[
                       'input_ids', 'encoder_hidden_states', 'past_key_values.0.decoder.key', 'past_key_values.0.decoder.value', 'past_key_values.0.encoder.key', 'past_key_values.0.encoder.value', 'past_key_values.1.decoder.key', 'past_key_values.1.decoder.value', 'past_key_values.1.encoder.key', 'past_key_values.1.encoder.value', 
                    ],
                    output_names=[
                        'logits',
                        'present.0.decoder.key', 'present.0.decoder.value', 'present.0.encoder.key', 'present.0.encoder.value', 'present.1.decoder.key', 'present.1.decoder.value', 'present.1.encoder.key', 'present.1.encoder.value',
                    ],
                    dynamic_axes={
                        'input_ids': {0: 'batch_size', 1: 'decoder_sequence_length'},
                        'encoder_hidden_states': {0: 'batch_size', 1: 'encoder_sequence_length / 2'},
                        'logits': {0: 'batch_size', 1: 'decoder_sequence_length'},
                        'past_key_values.0.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'present.0.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.0.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.0.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.0.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.1.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.1.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.1.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.1.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                    },
                    verbose=False, opset_version=14
                )
                print("<------")
                exit(0)

@ooe1123
Copy link
Contributor

ooe1123 commented May 27, 2024

opset=17でエクスポートした場合、以下のエラーが発生するので、その対応

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from kotoba-whisper-v1.0_decoder.onnx failed:Type Error: Type parameter (T) of Optype (LayerNormalization) bound to different types (tensor(float) and tensor(float16) in node (/net/model/decoder/layers.0/encoder_attn_layer_norm/LayerNormalization).

〇 transformers/models/whisper/modeling_whisper.py

class WhisperDecoderLayer(nn.Module):
    ...
    def forward(
        ...
    ) -> torch.Tensor:
        ...
        if encoder_hidden_states is not None:
            residual = hidden_states
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

class WhisperDecoderLayer(nn.Module):
    ...
    def forward(
        ...
    ) -> torch.Tensor:
        ...
        if encoder_hidden_states is not None:
            residual = hidden_states
            hidden_states = hidden_states.type(torch.float16)
            hidden_states = self.encoder_attn_layer_norm(hidden_states)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants