In [1]:
from datasets import load_from_disk, DatasetDict, concatenate_datasets

# dataset = load_from_disk("encoded_dataset")
train_clean_100 = load_from_disk("encoded_libritts/train.clean.100/")
train_clean_360 = load_from_disk("encoded_libritts/train.clean.360/")
dev_clean = load_from_disk("encoded_libritts/dev.clean")
test_clean = load_from_disk("encoded_libritts/test.clean")
full_train = concatenate_datasets([train_clean_100, train_clean_360])

dataset = DatasetDict({
    "train": full_train,
    "val": dev_clean,
    "test": test_clean
})
dataset = dataset.with_format("torch")
dataset = dataset.remove_columns(["path", "chapter_id", "text_original"])
dataset = dataset.rename_column(original_column_name="text_normalized", new_column_name="normalized_text")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoTokenizer

# Load the tokenizer
def make_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
    semantic_tokens = [f"<|semantic:{i}|>" for i in range(0,2048)]
    additional_special_tokens = [*semantic_tokens]
    tokenizer.add_special_tokens({
        "additional_special_tokens": additional_special_tokens
    })
    tokenizer.save_pretrained("../checkpoints/smoltts")

# make_tokenizer()
tokenizer = AutoTokenizer.from_pretrained("../checkpoints/smoltts")
tokenizer.use_default_system_prompt = False

In [3]:
len(tokenizer), tokenizer.vocab_size

(51200, 49152)

In [4]:
# Test the tokenizer by encoding and decoding some example text
example_text = "This is a test sentence."
encoded = tokenizer(example_text, return_tensors="pt")
decoded = tokenizer.decode(encoded['input_ids'][0])

# Print the results
dataset["test"][0]

