In [None]:
import tensorflow as tf
import json

def load_jsonl(path):
    ds = tf.data.TextLineDataset(path)

    def parse_line(line):
        obj = tf.py_function(lambda s: json.loads(s.numpy()), [line], Tout=tf.string)
        return obj

    return ds.map(parse_line)
dataset = load_jsonl("mixed_dataset.jsonl")

In [None]:
import sentencepiece as spm

spm.SentencePieceTrainer.Train(
    input='mixed_dataset.jsonl',
    model_prefix='tinyllm',
    vocab_size=16000,
    character_coverage=1.0,
    model_type='bpe',
    pad_id=0,
    unk_id=1,
    bos_id=2,
    eos_id=3,
    user_defined_symbols=["<user>", "<assistant>"]
)

print("Tokenizer trained: tinyllm.model, tinyllm.vocab")

Tokenizer trained: tinyllm.model, tinyllm.vocab


In [None]:
sp = spm.SentencePieceProcessor()
sp.load("tinyllm.model")

MAX_LEN = 256
PAD_ID = 0 if sp.pad_id() < 0 else sp.pad_id()

In [None]:
def encode_example(text):
    ids = sp.encode(text, out_type=int)

    if len(ids) > MAX_LEN:
        ids = ids[:MAX_LEN]

    x = ids[:-1]
    y = ids[1:]

    # pad
    x = x + [PAD_ID] * (MAX_LEN - len(x))
    y = y + [PAD_ID] * (MAX_LEN - len(y))

    return x, y

In [None]:
def tf_load_dataset(path, batch_size=8):
    ds = tf.data.TextLineDataset(path)

    def parse_json(line):
        obj = json.loads(line.numpy().decode("utf-8"))
        return obj["text"]

    def tf_parse_json(line):
        text = tf.py_function(parse_json, [line], Tout=tf.string)
        text.set_shape([])  # VERY IMPORTANT
        return text

    def tf_encode(text):
        x, y = tf.py_function(
            lambda t: encode_example(t.numpy().decode("utf-8")),
            [text],
            [tf.int32, tf.int32]
        )
        # Set shapes MANUALLY (the fix for your error)
        x.set_shape([MAX_LEN])
        y.set_shape([MAX_LEN])
        return x, y

    ds = ds.map(tf_parse_json, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.map(tf_encode, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.shuffle(20000)
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    return ds


In [None]:
train_ds = tf_load_dataset("mixed_dataset.jsonl")

In [None]:
import tensorflow as tf

class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.att = tf.keras.layers.MultiHeadAttention(
          num_heads=num_heads,
          key_dim=d_model // num_heads
        )

        self.ln1 = tf.keras.layers.LayerNormalization()
        self.ff = tf.keras.Sequential([
            tf.keras.layers.Dense(d_ff, activation="gelu"),
            tf.keras.layers.Dense(d_model)
        ])
        self.ln2 = tf.keras.layers.LayerNormalization()

    def call(self, x, causal_mask):
        #attn = self.att(x, x, attention_mask=causal_mask[:, :, :tf.shape(x)[1], :tf.shape(x)[1]])
        attn = self.att(x, x, attention_mask=causal_mask)
        x = self.ln1(x + attn)
        ff_out = self.ff(x)
        return self.ln2(x + ff_out)

def build_model(vocab_size=16000, max_len=256, d_model=384, layers=6, heads=6):
    inputs = tf.keras.Input(shape=(max_len,), dtype=tf.int32)

    embed = tf.keras.layers.Embedding(vocab_size, d_model)(inputs)


    pos = tf.range(max_len)[tf.newaxis, :]
    pos_embed = tf.keras.layers.Embedding(max_len, d_model)(pos)
    x = embed + pos_embed

    mask = tf.linalg.band_part(tf.ones((max_len, max_len)), -1, 0)
    mask = tf.cast(mask, tf.bool)  # Keras wants boolean mask


    for _ in range(layers):
        x = DecoderBlock(d_model, heads, d_ff=4 * d_model)(x, mask)

    logits = tf.keras.layers.Dense(vocab_size)(x)

    model = tf.keras.Model(inputs, logits)
    return model

In [None]:
model = build_model()
model.summary()

In [None]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True
)

optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)

model.compile(
    optimizer=optimizer,
    loss=loss_fn
)

model.fit(
    train_ds,
    epochs=3,
    batch_size=32
)
model.save('model_9.keras')

Epoch 1/3
[1m6934/6934[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m660s[0m 88ms/step - loss: 1.7427
Epoch 2/3
[1m6934/6934[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m638s[0m 88ms/step - loss: 1.6632
Epoch 3/3
[1m6934/6934[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m639s[0m 88ms/step - loss: 1.6006


In [None]:
"""
def generate(prompt, max_new=50):
    ids = sp.encode(prompt, out_type=int)

    for _ in range(max_new):
        x = ids[-256:]
        pad = [sp.pad_id()] * (256 - len(x))
        x = pad + x
        x = tf.constant([x])

        seq_len = len(ids)
        logits = model(x)[0, -1]
        next_id = int(tf.argmax(logits))
        ids.append(next_id)

        if next_id == sp.eos_id():
            break

    return sp.decode(ids)
"""

def generate(prompt, sp, model, max_new_tokens=60, temperature=0.6, top_k=30):
    # Encode initial tokens
    ids = sp.encode(prompt)

    for _ in range(max_new_tokens):
        # Keep only last 256 tokens if needed
        x = ids[-256:]

        # Pad to fixed length
        pad_len = 256 - len(x)
        x_padded = [sp.pad_id()] * pad_len + x

        x_tensor = tf.constant([x_padded])

        # Forward pass → get logits of last token
        logits = model(x_tensor)[0, -1]

        # -------------------------
        #   Repetition blocking
        # -------------------------
        if len(ids) > 1:
            last_id = ids[-1]
            logits = tf.tensor_scatter_nd_update(
                logits,
                indices=[[last_id]],
                updates=[-1e9]
            )

        # -------------------------
        #      Temperature
        # -------------------------
        logits = logits / temperature

        # -------------------------
        #         Top-K
        # -------------------------
        values, indices = tf.nn.top_k(logits, k=top_k)
        probs = tf.nn.softmax(values)[None, :]  # batch dims

        # Sample from top-k logits
        next_k = tf.random.categorical(tf.math.log(probs), 1)[0, 0].numpy()
        next_token = int(indices[next_k])

        ids.append(next_token)

        # Stop on EOS
        if next_token == sp.eos_id():
            break

    return sp.decode(ids)


print(generate("<user>: Hey bro\n<assistant>:", sp, model))

<user>: Hey bro <assistant>: hermos偿 hermosua hermosuah terminal ob hermos。令 hermos年 hermos。令。eceh sir hermos。令 hermosuaheceh lyrics ob hermos。架 hermos。令 ob hermossl hermosas hermossl hermosamenteh llegó ob hermosamente hermosamente hermosamente hermossl令 faster
