In [1]:
import tensorflow as tf
import json

def load_jsonl(path):
    ds = tf.data.TextLineDataset(path)

    def parse_line(line):
        obj = tf.py_function(lambda s: json.loads(s.numpy()), [line], Tout=tf.string)
        return obj

    return ds

dataset = load_jsonl("mixed_dataset.jsonl")

2025-11-21 17:15:46.948563: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763725545.959142    4958 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763725545.992891    4958 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1763725546.288058    4958 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763725546.288217    4958 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1763725546.288220    4958 computation_placer.cc:177] computation placer alr

In [30]:
import sentencepiece as spm

spm.SentencePieceTrainer.Train(
    input='mixed_dataset.jsonl',
    model_prefix='tinyllm',
    vocab_size=16000,
    character_coverage=1.0,
    model_type='bpe',
    pad_id=0,
    unk_id=1,
    bos_id=2,
    eos_id=3,
    user_defined_symbols=["<user>:", "<assistant>:"]
)

print("Tokenizer trained: tinyllm.model, tinyllm.vocab")

Tokenizer trained: tinyllm.model, tinyllm.vocab


sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: mixed_dataset.jsonl
  input_format: 
  model_prefix: tinyllm
  model_type: BPE
  vocab_size: 16000
  self_test_sample_size: 0
  character_coverage: 1
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  user_defined_symbols: <user>:
  user_defined_symbols: <assistant>:
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 1
  bos_id: 2
  eos_id: 3
  pad_id: 0
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece:

In [None]:
sp = spm.SentencePieceProcessor()
sp.load("tinyllm.model")

MAX_LEN = 384
PAD_ID = 0 if sp.pad_id() < 0 else sp.pad_id()

In [32]:
def encode_example(text):
    ids = sp.encode(text, out_type=int)

    if len(ids) > MAX_LEN:
        ids = ids[:MAX_LEN]

    x = ids[:-1]
    y = ids[1:]

    # pad
    x = x + [PAD_ID] * (MAX_LEN - len(x))
    y = y + [PAD_ID] * (MAX_LEN - len(y))

    return x, y

In [33]:
def tf_load_dataset(path, batch_size=8):
    ds = tf.data.TextLineDataset(path)

    def parse_json(line):
        obj = json.loads(line.numpy().decode("utf-8"))
        return obj["text"]

    def tf_parse_json(line):
        text = tf.py_function(parse_json, [line], Tout=tf.string)
        text.set_shape([])  # VERY IMPORTANT
        return text

    def tf_encode(text):
        x, y = tf.py_function(
            lambda t: encode_example(t.numpy().decode("utf-8")),
            [text],
            [tf.int32, tf.int32]
        )
        # Set shapes MANUALLY (the fix for your error)
        x.set_shape([MAX_LEN])
        y.set_shape([MAX_LEN])
        return x, y

    ds = ds.map(tf_parse_json, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.map(tf_encode, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.shuffle(20000)
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds


In [34]:
train_ds = tf_load_dataset("mixed_dataset.jsonl")

In [None]:
import tensorflow as tf

class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.att = tf.keras.layers.MultiHeadAttention(num_heads, key_dim=d_model)
        self.ln1 = tf.keras.layers.LayerNormalization()
        self.ff = tf.keras.Sequential([
            tf.keras.layers.Dense(d_ff, activation="gelu"),
            tf.keras.layers.Dense(d_model)
        ])
        self.ln2 = tf.keras.layers.LayerNormalization()

    def call(self, x, causal_mask):
        attn = self.att(x, x, attention_mask=causal_mask)
        x = self.ln1(x + attn)
        ff_out = self.ff(x)
        return self.ln2(x + ff_out)

def build_model(vocab_size=16000, max_len=256, d_model=384, layers=4, heads=6):
    inputs = tf.keras.Input(shape=(max_len,), dtype=tf.int32)

    embed = tf.keras.layers.Embedding(vocab_size, d_model)(inputs)


    pos = tf.range(max_len)[tf.newaxis, :]
    pos_embed = tf.keras.layers.Embedding(max_len, d_model)(pos)
    x = embed + pos_embed

    mask = tf.linalg.band_part(tf.ones((max_len, max_len)), -1, 0)
    mask = mask[tf.newaxis, tf.newaxis, :, :]

    for _ in range(layers):
        x = DecoderBlock(d_model, heads, d_ff=4 * d_model)(x, mask)

    logits = tf.keras.layers.Dense(vocab_size)(x)

    model = tf.keras.Model(inputs, logits)
    return model

In [36]:
model = build_model()
model.summary()

In [None]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True
)

optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)

model.compile(
    optimizer=optimizer,
    loss=loss_fn
)

model.fit(
    train_ds,
    epochs=10
)

Epoch 1/3


2025-11-21 20:05:27.903155: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:35: Filling up shuffle buffer (this may take a while): 5488 of 20000
2025-11-21 20:05:47.900474: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:35: Filling up shuffle buffer (this may take a while): 16784 of 20000
2025-11-21 20:05:52.139439: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.



[1m6934/6934[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1129s[0m 155ms/step - loss: 5.2475
Epoch 2/3


2025-11-21 20:23:53.738025: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7346126858096581346
2025-11-21 20:24:03.787420: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:35: Filling up shuffle buffer (this may take a while): 5615 of 20000
2025-11-21 20:24:23.786360: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:35: Filling up shuffle buffer (this may take a while): 17169 of 20000
2025-11-21 20:24:29.259591: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


[1m6934/6934[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1103s[0m 154ms/step - loss: 5.6235
Epoch 3/3


2025-11-21 20:42:16.724775: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
	 [[{{node IteratorGetNext}}]]
	 [[IteratorGetNext/_4]]
2025-11-21 20:42:16.724835: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7346126858096581346
2025-11-21 20:42:16.724846: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 11863766883536092728
2025-11-21 20:42:26.762997: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:35: Filling up shuffle buffer (this may take a while): 5668 of 20000
2025-11-21 20:42:46.759304: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:452] ShuffleDatasetV3:35: Filling up shuffle buffer (this may take a while): 17787 of 20000
2025-11-21 20:42:51.177063: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:482] Shuffle buffer filled.


[1m6934/6934[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1100s[0m 154ms/step - loss: 5.6190


2025-11-21 21:00:37.043210: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 7346126858096581346
2025-11-21 21:00:37.043264: I tensorflow/core/framework/local_rendezvous.cc:426] Local rendezvous recv item cancelled. Key hash: 11863766883536092728


<keras.src.callbacks.history.History at 0x7fe1592d63e0>

In [28]:
model.save('checkpoints/model_fixed_3.keras')

In [39]:
def generate(prompt, max_new=50):
    ids = sp.encode(prompt, out_type=int)
    
    for _ in range(max_new):
        x = ids[-256:]
        pad = [sp.pad_id()] * (256 - len(x))
        x = pad + x
        x = tf.constant([x])

        seq_len = len(ids)
        logits = model(x)[0, x.shape[1] - 1]  
        next_id = int(tf.argmax(logits))
        ids.append(next_id)

        if next_id == sp.eos_id():
            break

    return sp.decode(ids)

print(generate("<user>: hey\n<assistant>:"))

<user>: hey <assistant>:
