Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lora finetune #5779

Open
ily6 opened this issue May 9, 2024 · 14 comments
Open

Lora finetune #5779

ily6 opened this issue May 9, 2024 · 14 comments
Labels
Bug bug should be fixed

Comments

@ily6
Copy link

ily6 commented May 9, 2024

Hi, I want to use Whisper lora finetune for my asr task.
So I add asr_config=conf/tuning/train_asr_whisper_medium_lora_finetune.yaml, inference_config=conf/tuning/decode_asr_whisper_noctc_beam10.yaml to my run.sh. However, I got:
asr_train.py: error: unrecognized arguments: use_lora (from conf/tuning/train_asr_whisper_medium_lora_finetune.yaml)

Next I replace

use_lora: true
lora_conf: ...

with:

use_adapter: true
adapter: lora
adapter_conf: ...

But the train.log shows :
Model summary:
Class Name: ESPnetASRModel
Total Number of model parameters: 767.04 M
Number of trainable parameters: 767.04 M (100.0%)(It seems that train all the parameters)
Size: 3.07 GB
Type: torch.float32

I want to know if there are any solutions? Thanks very much!

espnet version: 202402

@ily6 ily6 added the Bug bug should be fixed label May 9, 2024
@sw005320
Copy link
Contributor

sw005320 commented May 9, 2024

Thanks for your report.
@simpleoier, can you answer it for me?

@simpleoier
Copy link
Collaborator

Hi @ily6 , did you specify freeze_params in your config? This is used to freeze the original parameters, due to the new Adapter interface made by @Stanwang1210.

@ily6
Copy link
Author

ily6 commented May 9, 2024

Thanks for your reply! I didn't specify freeze_params, just follow the original "conf/tuning/train_asr_whisper_medium_lora_finetune.yaml". What should I set the freeze_params to?

@simpleoier
Copy link
Collaborator

Can you try the following?

unused_parameters: true
freeze_param: [
"encoder",
"decoder",
]

@ily6
Copy link
Author

ily6 commented May 9, 2024

It works! Thanks very much.

@ily6
Copy link
Author

ily6 commented May 16, 2024

Hi,I've encountered a new issue. When I use Whisper for zero-shot evaluation, I directly modify the config.yaml file with:

max_epoch: 1
optim_conf:
    lr: 0.0

But the CER (%) on the Aishell1 test set was 174, and its hypothesis is as follows:
image

I would like to know if there are any solutions to this issue. Alternatively, is there any run.sh file available that I can use directly for zero-shot evaluation?

@Stanwang1210
Copy link
Contributor

Did you initialize lora parameters without fine-tuning?

@ily6
Copy link
Author

ily6 commented May 16, 2024

No, I use the asr_config=conf/tuning/train_asr_whisper_medium_full_finetune.yaml, inference_config=conf/tuning/decode_asr_whisper_noctc_beam10.yaml

@Stanwang1210
Copy link
Contributor

Did you correctly set the --token_type to whisper_multilingual ? See here
From the figure you provide, the Chinese transcriptions look reasonable.
Therefore, I doubt that there is something wrong with the tokenizer, which may lead to the high CER issue.
Also, from my own experience, the version of transformers module is quite important. Adopting wrong version of transformers module may also lead to the same issue.
Please try transformers==4.28.1 or version close to it.

If the issue can't be fixed, please provide more information.
It will help a lot.

@ily6
Copy link
Author

ily6 commented May 16, 2024

Thank you for your response. My training files are as follows:

normalize: null

encoder: whisper
encoder_conf:
    whisper_model: medium
    dropout_rate: 0.0
    use_specaug: true
    specaug_conf:
        apply_time_warp: true
        time_warp_window: 5
        time_warp_mode: bicubic
        apply_freq_mask: true
        freq_mask_width_range:
        - 0
        - 40
        num_freq_mask: 2
        apply_time_mask: true
        time_mask_width_ratio_range:
        - 0.
        - 0.12
        num_time_mask: 5


decoder: whisper
decoder_conf:
    whisper_model: medium
    dropout_rate: 0.0

preprocessor: default
preprocessor_conf:
    whisper_language: "zh"
    whisper_task: "transcribe"

model_conf:
    ctc_weight: 0.0
    lsm_weight: 0.1
    length_normalized_loss: false
    extract_feats_in_collect_stats: false
    sym_sos: "<|startoftranscript|>"
    sym_eos: "<|endoftext|>"
    # do_pad_trim: true         # should be set when doing zero-shot inference

