In [1]:
from llm_asr import GPTModel, WavLM, LLMASR
import soundfile as sf
import torch
import numpy as np
from IPython.display import Audio
from peft import LoraConfig, get_peft_model
from transformers import GenerationConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
x, fs = sf.read('/mnt/data/mls_spanish_opus/test/audio/10667/6706/10667_6706_000000.opus')

In [3]:
Audio(x,rate=fs)

In [4]:
lora_rank=16
warmup_steps=100
llm_model='Qwen/Qwen2-1.5B'
wavlm_model='microsoft/wavlm-base-plus'
lr=1e-5

llm = GPTModel(llm_model)
llm_model = llm.get_model()

if lora_rank > 0:
    lora_config = LoraConfig(
                    r=lora_rank,
                    lora_alpha=16,
                    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
                    lora_dropout=0.1,
                    task_type='CAUSAL_LM'
                )
    llm_model = get_peft_model(llm_model, lora_config)
llm.set_model(llm_model)
    
wavlm = WavLM(wavlm_model)
llm_asr = LLMASR(llm, wavlm, lr, warmup_steps=warmup_steps)

if llm.tokenizer.pad_token is None:
    llm.tokenizer.pad_token = llm.tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Some weights of the model checkpoint at microsoft/wavlm-base-plus were not used when initializing WavLMModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing WavLMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing WavLMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of WavLMModel were not initialized from the model checkpoint at microsoft/wavlm-base-plus and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametriza

In [5]:
llm_asr.load_state_dict(torch.load('/home/lpepino/LLM-ASR/qwen2-1B_lora16_wavlm_layer10_instruction/logs_metrics/version_6/checkpoints/epoch=4-step=1080.ckpt', map_location='cpu')['state_dict'])

<All keys matched successfully>

In [32]:
x, fs = sf.read('/mnt/data/mls_spanish_opus/test/audio/10667/6706/10667_6706_000000.opus')
x = torch.from_numpy(x)[None,:]
x = x.to(dtype=torch.bfloat16, device='cpu')
llm_asr.to(dtype=torch.bfloat16, device='cpu')
generation_config = GenerationConfig(max_new_tokens=128, do_sample=True, temperature=0.1, min_new_tokens=5, num_beams=4, eos_token_id=llm_asr.llm_model.tokenizer.eos_token_id)
with torch.no_grad():
    prefix = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nGenerate transcription of the given speech input \n\n### Input:\n<|endoftext|>"
    prefix_embeds = llm_asr.llm_model.tokenizer(prefix)
    print(prefix_embeds)
    prefix_embeds = llm_asr.llm_model.get_lut()(torch.tensor(prefix_embeds['input_ids']))
    print(prefix_embeds.shape)
    speech = llm_asr.wav_model(x)
    speech = torch.nn.functional.avg_pool1d(speech.transpose(1,2), kernel_size=4, stride=4).transpose(1,2)
    speech = llm_asr.wav_projector(speech)
    postfix = '### Response:\n'
    postfix_embeds = llm_asr.llm_model.tokenizer(postfix)
    postfix_embeds = llm_asr.llm_model.get_lut()(torch.tensor(postfix_embeds['input_ids']))
    speech = torch.cat([prefix_embeds.unsqueeze(0), speech, postfix_embeds.unsqueeze(0)],axis=1)


    print(speech.shape)
    outs = llm_asr.llm_model.model.model.generate(inputs_embeds=speech, generation_config = generation_config, tokenizer=llm_asr.llm_model.tokenizer)

{'input_ids': [38214, 374, 458, 7600, 429, 16555, 264, 3383, 11, 34426, 448, 458, 1946, 429, 5707, 4623, 2266, 13, 9645, 264, 2033, 429, 34901, 44595, 279, 1681, 382, 14374, 29051, 510, 31115, 45840, 315, 279, 2661, 8806, 1946, 4710, 14374, 5571, 510, 151643], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
torch.Size([42, 1536])


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


torch.Size([1, 232, 1536])


In [34]:
llm_asr.llm_model.tokenizer.decode(outs[0])

'todas las tardes á la hora del te se acordaban siempre del l o y recordaban también un cuán tole gustaba comer pan mojado en té con leche pobre pedrito nunca más lo vería en por qué había muerto<|endoftext|>'