In [43]:
from snac import SNAC
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().cuda()
import numpy as np
from tqdm import tqdm

In [44]:
file_path = "speech.wav"
import os
if not os.path.exists(file_path):
    raise FileNotFoundError(f"The file '{file_path}' does not exist. Please make sure the audio file is in the correct location.")

In [45]:
from scipy.io import wavfile
sample_rate, audio_data = wavfile.read(file_path)
print(f"Sample rate: {sample_rate} Hz")
print(f"Some points of audio data: {audio_data[900:1000]}")


Sample rate: 16000 Hz
Some points of audio data: [   1    9  -19    1    7   -5   -5    4  -17    3   12   17  -14   -8
   12   14   -3   -2    7   12   -3   -1    4  -13   13   17  -14  -18
   56    6  -17   59   26    9   16   75  -17  -71  -16   92  127 -147
   84  133 -196  264   54   36 -527 -182  237 -359  373  115 -741  760
 -375 -335  576 -114  119  141 -311  -63  122 -173  223  -91 -131   65
    5  141  145  -26  -43    2 -144 -188  -18 -175   60   70 -158  268
   81   56  412  -45  202    9 -233  -66 -182 -272 -195 -144 -103  -98
 -181   -3]


In [46]:
import torchaudio
print("Available audio backends:", torchaudio.list_audio_backends())
print("Current audio backend:", torchaudio.get_audio_backend())

Available audio backends: ['soundfile']
Current audio backend: None


  print("Current audio backend:", torchaudio.get_audio_backend())


In [47]:

waveform, sr = torchaudio.load(file_path)
waveform = torchaudio.functional.resample(waveform, sr, 24000)
waveform.shape


torch.Size([1, 94668])

In [48]:
from IPython.display import Audio
Audio(waveform.numpy(), rate=24000)

In [65]:
import torch
waveform = waveform.cuda()
with torch.inference_mode():
    codes = model.encode(waveform.unsqueeze(0))
    audio_hat = model.decode(codes)

In [66]:
codes

[tensor([[3149, 1843, 1843, 1843, 1843, 2399, 1843, 1843,  752, 2793,  335, 3647,
          2372, 2896, 2063, 3725, 1823, 2634, 3855, 3040,  910, 2517,  174, 2634,
          3656, 2083, 2262,  910,  541, 3250, 3250, 1823, 1823, 1856, 2262, 1823,
          2005, 1718, 3879, 2262, 2372, 1669,  335,  335,  335, 2005, 3391, 2086,
          1422, 3276,  335, 2634,  528,  147, 2793,  174, 1183, 3593, 2554, 1686,
          1843, 1843, 1843, 2877, 1069, 2768, 1843, 2614, 1937, 2634, 3508,  335,
           174, 3845,  335, 2262, 2262, 1793,  335, 3891, 2988, 1823, 1718, 2579,
          3745, 1823,  551, 2197, 2383, 1728, 2262, 2554, 1843, 1843, 3149,  752,
          1234, 3838, 1234, 1843, 1793, 2086,  335,  402,  335, 1388,  335, 1793,
          2086, 3238, 2634, 2634, 1758, 1793, 3813,  335, 1793, 2692,  335, 2896,
          2692, 2349, 1823, 1823, 2262, 3960,  697, 3593, 3446, 3838, 1234,  752,
           564,  533, 2086,  335, 3766, 1006, 1728,  335, 1793, 3326, 3250, 3326,
          2692, 

In [50]:
from IPython.display import Audio
Audio(audio_hat.cpu().squeeze().numpy(), rate=24000)

In [51]:
def convert_waveform_to_tokens(waveform, verbose=False):
    # waveform is 24k and the output of torchaudio.functional.resample(waveform, sr, 24000)
    waveform = waveform.cuda()
    with torch.inference_mode():
        codes = model.encode(waveform.unsqueeze(0))
        audio_hat = model.decode(codes)

    codes = [codes[i].cpu().numpy().tolist()[0] for i in range(len(codes))]

    all_frames_tokens_list = []
    for i in range(len(codes[0])):
        # Note the order
        frame_tokens = [codes[0][i],]
        frame_tokens.append(codes[1][2*i])
        frame_tokens.extend(codes[2][4*i:4*i+2])
        frame_tokens.append(codes[1][2*i+1])
        frame_tokens.extend(codes[2][4*i+2:4*i+4])

        all_frames_tokens_list.append(frame_tokens)

    if verbose:
        print("First 3 frames of tokens, without offset: ", all_frames_tokens_list[:3])

    offsetted_tokens_list = [ [128256+l[i]+4096*i for i in range(len(l)) ] for l in all_frames_tokens_list]
    offsetted_tokens_list = np.array(offsetted_tokens_list).flatten().tolist()

    return offsetted_tokens_list, audio_hat

def get_hf_formatted_data(speech_tokens, transcript_tokens, tts_token_id, eos_token_id):
    input_ids = transcript_tokens['input_ids'] + [tts_token_id] + speech_tokens + [eos_token_id]
    labels = [-100] * (len(transcript_tokens['input_ids']) + 1) + speech_tokens + [eos_token_id]
    attention_mask = [1] * len(input_ids)
    return input_ids, labels, attention_mask
    

In [52]:
from datasets import load_dataset
import torch
import torchaudio

# Load first 10 examples from the dataset
dataset = load_dataset("parler-tts/mls_eng_10k", split="train")

# Get first example
first_example = dataset[2]

# Print transcript
transcript = first_example['transcript']
print("Transcript:", transcript)

audio_data =first_example['audio']
waveform, sr = audio_data['array'], audio_data['sampling_rate']
print(f"Sample rate: {sr} Hz, waveform shape: {waveform.shape}, length: {waveform.shape[0] / sr}s")

Audio(waveform, rate=sr)







Transcript: old mr toad filled out his queer music bag under his chin and began to sing again peter watched him now it just happened that old mr toad was facing him
Sample rate: 16000 Hz, waveform shape: (193600,), length: 12.1s


In [53]:
print(dataset)
for i in range(5):
    print(dataset[i])

Dataset({
    features: ['audio', 'original_path', 'begin_time', 'end_time', 'transcript', 'audio_duration', 'speaker_id', 'book_id'],
    num_rows: 2420047
})
{'audio': {'path': '204_2274_001678.opus', 'array': array([-0.00820003, -0.00787047, -0.00798318, ..., -0.00168399,
       -0.0013583 , -0.00177855]), 'sampling_rate': 16000}, 'original_path': 'http://www.archive.org/download/outline_science_mfs_0807_librivox/outlinesciencev1_03_thomson_64kb.mp3', 'begin_time': 2058.29, 'end_time': 2074.22, 'transcript': 'here the physiological difference does not affect the body as a whole but the reproductive organs or gonads only though more intimate physiology would doubtless discover differences in the blood or in the chemical routine metabolism', 'audio_duration': 15.929999999999836, 'speaker_id': '204', 'book_id': '2274'}
{'audio': {'path': '2297_2376_000085.opus', 'array': array([-3.85698280e-04, -2.67470576e-04, -4.62332566e-04, ...,
       -5.37955202e-04,  8.03123330e-05, -7.27814913e

In [54]:
# Main tokenization loop and dataset creation
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset
import numpy as np

model_name = "meta-llama/Llama-3.2-3B-Instruct"
llama_tokenizer = AutoTokenizer.from_pretrained(model_name)

tts_token_id = 128256 + 4096*7 
eos_token_id = 128256 + 4096*7 + 1

# Lists to store processed data
all_input_ids = []
all_labels = []
all_attention_masks = []

for i in tqdm(range(5)):
    row = dataset[i]
    
    transcript = row['transcript']
    transcript_tokens = llama_tokenizer(transcript)
    
    audio_data = row['audio']
    waveform, sr = audio_data['array'], audio_data['sampling_rate']

    # Resample to 24kHz
    waveform = torch.from_numpy(waveform).float()
    waveform = waveform.unsqueeze(0)  # Add channel dimension [1, samples]
    waveform = torchaudio.functional.resample(waveform, sr, 24000)

    if i==2: 
        verbose = True 
    else: 
        verbose = False
    speech_tokens, audio_hat = convert_waveform_to_tokens(waveform, verbose=verbose)

    input_ids, labels, attention_mask = get_hf_formatted_data(speech_tokens, transcript_tokens, tts_token_id, eos_token_id)
    
    all_input_ids.append(input_ids)
    all_labels.append(labels)
    all_attention_masks.append(attention_mask)

# Create dataset dictionary
dataset_dict = {
    'input_ids': all_input_ids,
    'labels': all_labels,
    'attention_mask': all_attention_masks
}

# Convert to HuggingFace Dataset
hf_dataset = Dataset.from_dict(dataset_dict)

# Push to hub (uncomment and modify repo_id as needed)
hf_dataset.push_to_hub("UjjD/tts_mini_data")

    

100%|██████████| 5/5 [00:00<00:00, 15.30it/s]


First 3 frames of tokens, without offset:  [[1843, 2918, 3909, 898, 3822, 399, 1472], [752, 2644, 1152, 1527, 663, 681, 2604], [3149, 198, 2442, 3567, 31, 2788, 1353]]


Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 863.91ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00, 14.00it/s]
No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/datasets/UjjD/tts_mini_data/commit/b2c9f192f6b9052ab7c10effc81cc79df8332f3b', commit_message='Upload dataset', commit_description='', oid='b2c9f192f6b9052ab7c10effc81cc79df8332f3b', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/UjjD/tts_mini_data', endpoint='https://huggingface.co', repo_type='dataset', repo_id='UjjD/tts_mini_data'), pr_revision=None, pr_num=None)

In [55]:
Audio(audio_hat.cpu().squeeze().numpy(), rate=24000)

In [69]:
def decode_speech_tokens(speech_tokens):
    # Convert speech tokens into chunks of 7 elements
    speech_tokens_frame_by_frame =[speech_tokens[i:i+7] for i in range(0, len(speech_tokens), 7)]
    speech_tokens_frame_by_frame_deoffsetted = [[l[i]-128256-4096*i for i in range(len(l))] for l in speech_tokens_frame_by_frame]

    l1 = [l[0] for l in speech_tokens_frame_by_frame_deoffsetted]
    l2 = [[l[1],l[4]] for l in speech_tokens_frame_by_frame_deoffsetted]
    l3 = [[l[2],l[3],l[5],l[6]] for l in speech_tokens_frame_by_frame_deoffsetted]
    # Flatten l2 and l3 lists
    l2 = [item for sublist in l2 for item in sublist]
    l3 = [item for sublist in l3 for item in sublist]

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Convert each list to a tensor and reshape to 1 x N
    tensor_l1 = torch.tensor([l1], device=device)
    tensor_l2 = torch.tensor([l2], device=device) 
    tensor_l3 = torch.tensor([l3], device=device)

    codes = [tensor_l1, tensor_l2, tensor_l3]

    return codes

from datasets import load_dataset

dataset = load_dataset("UjjD/tts_mini_data")
example_data_point = dataset['train'][2]

input_ids, labels, attention_mask = example_data_point['input_ids'], example_data_point['labels'], example_data_point['attention_mask']



tts_token_position = input_ids.index(tts_token_id)
transcript_tokens = input_ids[:tts_token_position]
speech_tokens = input_ids[tts_token_position+1:-1]

# Decode and print
print("Transcript tokens:", transcript_tokens)
print("\nDecoded transcript:", llama_tokenizer.decode(transcript_tokens))

output = decode_speech_tokens(speech_tokens)
# Convert lists to tensors and move to CUDA if available

# Combine into list of tensors
codes = 

audio_hat = model.decode(codes)

from IPython.display import Audio
Audio(audio_hat.detach().cpu().squeeze().numpy(), rate=24000)

Transcript tokens: [128000, 820, 17767, 311, 329, 10409, 704, 813, 55641, 4731, 9145, 1234, 813, 46175, 323, 6137, 311, 7936, 1578, 95087, 15746, 1461, 1457, 433, 1120, 7077, 430, 2362, 17767, 311, 329, 574, 13176, 1461]

Decoded transcript: <|begin_of_text|>old mr toad filled out his queer music bag under his chin and began to sing again peter watched him now it just happened that old mr toad was facing him
