In [1]:
from datasets import load_from_disk

dataset = load_from_disk("encoded_dataset")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoTokenizer

# Load the tokenizer
def make_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
    semantic_tokens = [f"<|semantic:{i}|>" for i in range(0,2048)]
    additional_special_tokens = [*semantic_tokens]
    tokenizer.add_special_tokens({
        "additional_special_tokens": additional_special_tokens
    })
    tokenizer.save_pretrained("../checkpoints/smoltts")

# make_tokenizer()
tokenizer = AutoTokenizer.from_pretrained("../checkpoints/smoltts")
tokenizer.use_default_system_prompt = False

In [3]:
len(tokenizer), tokenizer.vocab_size

(51200, 49152)

In [3]:
# Test the tokenizer by encoding and decoding some example text
example_text = "This is a test sentence."
encoded = tokenizer(example_text, return_tensors="pt")
decoded = tokenizer.decode(encoded['input_ids'][0])

# Print the results
dataset["full"][0]

{'file': 'lj-001-0001.wav',
 'spoken_text': 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition',
 'normalized_text': 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition',
 'codes': tensor([[1698, 1719,  204, 1389,  851, 1772,  186, 1307, 1895,  832, 1633,  771,
           648, 1530, 1989, 1574, 1348,  722,  144, 1945,  278, 1109,   29,  611,
            46,  622,  628, 1740,  572,  572,  345, 1989, 1676,  929, 1776,  749,
           313, 1997, 1571,  819, 1238, 1054, 1054, 1135, 1506, 1393,  616, 1702,
           993,  579,  486,  486, 2039,  148,  657,  664,  339,  339,  588,  212,
          1443,   32, 1320, 1549,  440,    8, 1407, 1722, 1650, 1615,  798,  121,
           303,  697,  837,  358, 1882,  440, 1992, 1992,  587,  178,  178, 1627,
          1530,  929, 1610, 1916,  523,  21

In [4]:
import torch

def encode_text(role: str, content: str) -> torch.Tensor:
    sys_line = tokenizer.encode(f"<|im_start|>{role}\n{content}<|im_end|>\n", return_tensors="pt")
    zeros_mask = torch.zeros(8, sys_line.size(1), dtype=sys_line.dtype)
    return torch.cat([sys_line, zeros_mask])

sysprompt = encode_text("system", "Speak out the provided text")

In [5]:
SEMANTIC_OFFSET = tokenizer.encode("<|semantic:0|>")[0]
VQ_WRAPPER = encode_text(role="assistant", content="")[:,-7:-1]

def encode_vq(codes: torch.Tensor) -> torch.Tensor:
    speaker_line = codes[0,:] + SEMANTIC_OFFSET
    vq_block = torch.cat([speaker_line.unsqueeze(0), codes])
    return torch.cat([VQ_WRAPPER[:,:-1], vq_block, VQ_WRAPPER[:,-1].unsqueeze(1)], dim=1)


out = encode_vq(dataset["full"][0]["codes"])
tokenizer.decode(out[0,:])

'<|im_start|>assistant\n<|semantic:1698|><|semantic:1719|><|semantic:204|><|semantic:1389|><|semantic:851|><|semantic:1772|><|semantic:186|><|semantic:1307|><|semantic:1895|><|semantic:832|><|semantic:1633|><|semantic:771|><|semantic:648|><|semantic:1530|><|semantic:1989|><|semantic:1574|><|semantic:1348|><|semantic:722|><|semantic:144|><|semantic:1945|><|semantic:278|><|semantic:1109|><|semantic:29|><|semantic:611|><|semantic:46|><|semantic:622|><|semantic:628|><|semantic:1740|><|semantic:572|><|semantic:572|><|semantic:345|><|semantic:1989|><|semantic:1676|><|semantic:929|><|semantic:1776|><|semantic:749|><|semantic:313|><|semantic:1997|><|semantic:1571|><|semantic:819|><|semantic:1238|><|semantic:1054|><|semantic:1054|><|semantic:1135|><|semantic:1506|><|semantic:1393|><|semantic:616|><|semantic:1702|><|semantic:993|><|semantic:579|><|semantic:486|><|semantic:486|><|semantic:2039|><|semantic:148|><|semantic:657|><|semantic:664|><|semantic:339|><|semantic:339|><|semantic:588|><|seman

In [7]:
VQ_WRAPPER.size(1)

5

In [6]:
from typing import Dict

def tokenize_row(row: Dict) -> Dict[str, torch.Tensor]:
    """
    row["normalized_text"] is a string
    row["codes"] is a torch.Tensor shaped [9, T_vq]
    """
    user_line = encode_text(role="user", content=row["normalized_text"])
    codes_9rows = encode_vq(row["codes"])  # shape [9, T_vq]
    
    # Concatenate system prompt (row=1?), user line (row=1?), codebooks (row=9),
    # but along the *time* dimension => final shape [9, T_total] 
    #   (since sysprompt and user_line are [1, T_something], 
    #    codes_9rows is [9, T_vq], so we pad them to 9 rows if needed)
    # For demonstration, I'm just stacking them. You probably do:
    tokens = torch.cat([sysprompt, user_line, codes_9rows], dim=1)
    # tokens shape => [9, big_T]
    
    # Clone for labels
    labels = tokens.clone()

    # Let's define the "text portion" as sysprompt + user_line only
    text_len = sysprompt.size(1) + user_line.size(1)  # no VQ_WRAPPER or codes

    # ONLY mask codebook rows for that text region
    # row=0 is your "text" row, row=1..8 might be codebooks, or vice versa
    # (Here I'm assuming row=0 is your actual text tokens. 
    #  If it's reversed, tweak accordingly!)
    labels[1:, :text_len] = -100

    # Also mask the final <|im_end|> if that’s how you do it:
    labels[1:, -1] = -100

    # Usually you drop the final time-step from tokens
    return {
        "tokens": tokens[:, :-1],
        "labels": labels
    }
tokenizer.decode(tokenize_row(dataset["full"][0])["tokens"][0,:])

'<|im_start|>system\nSpeak out the provided text<|im_end|>\n<|im_start|>user\nPrinting, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition<|im_end|>\n<|im_start|>assistant\n<|semantic:1698|><|semantic:1719|><|semantic:204|><|semantic:1389|><|semantic:851|><|semantic:1772|><|semantic:186|><|semantic:1307|><|semantic:1895|><|semantic:832|><|semantic:1633|><|semantic:771|><|semantic:648|><|semantic:1530|><|semantic:1989|><|semantic:1574|><|semantic:1348|><|semantic:722|><|semantic:144|><|semantic:1945|><|semantic:278|><|semantic:1109|><|semantic:29|><|semantic:611|><|semantic:46|><|semantic:622|><|semantic:628|><|semantic:1740|><|semantic:572|><|semantic:572|><|semantic:345|><|semantic:1989|><|semantic:1676|><|semantic:929|><|semantic:1776|><|semantic:749|><|semantic:313|><|semantic:1997|><|semantic:1571|><|semantic:819|><|semantic:1238|><|semantic:1054|><|semantic:1054|><|semantic:1135|><|semantic:

In [7]:
dataset = dataset.map(tokenize_row)

Map: 100%|██████████| 13100/13100 [00:07<00:00, 1856.01 examples/s]


In [24]:
dataset["full"][0]

{'file': 'lj-001-0001.wav',
 'spoken_text': 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition',
 'normalized_text': 'Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition',
 'codes': tensor([[1698, 1719,  204, 1389,  851, 1772,  186, 1307, 1895,  832, 1633,  771,
           648, 1530, 1989, 1574, 1348,  722,  144, 1945,  278, 1109,   29,  611,
            46,  622,  628, 1740,  572,  572,  345, 1989, 1676,  929, 1776,  749,
           313, 1997, 1571,  819, 1238, 1054, 1054, 1135, 1506, 1393,  616, 1702,
           993,  579,  486,  486, 2039,  148,  657,  664,  339,  339,  588,  212,
          1443,   32, 1320, 1549,  440,    8, 1407, 1722, 1650, 1615,  798,  121,
           303,  697,  837,  358, 1882,  440, 1992, 1992,  587,  178,  178, 1627,
          1530,  929, 1610, 1916,  523,  21

In [8]:
dataset.save_to_disk("tokenized_dataset")

Saving the dataset (1/1 shards): 100%|██████████| 13100/13100 [00:00<00:00, 17090.23 examples/s]
