In [21]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [22]:
import torch
import torchaudio
from datasets import load_dataset
from transformers import AutoTokenizer
from speechtokenizer import SpeechTokenizer
from audiotools import AudioSignal

# Параметры устройства
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Специальные токены
start_audio_token = "<soa>"
end_audio_token = "<eoa>"
end_sequence_token = "<eos>"

# Константы
n_codebooks = 3
max_seq_length = 1024
top_k = 20

from safetensors.torch import load_file

import transformers
import torch
from torch.utils.data import Dataset
from datasets import load_dataset, Audio, Value

from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
)

from speechtokenizer import SpeechTokenizer
from audiotools import AudioSignal
device = "cuda:0"
n_special_tokens = 3
model_path = "Vikhrmodels/llama_asr_tts_24000"

tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=".")
model = AutoModelForCausalLM.from_pretrained(model_path, cache_dir=".", attn_implementation="eager", device_map={"":0})
# Загрузка токенизатора и модели

# Загрузка квантизатора
config_path = "../../audiotokenizer/speechtokenizer_hubert_avg_config.json"
ckpt_path = "../../audiotokenizer/SpeechTokenizer.pt"
quantizer = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
quantizer.eval()

# Перемещение всех слоев квантизатора на устройство и их заморозка
def freeze_entire_model(model):
    for n, p in model.named_parameters():
        p.requires_grad = False
    return model

for n, child in quantizer.named_children():
    child.to(device)
    child = freeze_entire_model(child)

# Функция для создания токенов заполнения для аудио
def get_audio_padding_tokens(quantizer):
    audio = torch.zeros((1, 1, 1)).to(device)
    codes = quantizer.encode(audio)
    del audio
    torch.cuda.empty_cache()
    return {"audio_tokens": codes.squeeze(1)}

# Функция для декодирования аудио из токенов
def decode_audio(tokens, quantizer, pad_tokens, n_original_tokens):
    start = torch.nonzero(tokens == tokenizer(start_audio_token)["input_ids"][-1])
    end = torch.nonzero(tokens == tokenizer(end_audio_token)["input_ids"][-1])
    start = start[0, -1] + 1 if len(start) else 0
    end = end[0, -1] if len(end) else tokens.shape[-1]
    
    audio_tokens = tokens[start:end] % n_original_tokens
    reminder = audio_tokens.shape[-1] % n_codebooks
    
    if reminder:
        audio_tokens = torch.cat([audio_tokens, pad_tokens[reminder:n_codebooks]], dim=0)

    transposed = audio_tokens.view(-1, n_codebooks).t()
    codes = transposed.view(n_codebooks, 1, -1).to(device)
    
    audio = quantizer.decode(codes).squeeze(0)
    torch.cuda.empty_cache()
    return AudioSignal(audio.detach().cpu().numpy(), quantizer.sample_rate)


# Пример использования



In [23]:
# Функция инференса для текста на входе и аудио на выходе
def infer_text_to_audio(text, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
    text_tokenized = tokenizer(text, return_tensors="pt")
    text_input_tokens = text_tokenized["input_ids"].to(device)
    
    soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
    eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
    
    text_tokens = torch.cat([text_input_tokens, soa], dim=1)
    attention_mask = torch.ones(text_tokens.size(), device=device)
    
    output_audio_tokens = model.generate(text_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)
    
    padding_tokens = get_audio_padding_tokens(quantizer)["audio_tokens"].to(device)
    audio_signal = decode_audio(output_audio_tokens[0], quantizer, padding_tokens.t()[0], len(tokenizer) - 1024)
    
    return audio_signal





In [24]:
# Функция инференса для аудио на входе и текста на выходе
def infer_audio_to_text(audio_path, model, tokenizer, quantizer, max_seq_length=1024, top_k=20):
    audio_data, sample_rate = torchaudio.load(audio_path)
    
    audio = audio_data.view(1, 1, -1).float().to(device)
    codes = quantizer.encode(audio)
    n_codebooks_a = 1
    raw_audio_tokens = codes[:, :n_codebooks_a] + len(tokenizer) - 1024
    
    soa = tokenizer(start_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
    eoa = tokenizer(end_audio_token, return_tensors="pt")["input_ids"][:, -1:].to(device)
    audio_tokens = torch.cat([soa, raw_audio_tokens.view(1, -1), eoa], dim=1)
    
    attention_mask = torch.ones(audio_tokens.size(), device=device)
    
    output_text_tokens = model.generate(audio_tokens, attention_mask=attention_mask, max_new_tokens=max_seq_length, top_k=top_k, do_sample=True)
    
    output_text_tokens = output_text_tokens.cpu()[0]
    output_text_tokens = output_text_tokens[output_text_tokens < tokenizer(start_audio_token)["input_ids"][-1]]
    decoded_text = tokenizer.decode(output_text_tokens, skip_special_tokens=True)
    
    return decoded_text

In [25]:
# Инференс текста в аудио
text = "Я хожу пешком.".upper()
audio_signal = infer_text_to_audio(text, model, tokenizer, quantizer)
audio_signal.write("generated_audio_.wav")



<audiotools.core.audio_signal.AudioSignal at 0x7fe04f104a10>

In [26]:
# Инференс аудио в текст
audio_path = "generated_audio.wav"
generated_text = infer_audio_to_text(audio_path, model, tokenizer, quantizer)
print(generated_text)

> THAN THIS NOTHAN SOLITUDE ALBERT YOU'LL BE IN COURSE WE'RE IN HERE AND THERE'S BUT FOR WITH A PHILIP OF WILL THERE WON'T FORGET WHAT'S THE NORTH RULE AH HOW CAN WE'LL BE BORN I'VE EVER THOUGHT IF YOU'RE BUT WHAT'S IT BUT A MATTER OF NONSENEW IF I'D BET THAT IS IT ANOTHER MAN NORA NORWAY TO IT IT'S RONICKY DOONE HE'S SOH NOW I'VE A ROSTOV MANAGED TO TRY HIM BACK ARE A RESOLUTION OF THIS STONE THAT'S SOME I'LL LET'S SEE HERE BETTER NOW
