In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from encodec import EncodecModel
from encodec.utils import convert_audio
import torch
import torchaudio
import os
import re
from tqdm import tqdm

In [2]:
import torch
torch.version.__version__

'1.13.1+cu117'

In [3]:
encodec_model = EncodecModel.encodec_model_24khz()
encodec_model.set_target_bandwidth(1.5)
tokenizer = AutoTokenizer.from_pretrained("anforsm/distilgpt2-finetuned-common-voice")
#model_name = "Ekgren/text-audio-distilgpt2"
#model = AutoModelForCausalLM.from_pretrained(model_name)

In [4]:
model = AutoModelForCausalLM.from_pretrained("Ekgren/distilgpt2-finetuned-common-voice")

In [28]:
text = 'text: i am taking out the trash\nsound:'
tokenized = tokenizer(text, return_tensors="pt")
tokens = model.generate(tokenized["input_ids"], do_sample=True, max_length=1024, temperature=1, top_k=50, top_p=0.95)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [6]:
def decode(tokens):
    decoded = tokenizer.decode(tokens[0], skip_special_tokens=True)
    # Get all audio_token_
    pattern = r'audio_token_(\d+)'
    audio_tokens = re.findall(pattern, decoded)
    audio_tokens = [int(token) for token in audio_tokens]

    print(audio_tokens)

    number_of_codebooks = 2
    number_of_samples = len(audio_tokens) // number_of_codebooks
    frame = torch.zeros(1, number_of_codebooks, number_of_samples, dtype=torch.long)
    for sample in range(number_of_samples):
        for codebook in range(number_of_codebooks):
            frame[0, codebook, sample] = audio_tokens[sample * number_of_codebooks + codebook]
    
    frames = [(frame, None)]

    with torch.no_grad():
        wav = encodec_model.decode(frames)

    torchaudio.save("output.wav", wav[0, :, :], encodec_model.sample_rate)

In [29]:
decode(tokens)

[62, 913, 62, 424, 62, 424, 62, 424, 408, 544, 408, 544, 62, 424, 408, 913, 408, 913, 465, 860, 592, 60, 278, 725, 158, 86, 544, 123, 12, 22, 476, 437, 321, 349, 224, 390, 491, 717, 276, 991, 59, 336, 790, 702, 986, 702, 563, 349, 931, 317, 796, 728, 866, 702, 766, 847, 766, 22, 766, 881, 74, 307, 18, 368, 609, 728, 766, 881, 12, 368, 487, 361, 766, 16, 311, 349, 86, 454, 906, 459, 944, 468, 303, 808, 303, 716, 303, 728, 303, 219, 303, 336, 53, 266, 53, 219, 373, 920, 373, 898, 103, 857, 373, 898, 533, 645, 476, 195, 952, 677, 920, 311, 43, 578, 858, 931, 696, 116, 875, 770, 347, 622, 971, 73, 875, 752, 777, 872, 323, 864, 91, 350, 523, 436, 656, 797, 605, 745, 312, 619, 605, 621, 648, 359, 626, 53, 86, 233, 636, 656, 597, 433, 638, 195, 280, 410, 551, 655, 605, 406, 464, 451, 626, 260, 86, 451, 312, 405, 778, 227, 312, 295, 651, 815, 491, 942, 953, 6, 59, 179, 505, 351, 3, 96, 86, 107, 912, 192, 760, 437, 929, 390, 467, 944, 306, 429, 704, 269, 186, 317, 306, 317, 530, 35, 311, 951, 8

In [23]:
decode([tokenizer("My name is Teven and I am not a woman.\nsound: audio_token_63audio_token_937audio_token_835audio_token_740audio_token_835audio_token_302audio_token_339audio_token_913audio_token_1019audio_token_747audio_token_228audio_token_601audio_token_875audio_token_646audio_token_670audio_token_815audio_token_672audio_token_777audio_token_570audio_token_815audio_token_86audio_token_986audio_token_767audio_token_572audio_token_644audio_token_711audio_token_621audio_token_881audio_token_996audio_token_881audio_token_278audio_token_881audio_token_39audio_token_940")["input_ids"]])

[63, 937, 835, 740, 835, 302, 339, 913, 1019, 747, 228, 601, 875, 646, 670, 815, 672, 777, 570, 815, 86, 986, 767, 572, 644, 711, 621, 881, 996, 881, 278, 881, 39, 940]
