Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD distli-whisper #1306

Closed
kyakuno opened this issue Nov 9, 2023 · 4 comments
Closed

ADD distli-whisper #1306

kyakuno opened this issue Nov 9, 2023 · 4 comments
Assignees

Comments

@kyakuno
Copy link
Collaborator

kyakuno commented Nov 9, 2023

https://github.com/huggingface/distil-whisper

@kyakuno
Copy link
Collaborator Author

kyakuno commented Nov 9, 2023

@ooe1123 連続してwhisperになってしまうのですが、largev3のあとに、distli whisper large v2のエクスポートを検討いただけると嬉しいです。

@kyakuno
Copy link
Collaborator Author

kyakuno commented Nov 9, 2023

(間に他のモデルをエクスポートして頂いても構いません)

@ooe1123
Copy link
Contributor

ooe1123 commented Nov 28, 2023

  • decoder

〇 transformers\models\whisper\modeling_whisper.py

class WhisperAttention(nn.Module):
    ...
    def forward(
        ...
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        ...
        if (
            is_cross_attention
            and past_key_value is not None
            and past_key_value[0].shape[2] == key_value_states.shape[1]
        ):
            ...
        elif is_cross_attention:
            ...
        elif past_key_value is not None:
            ...
        else:
            ...

ctrl_flg = [False]

class WhisperAttention(nn.Module):
    ...
    def forward(
        ...
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        ...
        # retrieve input_ids and inputs_embeds
        if ctrl_flg[0] is False:
            # オリジナルコード
            if (
                is_cross_attention
                and past_key_value is not None
                and past_key_value[0].shape[2] == key_value_states.shape[1]
            ):
                ...
            elif is_cross_attention:
                ...
            elif past_key_value is not None:
                ...
            else:
                ...
        else:
            if is_cross_attention:
                key_states = torch.cat([past_key_value[0], self._shape(self.k_proj(key_value_states), -1, bsz)], dim=2)
                value_states = torch.cat([past_key_value[1], self._shape(self.v_proj(key_value_states), -1, bsz)], dim=2)
                key_states = key_states[:,:,:1500,:]
                value_states = value_states[:,:,:1500,:]
            else:
                key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
                value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
                key_states = torch.cat([past_key_value[0], key_states], dim=2)
                value_states = torch.cat([past_key_value[1], value_states], dim=2)

〇 transformers\generation\utils.py

class GenerationMixin:
    ...
    def greedy_search(
        ...
    ) -> Union[GreedySearchOutput, torch.LongTensor]:
        ...
        this_peer_finished = False  # used by synced_gpus only
        while True:
            ...
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

class GenerationMixin:
    ...
    def greedy_search(
        ...
    ) -> Union[GreedySearchOutput, torch.LongTensor]:
        ...
        # Add
        from transformers.models.whisper.modeling_whisper import ctrl_flg
        ctrl_flg[0] = True

        this_peer_finished = False  # used by synced_gpus only
        while True:
            ...
            # prepare model inputs
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # Add
            if model_inputs["past_key_values"] is None:
                b = model_inputs["encoder_outputs"][0].size(0)
                d = model_inputs["encoder_outputs"][0].device
                model_inputs["past_key_values"] = [
                    [
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                        torch.zeros(b, 20, 0, 64, dtype=torch.float16).to(d),
                    ]
                ] * 2

            if 1 and 0 < model_inputs["past_key_values"][0][0].size(2):
            # if 0:
                class Net(nn.Module):
                    def __init__(self, net):
                        super(Net, self).__init__()
                        self.net = net
                    def forward(
                            self, decoder_input_ids, encoder_hidden_states,
                            past_key_values_0_decoder_key, past_key_values_0_decoder_value, past_key_values_0_encoder_key, past_key_values_0_encoder_value, past_key_values_1_decoder_key, past_key_values_1_decoder_value, past_key_values_1_encoder_key, past_key_values_1_encoder_value,
                        ):
                        model_inputs = {
                            "decoder_input_ids": decoder_input_ids,
                            "encoder_outputs": [encoder_hidden_states],
                            "past_key_values": [
                                [
                                    past_key_values_0_decoder_key,
                                    past_key_values_0_decoder_value,
                                    past_key_values_0_encoder_key,
                                    past_key_values_0_encoder_value,
                                ],
                                [
                                    past_key_values_1_decoder_key,
                                    past_key_values_1_decoder_value,
                                    past_key_values_1_encoder_key,
                                    past_key_values_1_encoder_value,
                                ],
                            ],
                        }
                        outputs = self.net(
                            **model_inputs,
                            return_dict=True,
                            output_attentions=output_attentions,
                            output_hidden_states=output_hidden_states,
                        )
                        return (
                            outputs["logits"],
                            outputs["past_key_values"][0][0],
                            outputs["past_key_values"][0][1],
                            outputs["past_key_values"][0][2],
                            outputs["past_key_values"][0][3],
                            outputs["past_key_values"][1][0],
                            outputs["past_key_values"][1][1],
                            outputs["past_key_values"][1][2],
                            outputs["past_key_values"][1][3],
                        )

                model = Net(self)
                print("------>")
                from torch.autograd import Variable
                xx = (
                    Variable(model_inputs["decoder_input_ids"]),
                    Variable(model_inputs["encoder_outputs"].last_hidden_state),
                    Variable(model_inputs["past_key_values"][0][0]),
                    Variable(model_inputs["past_key_values"][0][1]),
                    Variable(model_inputs["past_key_values"][0][2]),
                    Variable(model_inputs["past_key_values"][0][3]),
                    Variable(model_inputs["past_key_values"][1][0]),
                    Variable(model_inputs["past_key_values"][1][1]),
                    Variable(model_inputs["past_key_values"][1][2]),
                    Variable(model_inputs["past_key_values"][1][3]),
                )
                torch.onnx.export(
                    model, xx, 'decoder_model.onnx',
                    input_names=[
                       'input_ids', 'encoder_hidden_states', 'past_key_values.0.decoder.key', 'past_key_values.0.decoder.value', 'past_key_values.0.encoder.key', 'past_key_values.0.encoder.value', 'past_key_values.1.decoder.key', 'past_key_values.1.decoder.value', 'past_key_values.1.encoder.key', 'past_key_values.1.encoder.value', 
                    ],
                    output_names=[
                        'logits',
                        'present.0.decoder.key', 'present.0.decoder.value', 'present.0.encoder.key', 'present.0.encoder.value', 'present.1.decoder.key', 'present.1.decoder.value', 'present.1.encoder.key', 'present.1.encoder.value',
                    ],
                    dynamic_axes={
                        'input_ids': {0: 'batch_size', 1: 'decoder_sequence_length'},
                        'encoder_hidden_states': {0: 'batch_size', 1: 'encoder_sequence_length / 2'},
                        'logits': {0: 'batch_size', 1: 'decoder_sequence_length'},
                        'past_key_values.0.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.0.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.encoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'past_key_values.1.encoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length'},
                        'present.0.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.0.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.0.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.0.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.1.decoder.key': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.1.decoder.value': {0: 'batch_size', 2: 'past_decoder_sequence_length + 1'},
                        'present.1.encoder.key': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                        'present.1.encoder.value': {0: 'batch_size', 2: 'encoder_sequence_length_out'},
                    },
                    verbose=False, opset_version=11
                )
                print("<------")
                exit(0)

            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

@ooe1123
Copy link
Contributor

ooe1123 commented Nov 28, 2023

encoderはoptimum-cli でエクスポート可能

@kyakuno kyakuno closed this as completed Dec 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants