In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from encodec import EncodecModel
from encodec.utils import convert_audio
import torch
import torchaudio
import os
import re
from tqdm import tqdm

In [2]:
encodec_model = EncodecModel.encodec_model_24khz()
encodec_model.set_target_bandwidth(1.5)
model_name = "anforsm/distilgpt2-finetuned-common-voice"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading (…)lve/main/config.json: 100%|██████████| 1.01k/1.01k [00:00<00:00, 337kB/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Downloading (…)"pytorch_model.bin";: 100%|██████████| 337M/337M [00:08<00:00, 39.3MB/s] 
Downloading (…)neration_config.json: 100%|██████████| 119/119 [00:00<00:00, 33.2kB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 261/261 [00:00<00:00, 86.9kB/s]
Downloading (…)olve/main/vocab.json: 100%|██████████| 798k/798k [00:00<00:00, 1.25MB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████| 456k/456k [00:01<00:00, 418kB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 2.30M/2.30M [00:02<00:00, 817kB/s]
Downloading (…)in/added_tokens.json: 100%|██████████| 28.6k/28.6k [00:00<00:00, 307kB/s]
Downloading (…)cial_tokens_map.j

In [9]:
text = 'text: The track appears on the compilation album "Kraftworks".\nsound:'
tokenized = tokenizer(text, return_tensors="pt")
tokens = model.generate(tokenized["input_ids"], do_sample=True, max_length=1000, top_k=50, top_p=0.95, temperature=0.3, num_return_sequences=1)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [7]:
def decode(tokens):
    decoded = tokenizer.decode(tokens[0], skip_special_tokens=True)
    # Get all audio_token_
    pattern = r'audio_token_(\d+)'
    audio_tokens = re.findall(pattern, decoded)
    audio_tokens = [int(token) for token in audio_tokens]

    print(audio_tokens)

    number_of_codebooks = 2
    number_of_samples = len(audio_tokens) // number_of_codebooks
    frame = torch.zeros(1, number_of_codebooks, number_of_samples, dtype=torch.long)
    for sample in range(number_of_samples):
        for codebook in range(number_of_codebooks):
            frame[0, codebook, sample] = audio_tokens[sample * number_of_codebooks + codebook]
    
    frames = [(frame, None)]

    with torch.no_grad():
        wav = encodec_model.decode(frames)

    torchaudio.save("output.wav", wav[0, :, :], encodec_model.sample_rate)

In [10]:
decode(tokens)

[62, 913, 62, 424, 62, 424, 408, 544, 408, 544, 835, 913, 835, 544, 835, 913, 835, 518, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 408, 424, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 408, 424, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 518, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 835, 913, 408, 424, 408, 424, 408, 424, 408, 518, 408, 544, 408, 544, 408, 518, 408, 544, 408, 544, 408, 544, 408, 544, 408, 544, 62, 424, 62, 424, 62, 424, 62, 424, 62, 424, 62, 424, 62, 424, 62, 424, 62, 424, 408, 544, 408, 544, 408, 544, 408, 544, 408, 544, 408, 913, 408, 913, 408, 913, 408, 913, 408, 913, 408, 913, 408, 913, 62, 424, 62, 424, 62, 424, 62, 424, 62, 424, 62, 518, 62, 518, 62, 518, 62, 518, 62, 518, 62, 518, 62, 518, 62, 518, 62, 518, 62, 518, 62, 518, 62, 518, 62, 518,

In [6]:
tokens

tensor([[ 5239,    25, 18435,   616,  1438,   318,   198, 23661,    25,   220,
         51037, 51170, 50319, 50681, 50665, 50801, 50665, 50801, 50665, 50801,
         50665, 50801, 50665, 50775, 50665, 50775, 50665, 50775, 50665, 50775,
         50665, 50775, 50665, 50775, 50665, 50775, 50665, 50775, 50665, 50775,
         50665, 50775, 50665, 50775, 50665, 50775, 50665, 50775, 50665, 51170,
         50665, 51170, 50665, 50775, 50665, 50801, 50665, 50775, 50665, 50775,
         50665, 51170, 50665, 50775, 50665, 50775, 50665, 50775, 50665, 51170,
         50665, 51170, 50319, 50681, 50319, 50681, 50319, 50681, 50319, 50681,
         50319, 50681, 50319, 50681, 50319, 50681, 50319, 50681, 50319, 50681,
         50319, 50681, 50319, 50681, 50319, 50681, 50319, 50681, 50319, 50681,
         50319, 50681, 50319, 50681, 50319, 50681, 50319, 50681, 50319, 50681,
         50319, 50681, 50319, 50681, 50319, 50681, 50319, 50681, 50319, 50681,
         50319, 50681, 50319, 50681, 50319, 50681, 5