{'normalized_text': 'I felt it in my bones when I woke this morning that something splendid was going to turn up.',
 'speaker_id': '4446',
 'id': '4446_2275_000002_000009',
 'codes': tensor([[1049, 1268,  549, 1324,  668, 1538, 1593,   95,  629, 1281, 1281,  680,
           536,  536,  230, 1018, 1117,  244,  507,  997, 1399,  640, 1591, 1967,
          1161,  690,   67, 1772,  830, 1612,  561,  119, 1052,  880, 1029, 1532,
          1161, 1344, 1109,    6, 1001,  382,  596,   99, 1726, 2030,  531,  616,
           367, 1271, 1868,  978,  729,  396, 1544],
         [1470, 1879,  712,  283,  220,  137, 1610,  263,  531, 1845, 1428, 1132,
           359, 1904, 1458, 1876,  895,  149,  190,  116,  603,  786, 1884, 1455,
          1928,  677,  914, 1122,  436,  618,  850, 1766, 2005, 1618,  966,  850,
          1663,  172,  274,  612, 1013, 1928, 1262, 1169, 1006, 1777, 1755, 2026,
          1714,  788,  786, 1520,  811,   91, 1700],
         [ 373,  602, 2016, 1148,   98,  790, 1570,  944

In [20]:
import torch

def encode_text(role: str, content: str, needs_initial_newline=False) -> torch.Tensor:
    sys_line = tokenizer.encode(f"{chr(10) if needs_initial_newline else ''}<|im_start|>{role}\n{content}<|im_end|>\n", return_tensors="pt")
    zeros_mask = torch.zeros(8, sys_line.size(1), dtype=sys_line.dtype)
    return torch.cat([sys_line, zeros_mask])

tts_sysprompt = encode_text("system", "Speak out the provided text")
asr_sysprompt = encode_text("system", "Transcribe the provided speech")

In [53]:
SEMANTIC_OFFSET = tokenizer.encode("<|semantic:0|>")[0]
VQ_ASSISTANT_WRAPPER = encode_text(role="assistant", content="")[:,-7:-1]
VQ_USER_WRAPPER = encode_text(role="user", content="")[:,:-1]

def encode_vq(codes: torch.Tensor, is_assistant=True) -> torch.Tensor:
    speaker_line = codes[0,:] + SEMANTIC_OFFSET
    vq_block = torch.cat([speaker_line.unsqueeze(0), codes])
    wrapper = VQ_ASSISTANT_WRAPPER if is_assistant else VQ_USER_WRAPPER
    # print(f"VQ BLOCK: {vq_block.shape}, WRAPPER: {wrapper.shape}, is_assistant: {is_assistant}")
    return torch.cat([wrapper[:,:-1], vq_block, wrapper[:,-1].unsqueeze(1)], dim=1)


out = encode_vq(dataset["test"][0]["codes"], is_assistant=True)
tokenizer.decode(out[0,:])

'<|im_start|>assistant\n<|semantic:1049|><|semantic:1268|><|semantic:549|><|semantic:1324|><|semantic:668|><|semantic:1538|><|semantic:1593|><|semantic:95|><|semantic:629|><|semantic:1281|><|semantic:1281|><|semantic:680|><|semantic:536|><|semantic:536|><|semantic:230|><|semantic:1018|><|semantic:1117|><|semantic:244|><|semantic:507|><|semantic:997|><|semantic:1399|><|semantic:640|><|semantic:1591|><|semantic:1967|><|semantic:1161|><|semantic:690|><|semantic:67|><|semantic:1772|><|semantic:830|><|semantic:1612|><|semantic:561|><|semantic:119|><|semantic:1052|><|semantic:880|><|semantic:1029|><|semantic:1532|><|semantic:1161|><|semantic:1344|><|semantic:1109|><|semantic:6|><|semantic:1001|><|semantic:382|><|semantic:596|><|semantic:99|><|semantic:1726|><|semantic:2030|><|semantic:531|><|semantic:616|><|semantic:367|><|semantic:1271|><|semantic:1868|><|semantic:978|><|semantic:729|><|semantic:396|><|semantic:1544|><|im_end|>'

In [66]:
from typing import Dict

ASSISTANT_PREFIX_LEN = len(tokenizer.tokenize("<|im_start|>assistant\n"))
USER_PREFIX_LEN  = len(tokenizer.tokenize("<|im_start|>user\n"))

def tokenize_row(row: Dict, is_batch=True):
    """
    row["normalized_text"] is a string
    row["codes"] is a torch.Tensor shaped [9, T_vq]
    """
    row = {
        "normalized_text": row["normalized_text"][0],
        "codes": row["codes"][0],
        "speaker_id": row["speaker_id"],
        "id": row["id"]
    } if is_batch else row
    tts_user_line = encode_text(role="user", content=row["normalized_text"])
    asr_assistant_line = encode_text(role="assistant", content=row["normalized_text"], needs_initial_newline=True)
    tts_assistant_codes = encode_vq(row["codes"])  # shape [9, T_vq]
    asr_user_codes = encode_vq(row["codes"], is_assistant=False)  # shape [9, T_vq]
    
    # Concatenate system prompt (row=1?), user line (row=1?), codebooks (row=9),
    # but along the *time* dimension => final shape [9, T_total] 
    #   (since sysprompt and user_line are [1, T_something], 
    #    codes_9rows is [9, T_vq], so we pad them to 9 rows if needed)
    # For demonstration, I'm just stacking them. You probably do:
    tts_ground_truth = torch.cat([tts_sysprompt, tts_user_line, tts_assistant_codes], dim=1)
    asr_ground_truth = torch.cat([asr_sysprompt, asr_user_codes, asr_assistant_line], dim=1)
    tts_tokens = tts_ground_truth[:,:-1].clone()
    asr_tokens = asr_ground_truth[:,:-1].clone()
    # Clone for labels
    tts_labels = tts_ground_truth[:, 1:].clone()
    asr_labels = asr_ground_truth[:, 1:].clone()

    # TTS MASKING (easy)
    # labels = asr_ground_truth[:, 1:].clone()
    # Let's define the "text portion" as sysprompt + user_line only
    text_len = tts_sysprompt.size(1) + tts_user_line.size(1) + ASSISTANT_PREFIX_LEN - 1  # no VQ_WRAPPER or codes
    # ONLY mask codebook rows for that text region
    # row=0 is your "text" row, row=1..8 might be codebooks, or vice versa
    # (Here I'm assuming row=0 is your actual text tokens. 
    #  If it's reversed, tweak accordingly!)
    tts_labels[1:, :text_len] = -100

    asr_start_len = asr_sysprompt.size(1) + USER_PREFIX_LEN - 1
    asr_labels[1:, :asr_start_len] = -100
    asr_labels[1:, -asr_assistant_line.size(1):] = -100

    out = {
        "tokens": [tts_tokens, asr_tokens],
        "labels": [tts_labels, asr_labels],
        "task": ["tts", "asr"],
        "normalized_text": [row["normalized_text"]] * 2,
        "speaker_id": row["speaker_id"] * 2,
        "id": row["id"] * 2,
    }
    return out

example_row = tokenize_row(dataset["test"][0], is_batch=False)
tokenizer.decode(example_row["tokens"][0][0,:])

'<|im_start|>system\nSpeak out the provided text<|im_end|>\n<|im_start|>user\nI felt it in my bones when I woke this morning that something splendid was going to turn up.<|im_end|>\n<|im_start|>assistant\n<|semantic:1049|><|semantic:1268|><|semantic:549|><|semantic:1324|><|semantic:668|><|semantic:1538|><|semantic:1593|><|semantic:95|><|semantic:629|><|semantic:1281|><|semantic:1281|><|semantic:680|><|semantic:536|><|semantic:536|><|semantic:230|><|semantic:1018|><|semantic:1117|><|semantic:244|><|semantic:507|><|semantic:997|><|semantic:1399|><|semantic:640|><|semantic:1591|><|semantic:1967|><|semantic:1161|><|semantic:690|><|semantic:67|><|semantic:1772|><|semantic:830|><|semantic:1612|><|semantic:561|><|semantic:119|><|semantic:1052|><|semantic:880|><|semantic:1029|><|semantic:1532|><|semantic:1161|><|semantic:1344|><|semantic:1109|><|semantic:6|><|semantic:1001|><|semantic:382|><|semantic:596|><|semantic:99|><|semantic:1726|><|semantic:2030|><|semantic:531|><|semantic:616|><|semant

In [59]:
dataset["train"][0]

{'normalized_text': '[The moon] I gazed with a kind of wonder.',
 'speaker_id': '730',
 'id': '730_358_000003_000002',
 'codes': tensor([[1049, 1102, 1686, 1258, 1258, 1689, 1528, 1987,  978,  312, 2039,  753,
           969,  598, 1084, 1268,  621, 1757,  560, 1734, 1527, 1117,  622,  628,
           510,  623,  623,  918,  689,  997, 1069, 1941,  294,  774,  518, 1987,
           769],
         [ 363, 1427, 1427, 1427, 1525, 1452, 1332,  363,   91,  441,  441,  441,
           363,  363,  212, 1834, 2005,  144, 1507, 1772, 1501,  786, 1891,  543,
          1646, 1181, 1097,  941, 1028,    8, 1097, 1772, 1252, 1731, 1182, 1520,
           243],
         [ 373, 1031,  314, 1916,   17,  681, 1370, 2016, 2016, 1400,  457, 2016,
          1569, 1569, 1796, 1350, 1370,   88,   88,   88,  311,  329, 1738, 1822,
          1017, 2025, 1453,  324,  568, 1787, 1517,  999,  461,   61,  231, 2016,
           265],
         [1477,  176, 1056, 2024, 1269,  349,  104, 1521,  417, 1422,  125, 1229,
 

In [67]:
# DO NOT INCREASE batch size
dataset = dataset.map(tokenize_row, remove_columns="codes", batched=True, batch_size=1)

Map:   0%|          | 0/149658 [00:00<?, ? examples/s]

Map: 100%|██████████| 149658/149658 [03:06<00:00, 803.39 examples/s]
Map: 100%|██████████| 5736/5736 [00:07<00:00, 756.32 examples/s]
Map: 100%|██████████| 4837/4837 [00:06<00:00, 734.19 examples/s]


In [29]:
dataset.save_to_disk("tokenized_libritts_bijection")

Saving the dataset (6/6 shards): 100%|██████████| 149658/149658 [00:01<00:00, 97881.60 examples/s] 
Saving the dataset (1/1 shards): 100%|██████████| 5736/5736 [00:00<00:00, 101659.04 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 4837/4837 [00:00<00:00, 89127.21 examples/s]


In [22]:
tokenizer.decode(example_row[1]["tokens"][0,:])

'<|im_start|>system\nTranscribe the provided speech<|im_end|>\n<|im_start|>user\n<|semantic:1049|><|semantic:1268|><|semantic:549|><|semantic:1324|><|semantic:668|><|semantic:1538|><|semantic:1593|><|semantic:95|><|semantic:629|><|semantic:1281|><|semantic:1281|><|semantic:680|><|semantic:536|><|semantic:536|><|semantic:230|><|semantic:1018|><|semantic:1117|><|semantic:244|><|semantic:507|><|semantic:997|><|semantic:1399|><|semantic:640|><|semantic:1591|><|semantic:1967|><|semantic:1161|><|semantic:690|><|semantic:67|><|semantic:1772|><|semantic:830|><|semantic:1612|><|semantic:561|><|semantic:119|><|semantic:1052|><|semantic:880|><|semantic:1029|><|semantic:1532|><|semantic:1161|><|semantic:1344|><|semantic:1109|><|semantic:6|><|semantic:1001|><|semantic:382|><|semantic:596|><|semantic:99|><|semantic:1726|><|semantic:2030|><|semantic:531|><|semantic:616|><|semantic:367|><|semantic:1271|><|semantic:1868|><|semantic:978|><|semantic:729|><|semantic:396|><|semantic:1544|><|im_end|>\n<|im_

In [24]:
tokenizer.decode(example_row[0]["labels"][0,:])

'system\nSpeak out the provided text<|im_end|>\n<|im_start|>user\nI felt it in my bones when I woke this morning that something splendid was going to turn up.<|im_end|>\n<|im_start|>assistant\n<|semantic:1049|><|semantic:1268|><|semantic:549|><|semantic:1324|><|semantic:668|><|semantic:1538|><|semantic:1593|><|semantic:95|><|semantic:629|><|semantic:1281|><|semantic:1281|><|semantic:680|><|semantic:536|><|semantic:536|><|semantic:230|><|semantic:1018|><|semantic:1117|><|semantic:244|><|semantic:507|><|semantic:997|><|semantic:1399|><|semantic:640|><|semantic:1591|><|semantic:1967|><|semantic:1161|><|semantic:690|><|semantic:67|><|semantic:1772|><|semantic:830|><|semantic:1612|><|semantic:561|><|semantic:119|><|semantic:1052|><|semantic:880|><|semantic:1029|><|semantic:1532|><|semantic:1161|><|semantic:1344|><|semantic:1109|><|semantic:6|><|semantic:1001|><|semantic:382|><|semantic:596|><|semantic:99|><|semantic:1726|><|semantic:2030|><|semantic:531|><|semantic:616|><|semantic:367|><|se

In [25]:
tokenizer.decode(example_row[1]["labels"][0,:])

'system\nTranscribe the provided speech<|im_end|>\n<|im_start|>user\n<|semantic:1049|><|semantic:1268|><|semantic:549|><|semantic:1324|><|semantic:668|><|semantic:1538|><|semantic:1593|><|semantic:95|><|semantic:629|><|semantic:1281|><|semantic:1281|><|semantic:680|><|semantic:536|><|semantic:536|><|semantic:230|><|semantic:1018|><|semantic:1117|><|semantic:244|><|semantic:507|><|semantic:997|><|semantic:1399|><|semantic:640|><|semantic:1591|><|semantic:1967|><|semantic:1161|><|semantic:690|><|semantic:67|><|semantic:1772|><|semantic:830|><|semantic:1612|><|semantic:561|><|semantic:119|><|semantic:1052|><|semantic:880|><|semantic:1029|><|semantic:1532|><|semantic:1161|><|semantic:1344|><|semantic:1109|><|semantic:6|><|semantic:1001|><|semantic:382|><|semantic:596|><|semantic:99|><|semantic:1726|><|semantic:2030|><|semantic:531|><|semantic:616|><|semantic:367|><|semantic:1271|><|semantic:1868|><|semantic:978|><|semantic:729|><|semantic:396|><|semantic:1544|><|im_end|>\n<|im_start|>assis

In [21]:
example_row["tokens"][0][1,:]

tensor([   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0, 1698, 1719,  204, 1389,  851, 1772,  186, 1307, 1895,  832,
        1633,  771,  648, 1530, 1989, 1574, 1348,  722,  144, 1945,  278, 1109,
          29,  611,   46,  622,  628, 1740,  572,  572,  345, 1989, 1676,  929,
        1776,  749,  313, 1997, 1571,  819, 1238, 1054, 1054, 1135, 1506, 1393,
         616, 1702,  993,  579,  486,  486, 2039,  148,  657,  664,  339,  339,
         588,  212, 1443,   32, 1320, 1549,  440,    8, 1407, 1722, 1650, 1615,
         798,  121,  303,  697,  837,  358, 1882,  440, 1992, 1992,  587,  178,
         178, 1627, 1530,  929, 1610, 1916,  523,  213, 1252, 1480, 1468, 1899,
         773, 2033, 2033,   83, 1146,  7

In [28]:
example_row[1]["labels"][1,:]

tensor([-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, 1049, 1268,  549, 1324,  668, 1538, 1593,   95,  629, 1281, 1281,
         680,  536,  536,  230, 1018, 1117,  244,  507,  997, 1399,  640, 1591,
        1967, 1161,  690,   67, 1772,  830, 1612,  561,  119, 1052,  880, 1029,
        1532, 1161, 1344, 1109,    6, 1001,  382,  596,   99, 1726, 2030,  531,
         616,  367, 1271, 1868,  978,  729,  396, 1544,    0, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100])

In [28]:
dataset["test"][0]

{'normalized_text': 'I felt it in my bones when I woke this morning that something splendid was going to turn up.',
 'speaker_id': '4446',
 'id': '4446_2275_000002_000009',
 'tokens': tensor([[    1,  9690,   198, 15024,   494,   578,   260,  2711,  1694,     2,
            198,     1,  4093,   198,    57,  4592,   357,   281,   957,  6542,
            645,   339, 40652,   451,  5738,   338,  1488, 33494,   436,  2045,
            288,  1607,   614,    30,     2,   198,     1,   520,  9531,   198,
          50201, 50420, 49701, 50476, 49820, 50690, 50745, 49247, 49781, 50433,
          50433, 49832, 49688, 49688, 49382, 50170, 50269, 49396, 49659, 50149,
          50551, 49792, 50743, 51119, 50313, 49842, 49219, 50924, 49982, 50764,
          49713, 49271, 50204, 50032, 50181, 50684, 50313, 50496, 50261, 49158,
          50153, 49534, 49748, 49251, 50878, 51182, 49683, 49768, 49519, 50423,
          51020, 50130, 49881, 49548, 50696],
         [    0,     0,     0,     0,     0,     0,