# Following the course at https://github.com/rwitten/HighPerfLLMs2024

In [61]:
# Hyperparameters.
seed = 1337
batch_size = 8
context_length = 128

In [19]:
import tensorflow_datasets as tfds
ds = tfds.load('lm1b', split='train', shuffle_files=False)

In [None]:
# Try to train a tokenizer with a small batch
import sentencepiece as spm
def sentence_generator():
    """Generator that yields sentences from the LM1B dataset."""
    for i, example in enumerate(tfds.as_numpy(ds)):
        if i > 100000:
            return
        # The text field might be bytes, so decode it if needed.
        text = example['text']
        if isinstance(text, bytes):
            text = text.decode('utf-8')
        yield text

# Train SentencePiece using the sentence iterator.
spm.SentencePieceTrainer.train(
    model_prefix='data/lm1b_tokenizer',
    sentence_iterator=sentence_generator(),  # Use our generator instead of an input file.
    vocab_size=1024,
    character_coverage=0.9995,   # Adjust character coverage if needed.
    model_type='bpe'         # You can also choose 'bpe', 'char', or 'word'.
)


sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input_format: 
  model_prefix: data/lm1b_tokenizer
  model_type: BPE
  vocab_size: 1024
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
  enable_differential_privacy: 0
  differential_privacy

In [64]:
sp = spm.SentencePieceProcessor()
sp.load('data/lm1b_tokenizer.model')

import jax
import jax.numpy as jnp
import numpy as np

def get_encoded_batch():
    batched_ds = ds.batch(batch_size)
    for batch in tfds.as_numpy(batched_ds):
        batch = np.vectorize(lambda x: x.decode('utf-8'))(batch['text'])
        batch = [sp.encode(x)[:context_length] for x in batch]
        batch = [x + [0] * (context_length - len(x)) for x in batch]
        batch = jnp.asarray(batch)
        yield batch


print(next(iter(get_encoded_batch())).shape)

(8, 128)


In [None]:
import flax.nnx as nnx

class LangaugeModel()