frontend: null
input_size: 1                   # to prevent build_model() from complaining

seed: 2022
log_interval: 100
num_att_plot: 0
num_workers: 4
sort_in_batch: descending       # how to sort data in making batch
sort_batch: descending          # how to sort created batches
batch_type: numel
batch_bins: 3000000            # good for 8 * RTX 3090 24G
accum_grad: 16
max_epoch: 1
patience: none
init: none
best_model_criterion:
-   - valid
    - acc
    - max
keep_nbest_models: 1

use_amp: true
cudnn_deterministic: false
cudnn_benchmark: false

optim: adamw
grad_clip: 1.0
optim_conf:
    lr: 0.0
scheduler: warmuplr
scheduler_conf:
    warmup_steps: 1500

And run.sh

#!/usr/bin/env bash
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail

train_set=train
valid_set=dev
test_sets="test"

asr_config=conf/tuning/train_asr_whisper_medium_finetune.yaml
inference_config=conf/tuning/decode_asr_whisper_noctc_beam10.yaml

lm_config=conf/train_lm_transformer.yaml
use_lm=false
use_wordlm=false

# speed perturbation related
# (train_set will be "${train_set}_sp" if speed_perturb_factors is specified)
speed_perturb_factors="0.9 1.0 1.1"

./asr_test.sh \
    --nj 32 \
    --gpu_inference true \
    --inference_nj 1 \
    --lang zh \
    --token_type whisper_multilingual \
    --feats_normalize "" \
    --audio_format "wav" \
    --feats_type raw \
    --use_lm ${use_lm}                                 \
    --use_word_lm ${use_wordlm}                        \
    --lm_config "${lm_config}"                         \
    --cleaner whisper_basic                            \
    --asr_config "${asr_config}"                       \
    --inference_config "${inference_config}"           \
    --train_set "${train_set}"                         \
    --valid_set "${valid_set}"                         \
    --test_sets "${test_sets}"                         \
    --speed_perturb_factors "${speed_perturb_factors}" \
    --asr_speech_fold_length 512 \
    --asr_text_fold_length 150 \
    --lm_fold_length 150 \
    --lm_train_text "data/${train_set}/text" "$@"

I found that the decoding results of Aishell are repetitive, for example:

BAC009S0764W0121 
甚至出现交易几乎停滞的情况甚至出现交易几乎停滞的情况甚至出现交易几乎停滞的情况甚至出现交易几乎停滞的情况甚至出
BAC009S0764W0123 但因為聚集了過多公共事源,但因為聚集了過多公共事源,但因為聚集了過多公共事源,但因為聚集了過多公共事源,但因為聚集了過多

This leads to many insertion errors, and the final result is as follows:

|     SPKR        |      # Snt           # Wrd      |     Corr              Sub             Del             Ins              Err           S.Err      |
|     Sum/Avg     |      7176           104765      |     80.3             15.6             4.2           154.5            174.2            90.6      |

My transformer version is "4.40.2", and I achieved consistent results with the espnet official results when performing LoRA fine-tuning and full fine-tuning experiments using this version on Aishell1 dataset. Therefore, could it be that the issue is not related to the transformer version?
Would setting the learning rate to 0 and training for one epoch cause any problems? Or maybe the parameter "do_pad_trim: true"?

@Stanwang1210
Copy link
Contributor

Sorry for the late reply.
Could you please check whether the parameters in your checkpoint align with the original whisper checkpoint? So that we can tell it's not related to

learning rate to 0 and training for one epoch

And given the information you provided, it may not relate to the transformers version.

From your inference samples, it's like you encounter the hallucination problem like here.
If that's the case, then the problem will not able to be solved easily. Whisper official code did implement some post-processing to deal with that issue. You can take a look at it.

@Yuanyuan-888
Copy link

Hi, I am also using this recipe, but in asr.sh stage 5, I have Whisper attribute error tokenizer object has no attribute tokenizer error. I am using espnet 202402, and openai-whisper 202311

@simpleoier
Copy link
Collaborator

@Yuanyuan-888 The quick fix is to try an earlier version of whisper, 20230308. Whisper has changed their tokenizer API.

@Yuanyuan-888
Copy link

@Yuanyuan-888 The quick fix is to try an earlier version of whisper, 20230308. Whisper has changed their tokenizer API.

Hi! Thank you for your answer, then it will not use whisper large V3 to decode anymore

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug bug should be fixed
Projects
None yet
Development

No branches or pull requests

5 participants