In [29]:
from datasets import load_from_disk

dataset = load_from_disk("tokenized_dataset")
dataset.shuffle(42)
dataset["full"][0]["labels"].shape, dataset["full"][0]["tokens"].shape

(torch.Size([9, 172]), torch.Size([9, 171]))

In [35]:
import torch
from torch.utils.data import DataLoader

def collate_fn(batch):
    # Get max length in batch
    max_input_len = max(x['tokens'].shape[1] for x in batch)
    
    # Prepare empty tensors for batch
    tokens = torch.full((len(batch), 9, max_input_len), fill_value=2, dtype=torch.long)  # 2 for <|im_end|>
    labels = torch.full((len(batch), 9, max_input_len + 1), fill_value=-100, dtype=torch.long)
    
    # Fill in actual values
    for i, item in enumerate(batch):
        seq_len = item['tokens'].shape[1]
        tokens[i, :, :seq_len] = item['tokens'].clone().detach()
        labels[i, :, :seq_len+1] = item['labels'].clone().detach()
    
    return {'tokens': tokens, 'labels': labels}


In [36]:
from transformers import AutoTokenizer

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("../checkpoints/smoltts")

# Setup dataloader
dataloader = DataLoader(dataset['full'], batch_size=8, collate_fn=collate_fn, shuffle=True)

# Get first batch
batch = next(iter(dataloader))

# Look at first item in batch
first_item = batch['tokens'][0]  # Should be [9, seq_len]

# Get semantic tokens (first row)
semantic_tokens = first_item[0]  # Just the first index for semantic tokens

# Remove padding (zeros) if any
semantic_tokens = semantic_tokens[semantic_tokens != 0]

# Decode
decoded = tokenizer.decode(semantic_tokens)
print("Decoded text:")
print(decoded)

# Optional: print shapes to verify
print("\nShapes:")
print(f"Full batch shape: {batch['tokens'].shape}")
print(f"First item shape: {first_item.shape}")
print(f"Semantic tokens shape: {semantic_tokens.shape}")

Decoded text:
<|im_start|>system
Speak out the provided text<|im_end|>
<|im_start|>user
whence it deduced the practice and condition of every prison that replied.<|im_end|>
<|im_start|>assistant
<|semantic:1415|><|semantic:268|><|semantic:561|><|semantic:523|><|semantic:1300|><|semantic:942|><|semantic:170|><|semantic:1309|><|semantic:54|><|semantic:1269|><|semantic:1274|><|semantic:1326|><|semantic:1658|><|semantic:366|><|semantic:366|><|semantic:313|><|semantic:1899|><|semantic:146|><|semantic:238|><|semantic:1228|><|semantic:1534|><|semantic:300|><|semantic:1558|><|semantic:1054|><|semantic:1385|><|semantic:54|><|semantic:1379|><|semantic:1840|><|semantic:1517|><|semantic:410|><|semantic:1781|><|semantic:1508|><|semantic:552|><|semantic:1600|><|semantic:1600|><|semantic:1639|><|semantic:313|><|semantic:1997|><|semantic:1985|><|semantic:819|><|semantic:150|><|semantic:1487|><|semantic:1612|><|semantic:325|><|semantic:1910|><|semantic:858|><|semantic:157|><|semantic:157|><|semantic:83

In [37]:
from model.dual_ar import DualARTransformer

model = DualARTransformer.from_pretrained("../checkpoints/smoltts", load_weights=True)

Loading model from ../checkpoints/smoltts, config: DualARModelArgs(model_type='dual_ar', vocab_size=51200, n_layer=30, n_head=9, dim=576, intermediate_size=1536, n_local_heads=3, head_dim=64, rope_base=100000, norm_eps=1e-05, max_seq_len=8192, dropout=0.0, tie_word_embeddings=True, attention_qkv_bias=False, codebook_size=2048, num_codebooks=8, use_gradient_checkpointing=True, initializer_range=0.041666666666666664, is_reward_model=False, share_codebook_embeddings=True, scale_codebook_embeddings=False, n_fast_layer=4, fast_dim=576, fast_n_head=9, fast_n_local_heads=3, fast_head_dim=64, fast_intermediate_size=1536, fast_attention_qkv_bias=False)
Loaded weights with error: <All keys matched successfully>


In [38]:
BSZ = 8
len(dataset["full"]) / BSZ 

1637.5