In [3]:
import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math

@tf.function
@keras.saving.register_keras_serializable()
def prepare_sinusoidal_lookup_table(EMBEDDING_SIZE: int = 128, max_seq_len: int = 512):
    """
    Builds a sinusoidal positional encoding lookup table.
    
    Args:
      EMBEDDING_SIZE: dimensionality of each position encoding vector (must be even).
      max_seq_len: maximum sequence length (number of positions).
    
    Returns:
      lookup_table: a tf array of shape (max_seq_len, EMBEDDING_SIZE)
                    where row p gives the positional encoding for position p.
    """
    # Initialize the table
    lookup_table = np.zeros((max_seq_len, EMBEDDING_SIZE), dtype=np.float32)
    
    # Compute the angle rates for each dimension
    # angle_rates[k] = 1 / (10000^(2*(k//2) / EMBEDDING_SIZE))
    dims = np.arange(EMBEDDING_SIZE)[np.newaxis, :]   # shape (1, EMBEDDING_SIZE)
    positions = np.arange(max_seq_len)[:, np.newaxis] # shape (max_seq_len, 1)
    angle_rates = 1 / np.power(10000, (2 * (dims // 2)) / EMBEDDING_SIZE)
    
    # Compute the angle for each position and dimension: position * angle_rate
    angle_rads = positions * angle_rates  # shape (max_seq_len, EMBEDDING_SIZE)
    
    # Apply sin to even indices (0,2,4,...) and cos to odd indices (1,3,5,...)
    lookup_table[:, 0::2] = np.sin(angle_rads[:, 0::2])
    lookup_table[:, 1::2] = np.cos(angle_rads[:, 1::2])
    
    return tf.constant(lookup_table)

In [4]:
import os
from typing import List, Dict, Tuple
import numpy as np
import tensorflow as tf

@keras.saving.register_keras_serializable()
def tokenize_and_build_vocabulary_tf(file_path_list: List[str], existing_vocab: Dict[str, int] | None = None) -> Dict[str, int]:
    """
    Build a character-level vocabulary dictionary from text files.
    
    Args:
        file_path_list: List of file paths containing the text corpus.
        existing_vocab: Optional existing vocabulary to extend.

    Returns:
        token_to_id: dict mapping character to unique integer token ID.
    """
    if isinstance(file_path_list, (str, bytes)):
        file_path_list = [file_path_list] # type: ignore
    if existing_vocab is None:
        existing_vocab = {}
    vocab_set = set(existing_vocab.keys())
    
    for file_name in file_path_list:
        if os.path.isdir(file_name):
            raise IsADirectoryError(f"Expected file path, got directory: {file_name}")
        if not os.path.isfile(file_name):
            raise FileNotFoundError(f"File not found: {file_name}")
        with open(file_name, encoding="utf-8") as f:
            text = f.read()
            vocab_set.update(text)
    
    sorted_tokens = sorted(vocab_set)
    token_to_id = {char: idx for idx, char in enumerate(sorted_tokens)}
    return token_to_id

@keras.saving.register_keras_serializable()
def tokenize_and_build_token_id(token_to_id_dict: Dict[str, int], text_batch: List[str], max_seq_len: int, pad_value: int = 0) -> Tuple[tf.Tensor, tf.Tensor]:
    """
    Tokenize a batch of text strings into character token IDs using a token dictionary,
    then pad/truncate to max_seq_len and create attention masks.

    Args:
        token_to_id_dict: dict mapping character to integer token ID.
        text_batch: list of text strings to tokenize.
        max_seq_len: maximum sequence length after padding/truncation.
        pad_value: integer ID used for padding tokens.

    Returns:
        token_ids: tf.Tensor of shape (batch_size, max_seq_len), dtype tf.int32.
        attention_mask: tf.Tensor of shape (batch_size, max_seq_len), dtype tf.int32 (1 for real tokens, 0 for padding).
    """
    batch_token_ids = []
    for text in text_batch:
        ids = [token_to_id_dict.get(c, pad_value) for c in text]
        if len(ids) > max_seq_len:
            ids = ids[:max_seq_len]
        else:
            ids += [pad_value] * (max_seq_len - len(ids))
        batch_token_ids.append(ids)
    
    token_ids = np.array(batch_token_ids, dtype=np.int32)
    attention_mask = (token_ids != pad_value).astype(np.int32)
    
    return tf.constant(token_ids), tf.constant(attention_mask) # type: ignore


In [5]:
import os
from typing import List, Dict, Tuple
import numpy as np
import tensorflow as tf
import sentencepiece as spm


@keras.saving.register_keras_serializable()
def train_sentencepiece_tokenizer(file_path_list: List[str], 
                                vocab_size: int = 2000,
                                model_prefix: str = 'spm_gpt') -> spm.SentencePieceProcessor:
    """
    Train SentencePiece tokenizer from text files (replaces tokenize_and_build_vocabulary_tf).
    
    Args:
        file_path_list: List of file paths containing the text corpus.
        vocab_size: Size of the subword vocabulary (default: 2000).
        model_prefix: Prefix for output model files.
    
    Returns:
        sp: Trained SentencePieceProcessor object.
    """
    if isinstance(file_path_list, (str, bytes)):
        file_path_list = [file_path_list]
    
    # Validate files (same as your original)
    for file_name in file_path_list:
        if os.path.isdir(file_name):
            raise IsADirectoryError(f"Expected file path, got directory: {file_name}")
        if not os.path.isfile(file_name):
            raise FileNotFoundError(f"File not found: {file_name}")
    
    # Combine all files into one input (or use comma-separated list)
    input_files = ','.join(file_path_list)
    
    # Train SentencePiece model
    spm.SentencePieceTrainer.train(
        input=input_files,
        model_prefix=model_prefix,
        vocab_size=vocab_size,
        character_coverage=0.9995,
        model_type='bpe',
        pad_id=0,
        unk_id=1,
        bos_id=2,
        eos_id=3,
    )
    
    # Load and return processor
    sp = spm.SentencePieceProcessor()
    sp.load(f'{model_prefix}.model')
    return sp

@keras.saving.register_keras_serializable()
def tokenize_and_build_token_id_sp(sp: spm.SentencePieceProcessor, 
                                 text_batch: List[str], 
                                 max_seq_len: int, 
                                 pad_value: int = 0) -> Tuple[tf.Tensor, tf.Tensor]:
    """
    Tokenize batch of text using SentencePiece (replaces tokenize_and_build_token_id).
    
    Args:
        sp: Trained SentencePieceProcessor object.
        text_batch: List of text strings to tokenize.
        max_seq_len: Maximum sequence length after padding/truncation.
        pad_value: Integer ID used for padding tokens (should match sp.pad_id()).
    
    Returns:
        token_ids: tf.Tensor of shape (batch_size, max_seq_len), dtype tf.int32.
        attention_mask: tf.Tensor of shape (batch_size, max_seq_len), dtype tf.int32.
    """
    batch_token_ids = []
    
    for text in text_batch:
        # Encode text to subword IDs
        ids = sp.encode_as_ids(text)
        
        # Truncate if too long
        if len(ids) > max_seq_len:
            ids = ids[-max_seq_len:]  # Keep the end (recent context)
        else:
            # Pad to max_seq_len
            ids += [pad_value] * (max_seq_len - len(ids))
        
        batch_token_ids.append(ids)
    
    token_ids = np.array(batch_token_ids, dtype=np.int32)
    attention_mask = (token_ids != pad_value).astype(np.int32)
    
    return tf.constant(token_ids), tf.constant(attention_mask) # type: ignore

In [6]:
@keras.saving.register_keras_serializable()
class InitializePositionalEmbeddings(keras.layers.Layer):
    def __init__(
        self,
        d_model: int,
        vocab_size : int,
        max_seq_len: int = 512,
        pad_value: int = 0,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.d_model = int(d_model)
        self.pad_value = int(pad_value)
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self._pos_table = prepare_sinusoidal_lookup_table(d_model, max_seq_len)

    def build(self, input_shape):
        self.embedding_matrix = self.add_weight(
            name="embedding_matrix",
            shape=(self.vocab_size, self.d_model),
            initializer="random_normal",
            trainable=True,
            dtype=tf.float32
        )
        super().build(input_shape)

    def call(self, text_batch):

        token_ids= text_batch # Unpacking Data Pre-processing inputs Embeddings
        
        # Embeddings lookup: (B, T, D)
        token_emb = tf.nn.embedding_lookup(self.embedding_matrix, token_ids)
        # Positional embeddings: slice and broadcast
        seq_len = tf.shape(token_ids)[1] # type: ignore
        pos_emb = self._pos_table[:seq_len, :]    # type: ignore # (T, D)
        pos_emb = tf.expand_dims(pos_emb, 0)     # (1, T, D)
        embeddings = token_emb + pos_emb         # (B, T, D)
        return embeddings

    # def compute_output_shape(self, input_shape):
    #     # input_shape: (batch_size,)
    #     batch = input_shape
    #     # Sequence length is dynamic: None
    #     return (batch, None, self.d_model), (batch, None)

    def get_config(self):
        cfg = super().get_config()
        cfg.update({
            "d_model": self.d_model,
            'vocab_size': self.vocab_size,
            'max_seq_len': self.max_seq_len,
            "pad_value": self.pad_value,
        })
        return cfg
    
    def compute_output_shape(self, input_shape):
        # input_shape: (batch_size, seq_len)
        batch_size = input_shape[0]
        seq_len = input_shape[1]
        return (batch_size, seq_len, self.d_model)


In [7]:
@keras.saving.register_keras_serializable()
class SelfAttentionLayer(keras.layers.Layer):
    def __init__(self,attention_heads = 8, **kwargs):
        super().__init__(**kwargs)
        self.attention_heads = attention_heads
        
    def build(self, input_shape): # Two tuples -> first tuple is (Batch Shape , Max_seq_length_in_batch,d_model) , Second tuple is (batch , max_seq_len)
        self.d_model = input_shape[0][-1]
        self.Query_projection = self.add_weight(
            name = 'Query_Vector_for_projection',
            initializer = 'random_normal',
            shape = (self.d_model,self.d_model),
            trainable = True 
        )
        self.Key_projection = self.add_weight(
            name = 'Key_Vector_for_projection',
            initializer = 'random_normal',
            shape = (self.d_model,self.d_model),
            trainable = True 
        )
        self.Value_projection = self.add_weight(
            name = 'Value_Vector_for_projection',
            initializer = 'random_normal',
            shape = (self.d_model,self.d_model),
            trainable = True 
        )

        self.output_projection = self.add_weight(
        name="Output_projection",
        initializer="random_normal",
        shape=(self.d_model, self.d_model),
        trainable=True,
        )

        self.d_head = self.d_model // self.attention_heads # type: ignore
        
        assert self.d_model % self.attention_heads == 0, "d_model must be divisible by attention_heads"

    def call(self,inputs):
        embeddings = inputs[0]
        token_masks = inputs[1]

        batch_size = tf.shape(embeddings)[0] # type: ignore
        seq_len = tf.shape(embeddings)[1] # type: ignore

        Q = embeddings @ self.Query_projection # (seq_len , d_model)
        K = embeddings @ self.Key_projection
        V = embeddings @ self.Value_projection

        # 2. Reshape and transpose for multi-head Attention
        Q = tf.reshape(Q, (batch_size, seq_len, self.attention_heads, self.d_head))
        K = tf.reshape(K, (batch_size, seq_len, self.attention_heads, self.d_head))
        V = tf.reshape(V, (batch_size, seq_len, self.attention_heads, self.d_head))

        Q = tf.transpose(Q, (0, 2, 1, 3))  # (batch, heads, seq_len, d_head)
        K = tf.transpose(K, (0, 2, 1, 3))
        V = tf.transpose(V, (0, 2, 1, 3))

        scores = tf.matmul(Q,K, transpose_b=True) # (batch , heads , seq_len,seq_len)
        scores = scores / tf.sqrt(tf.cast(self.d_head, tf.float32))
        # 5a. Causal mask (L,L) lower triangular
        causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
        causal_mask = tf.reshape(causal_mask, (1, 1, seq_len, seq_len))  # (1,1,L,L)

        # 5b. Token mask (B,L) -> (B,1,1,L)
        token_mask = tf.cast(token_masks[:, tf.newaxis, tf.newaxis, :], tf.float32)

        # 5c. Combine masks
        combined_mask = causal_mask * token_mask  # broadcast -> (B, H, L, L)

        # 6. Apply mask (replace disallowed with -1e9)
        scores = tf.where(combined_mask > 0, scores, tf.constant(-1e9, dtype = scores.dtype))

        attention_weights = tf.nn.softmax(scores, axis=-1)
        context = attention_weights @ V   #(batch, heads, seq_len, seq_len) × (batch, heads, seq_len, d_head) → (batch, heads, seq_len, d_head)
        concat_context = tf.reshape(context, (batch_size,seq_len,self.attention_heads * self.d_head))  # type: ignore

        final_context = concat_context @ self.output_projection 
        return final_context
    
    def get_config(self):
        config = super().get_config()
        config.update({"attention_heads": self.attention_heads,})
        return config
    
    def compute_output_shape(self, input_shape):
        return input_shape[0]

In [8]:
@keras.saving.register_keras_serializable()
class LayerNormalization(keras.layers.Layer):
    def __init__(self,eps=1e-5,**kwargs):
        super().__init__(**kwargs)
        self.eps = eps
    
    def build(self,input_shape): # Near Attention (batch, seq_len, d_model)
        self.alpha = self.add_weight(
            name = 'alpha',
            shape = input_shape[-1:],
            initializer = 'ones',
            dtype = tf.float32,
            trainable = True
        )
        self.beta = self.add_weight(
            name = 'beta',
            shape = input_shape[-1:],
            initializer = 'zeros',
            dtype = tf.float32,
            trainable = True
        )
        super().build(input_shape)
        
    def call(self, inputs):
        mean, var = tf.nn.moments(inputs, axes=[-1], keepdims=True)
        normed = (inputs - mean) / tf.sqrt(var + self.eps) # type: ignore
        return self.alpha * normed + self.beta

    def get_config(self):
        base = super().get_config()
        return {**base, "eps": self.eps}
    
    def compute_output_shape(self, input_shape):
        return input_shape

In [15]:
@keras.saving.register_keras_serializable()
class CosineDecayWithWarmup(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, 
                 warmup_steps: int,
                 total_steps: int,
                 peak_learning_rate: float = 1e-4,
                 min_learning_rate: float = 1e-6,
                 name: str = "cosine_decay_with_warmup"):
        super().__init__()
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.peak_learning_rate = peak_learning_rate
        self.min_learning_rate = min_learning_rate
        self.name = name
        
    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        total_steps = tf.cast(self.total_steps, tf.float32)
        
        # Warmup phase: linear increase from 0 to peak_learning_rate
        warmup_lr = self.peak_learning_rate * step / warmup_steps
        
        # Cosine decay phase
        decay_steps = total_steps - warmup_steps
        cosine_decay_lr = self.min_learning_rate + 0.5 * (
            self.peak_learning_rate - self.min_learning_rate
        ) * (1 + tf.cos(np.pi * (step - warmup_steps) / decay_steps))
        
        return tf.where(step < warmup_steps, warmup_lr, cosine_decay_lr)
    
    def get_config(self):
        return {
            "warmup_steps": self.warmup_steps,
            "total_steps": self.total_steps,
            "peak_learning_rate": self.peak_learning_rate,
            "min_learning_rate": self.min_learning_rate,
            "name": self.name,
        }

In [9]:
@keras.saving.register_keras_serializable()
class DecoderBlock(keras.Model):
    '''A single Decoder Block'''
    def __init__(self, d_model, n_heads, dropout_rate=0.1, epsilon=1e-5, name=None):
        super().__init__(name=name)
        self.d_model = d_model
        self.n_heads = n_heads
        self.dropout_rate = dropout_rate
        self.epsilon = epsilon
        # norms
        self.ln1 = LayerNormalization(epsilon)   # pre-attn
        self.ln2 = LayerNormalization(epsilon)   # pre-ffn
        # attention (assumes your SelfAttentionLayer accepts (x, attention_mask))
        self.attn = SelfAttentionLayer(n_heads)
        self.dropout1 = keras.layers.Dropout(dropout_rate)
        # FFN
        self.ffn1 = keras.layers.Dense(4 * d_model, activation="gelu")
        self.ffn2 = keras.layers.Dense(d_model)
        self.dropout2 = keras.layers.Dropout(dropout_rate)

    def call(self, x, attention_mask, training=False):
        # Self-attention sublayer
        y = self.ln1(x)
        y = self.attn((y, attention_mask))          # shape: (B, T, d_model)
        y = self.dropout1(y, training=training)
        x = x + y                                    # residual

        # FFN sublayer
        y = self.ln2(x)
        y = self.ffn1(y)
        y = self.ffn2(y)
        y = self.dropout2(y, training=training)
        x = x + y                                    # residual
        return x
    
    def compute_output_shape(self, input_shape):
        # input_shape is typically (batch_size, seq_len, d_model)
        return input_shape
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "d_model": self.d_model,
            "n_heads": self.n_heads,
            "dropout_rate": self.dropout_rate,
            "epsilon": self.epsilon,
        })
        return config

@keras.saving.register_keras_serializable()
class GPT(keras.Model):
    '''
    GPT model with N distinct blocks
      -----------------------------------'''
    def __init__(self,
                 d_model: int = 128,
                 vocab_size: int = 94,
                 context_length: int = 512,
                 attention_heads: int = 8,
                 epsilon: float = 1e-5,
                 decoder_blocks: int = 3,
                 dropout_rate: float = 0.1,
                 **kwargs):
        super().__init__(**kwargs)
        self._d_model = d_model
        self._vocab_size = vocab_size
        self._context_length = context_length
        self._attention_heads = attention_heads
        self._epsilon = epsilon
        self._decoder_blocks = decoder_blocks
        self._dropout_rate = dropout_rate

        # embeddings (yours)
        self.emb = InitializePositionalEmbeddings(
            d_model, vocab_size,name="init_embeddings"
        )

        # stack of distinct decoder blocks
        self.blocks = [
            DecoderBlock(d_model, attention_heads, dropout_rate, epsilon, name=f"decoder_block_{i}")
            for i in range(decoder_blocks)
        ]

        # final norm (GPT-2 style) and LM head
        self.final_ln = LayerNormalization(epsilon)
        self.lm_head = keras.layers.Dense(vocab_size, name="Model_head")

    def call(self, inputs, training=False):
        """
        inputs: (token_ids, attention_mask)
          - token_ids: int32 (B, T)
          - attention_mask: int32/float32 mask broadcasting to attention logits.
            Common shapes: (B, 1, 1, T) or (B, T) if your SelfAttentionLayer handles expansion.
        """
        token_ids, attention_mask = inputs
        x = self.emb(token_ids)                         # (B, T, d_model)

        for block in self.blocks:
            x = block(x, attention_mask, training=training)

        x = self.final_ln(x)
        logits = self.lm_head(x)                        # (B, T, vocab_size)
        return logits                                   # keep softmax outside

    def get_config(self):
        cfg = super().get_config()
        cfg.update({
            "d_model": self._d_model,
            "vocab_size": self._vocab_size,
            "context_length": self._context_length,
            "attention_heads": self._attention_heads,
            "epsilon": self._epsilon,
            "decoder_blocks": self._decoder_blocks,
            "dropout_rate": self._dropout_rate,
        })
        return cfg


{0: '\n',
 1: ' ',
 2: '!',
 3: "'",
 4: '(',
 5: ')',
 6: ',',
 7: '-',
 8: '.',
 9: '0',
 10: '1',
 11: '2',
 12: '3',
 13: '4',
 14: '5',
 15: '6',
 16: '7',
 17: '8',
 18: '9',
 19: ':',
 20: ';',
 21: '?',
 22: 'a',
 23: 'b',
 24: 'c',
 25: 'd',
 26: 'e',
 27: 'f',
 28: 'g',
 29: 'h',
 30: 'i',
 31: 'j',
 32: 'k',
 33: 'l',
 34: 'm',
 35: 'n',
 36: 'o',
 37: 'p',
 38: 'q',
 39: 'r',
 40: 's',
 41: 't',
 42: 'u',
 43: 'v',
 44: 'w',
 45: 'x',
 46: 'y',
 47: 'z'}

In [None]:
import gradio as gr
import numpy as np
import tensorflow as tf
import keras

token_to_id_dict = tokenize_and_build_vocabulary_tf([r'/home/akshat/GPT_from_scratch/text_data/jane_austen_clean.txt'])
id_to_token_dict = {id_val: token for token, id_val in token_to_id_dict.items()}

# Use the latest and best model - try the best_model.keras first, then latest checkpoint
try:
    model = keras.models.load_model(r'/home/akshat/GPT_from_scratch/notebooks/rewrite_char_level_checkpoints/model_epoch_233_val_loss_1.8321.keras')
    print("✅ Loaded best_model.keras")
except:
    try:
        model = keras.models.load_model(r'/home/akshat/GPT_from_scratch/notebooks/char_level_checkpoints/model_epoch_163_val_loss_0.0303.keras')
        print("✅ Loaded epoch 163 model")
    except:
        model = keras.models.load_model(r'/home/akshat/GPT_from_scratch/notebooks/char_level_checkpoints/model_epoch_161_val_loss_0.0305.keras')
        print("✅ Loaded epoch 161 model")

CONTEXT_LEN = model._context_length  # Use the model's actual context length

# Debug: Print vocabulary info
print(f"Vocabulary size: {len(token_to_id_dict)}")
print(f"Model type: {type(model)}")

# Get model info from your custom GPT model
try:
    print(f"Model vocab size: {model._vocab_size}")
    print(f"Model context length: {model._context_length}")
    print(f"Model d_model: {model._d_model}")
    print(f"Model attention heads: {model._attention_heads}")
    print(f"Model decoder blocks: {model._decoder_blocks}")
    print(f"Vocab size matches model: {model._vocab_size == len(token_to_id_dict)}")
    
    if model._vocab_size != len(token_to_id_dict):
        print(f"⚠️  VOCAB SIZE MISMATCH! Model expects {model._vocab_size}, got {len(token_to_id_dict)}")
    
except Exception as e:
    print(f"Error getting model info: {e}")

print(f"Sample characters in vocab: {list(token_to_id_dict.keys())[:30]}")
print(f"Common characters present: {['a' in token_to_id_dict, 'e' in token_to_id_dict, ' ' in token_to_id_dict, '.' in token_to_id_dict]}")

# Check for problematic characters in the gibberish output
gibberish = "4ff.mtm 64m86rfstmfm?.fmmftms777mtmkf  tm7n7m77m77"
print(f"Checking gibberish characters:")
for char in set(gibberish):
    if char in token_to_id_dict:
        print(f"  '{char}' -> ID {token_to_id_dict[char]} ✓")
    else:
        print(f"  '{char}' -> NOT IN VOCAB ✗")

def encode_text(text, token_to_id_dict):
    """Encode text to token IDs using character-level tokenizer - convert to lowercase since dataset is lowercase"""
    # Convert input to lowercase since your dataset was lowercased
    text = text.lower()
    
    token_ids = []
    for char in text:
        if char in token_to_id_dict:
            token_ids.append(token_to_id_dict[char])
        else:
            print(f"Warning: '{char}' (ord: {ord(char)}) not in vocabulary, skipping")
            continue
    return token_ids

def decode_ids(token_ids, id_to_token_dict):
    """Decode token IDs back to text using character-level tokenizer"""
    text = ""
    for token_id in token_ids:
        if token_id in id_to_token_dict:
            text += id_to_token_dict[token_id]
        else:
            print(f"Warning: token ID {token_id} not in vocabulary")
    return text

def get_special_token_ids():
    """Get special token IDs - adjust these based on your tokenizer setup"""
    # For Jane Austen data, likely no special PAD token, use newline as EOS
    pad_id = token_to_id_dict.get('<PAD>', None)
    eos_id = token_to_id_dict.get('\n', None)  # Use newline as natural stopping point
    print(f"Special tokens - PAD: {pad_id}, EOS (newline): {eos_id}")
    return pad_id, eos_id

def top_k_sampling(logits, k=10):
    """Sample from logits using top-k sampling"""
    # Ensure we don't sample more than available tokens
    k = min(k, len(logits))
    
    values, indices = tf.math.top_k(logits, k=k)
    last_val = values[-1]
    filtered_logits = tf.where(
        logits < last_val,
        tf.fill(tf.shape(logits), float('-inf')),
        logits
    )
    probs = tf.nn.softmax(filtered_logits).numpy()
    
    # Add small epsilon to avoid numerical issues
    probs = probs + 1e-10
    probs = probs / np.sum(probs)
    
    return np.random.choice(len(probs), p=probs)

def generate_response(prompt, max_length=100, temperature=0.7, top_k=10, use_argmax=False):
    if not prompt.strip():
        return ""
    
    print(f"\n--- Generation Debug ---")
    print(f"Input prompt: '{prompt}' (will be lowercased)")
    
    # Tokenize prompt with character-level tokenizer
    input_tokens = encode_text(prompt, token_to_id_dict)
    print(f"Input tokens: {input_tokens}")
    print(f"Input tokens decoded back: '{decode_ids(input_tokens, id_to_token_dict)}'")
    
    if not input_tokens:
        return "Error: Could not tokenize input"
    
    # Truncate if longer than context length
    if len(input_tokens) > CONTEXT_LEN:
        input_tokens = input_tokens[-CONTEXT_LEN:]
    
    generated_tokens = input_tokens.copy()
    pad_id, eos_id = get_special_token_ids()
    
    print(f"Starting generation with {len(input_tokens)} input tokens...")
    
    for step in range(max_length):
        # Prepare inputs - pad from left to maintain most recent context
        input_ids = np.zeros((1, CONTEXT_LEN), dtype=np.int32)
        attention_mask = np.zeros((1, CONTEXT_LEN), dtype=np.int32)
        
        # Place tokens at the end of the context window
        current_len = min(len(generated_tokens), CONTEXT_LEN)
        start_idx = CONTEXT_LEN - current_len
        input_ids[0, start_idx:] = generated_tokens[-current_len:]
        attention_mask[0, start_idx:] = 1
        
        # Model forward pass
        try:
            logits = model((input_ids, attention_mask), training=False)
            next_token_logits = logits[0, -1, :]
            
            # Apply temperature
            if not use_argmax:
                next_token_logits = next_token_logits / temperature
        except Exception as e:
            print(f"Model forward pass error: {e}")
            break
        
        # Sample next token
        try:
            if use_argmax:
                # Use argmax (greedy) sampling for testing
                next_token = int(np.argmax(next_token_logits))
            else:
                # Use top-k sampling
                next_token = top_k_sampling(next_token_logits, k=top_k)
        except Exception as e:
            print(f"Sampling error: {e}")
            break
        
        # Debug: Print first few tokens
        if step < 10:
            sampled_char = id_to_token_dict.get(next_token, f"<UNK:{next_token}>")
            prob = float(tf.nn.softmax(next_token_logits)[next_token])
            print(f"Step {step}: Token {next_token} -> '{sampled_char}' (prob: {prob:.4f})")
        
        # Check if token is valid
        if next_token >= len(id_to_token_dict):
            print(f"Warning: Invalid token {next_token}, vocab size is {len(id_to_token_dict)}")
            break
        
        # Stop on special tokens
        if pad_id is not None and next_token == pad_id:
            print(f"Stopping at step {step}: hit PAD token")
            break
        if eos_id is not None and next_token == eos_id and step > 10:  # Don't stop too early
            print(f"Stopping at step {step}: hit EOS token (newline)")
            break
        
        generated_tokens.append(int(next_token))
        
        # Maintain sliding window
        if len(generated_tokens) > CONTEXT_LEN:
            generated_tokens = generated_tokens[-CONTEXT_LEN:]
    
    # Decode only the newly generated tokens
    new_tokens = generated_tokens[len(input_tokens):]
    response = decode_ids(new_tokens, id_to_token_dict)
    print(f"Generated {len(new_tokens)} new tokens: {new_tokens[:20]}...")  # Show first 20
    print(f"Generated response: '{response}'")
    print(f"--- End Debug ---\n")
    
    return response.strip()

def chat_fn(message, history, temperature, max_length, top_k, use_argmax):
    if not message.strip():
        return "", history
    
    bot_response = generate_response(message, max_length=max_length, temperature=temperature, top_k=top_k, use_argmax=use_argmax)
    history.append((message, bot_response))
    return "", history

# Quick test with the better model
print("Testing improved model:")
test_cases = ["the", "elizabeth", "it is a"]

for prompt in test_cases:
    print(f"\nTesting with: '{prompt}'")
    tokens = encode_text(prompt, token_to_id_dict)
    
    # Create model input
    input_ids = np.zeros((1, 256), dtype=np.int32)
    attention_mask = np.zeros((1, 256), dtype=np.int32)
    input_ids[0, -len(tokens):] = tokens
    attention_mask[0, -len(tokens):] = 1
    
    # Get model predictions
    logits = model((input_ids, attention_mask), training=False)
    next_token_logits = logits[0, -1, :]
    
    # Show top 5 predictions
    top_probs, top_indices = tf.nn.top_k(tf.nn.softmax(next_token_logits), k=5)
    print("Top 5 predictions:")
    for i in range(5):
        token_id = int(top_indices[i])
        prob = float(top_probs[i])
        char = id_to_token_dict.get(token_id, f"UNK_{token_id}")
        print(f"  {i+1}. '{char}' (ID: {token_id}) - {prob:.4f}")

# Test with longer context
print(f"\nTesting with longer Jane Austen context:")
long_prompt = "it is a truth universally acknowledged that a single man in possession of a good fortune must be in want of a"
tokens = encode_text(long_prompt, token_to_id_dict)
print(f"Context: '{long_prompt}'")
print(f"Context length: {len(tokens)} tokens")

# Use reasonable context length
context_tokens = tokens[-100:] if len(tokens) > 100 else tokens

input_ids = np.zeros((1, 256), dtype=np.int32)
attention_mask = np.zeros((1, 256), dtype=np.int32)
input_ids[0, -len(context_tokens):] = context_tokens
attention_mask[0, -len(context_tokens):] = 1

logits = model((input_ids, attention_mask), training=False)
next_token_logits = logits[0, -1, :]

top_probs, top_indices = tf.nn.top_k(tf.nn.softmax(next_token_logits), k=5)
print("Top 5 predictions after long context:")
for i in range(5):
    token_id = int(top_indices[i])
    prob = float(top_probs[i])
    char = id_to_token_dict.get(token_id, f"UNK_{token_id}")
    print(f"  {i+1}. '{char}' (ID: {token_id}) - {prob:.4f}")

with gr.Blocks(title="My Character-Level GPT Bot Trained in Tensorflow", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🤖 Chat with Akshat's Character-Level GPT Model")
    gr.Markdown("Ask me anything! I'm a GPT model trained from scratch with character-level tokenization on Jane Austen data. Be gentle with me :)")
    
    # Add vocab info
    gr.Markdown(f"**Model Info:** Vocabulary size: {len(token_to_id_dict)} characters")
    
    chatbot = gr.Chatbot(label="Conversation", height=400, show_copy_button=True)
    
    with gr.Row():
        msg = gr.Textbox(label="Your message", placeholder="Type your message here...", scale=4)
        send_btn = gr.Button("Send", scale=1, variant="primary")
    
    with gr.Row():
        temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.3, step=0.05, label="Temperature")
        max_length = gr.Slider(minimum=10, maximum=200, value=30, step=10, label="Max Length")
        top_k = gr.Slider(minimum=1, maximum=50, value=5, step=1, label="Top-K Sampling")
        use_argmax = gr.Checkbox(label="Use Argmax (Greedy) - for testing", value=True)
    
    clear_btn = gr.Button("Clear Chat", variant="secondary")
    
    # Event handlers
    msg.submit(chat_fn, [msg, chatbot, temperature, max_length, top_k, use_argmax], [msg, chatbot])
    send_btn.click(chat_fn, [msg, chatbot, temperature, max_length, top_k, use_argmax], [msg, chatbot])
    clear_btn.click(lambda: [], None, chatbot)

if __name__ == "__main__":
    demo.launch(
        share=True,          # Generate public share link
        server_name="127.0.0.1",
        server_port=6020,
        show_error=True
    )

✅ Loaded best_model.keras
Vocabulary size: 38
Model type: <class '__main__.GPT'>
Model vocab size: 38
Model context length: 128
Model d_model: 64
Model attention heads: 2
Model decoder blocks: 1
Vocab size matches model: True
Sample characters in vocab: ['\n', ' ', '!', "'", '(', ')', ',', '-', '.', ':', ';', '?', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r']
Common characters present: [True, True, True, True]
Checking gibberish characters:
  '.' -> ID 8 ✓
  'f' -> ID 17 ✓
  '6' -> NOT IN VOCAB ✗
  's' -> ID 30 ✓
  'r' -> ID 29 ✓
  'm' -> ID 24 ✓
  'n' -> ID 25 ✓
  '?' -> ID 11 ✓
  '7' -> NOT IN VOCAB ✗
  'k' -> ID 22 ✓
  '4' -> NOT IN VOCAB ✗
  't' -> ID 31 ✓
  '8' -> NOT IN VOCAB ✗
  ' ' -> ID 1 ✓
Testing improved model:

Testing with: 'the'
Top 5 predictions:
  1. 'a' (ID: 12) - 0.1819
  2. 'j' (ID: 21) - 0.0798
  3. ')' (ID: 5) - 0.0611
  4. 's' (ID: 30) - 0.0552
  5. 'k' (ID: 22) - 0.0526

Testing with: 'elizabeth'
Top 5 predictions:
  1

  chatbot = gr.Chatbot(label="Conversation", height=400, show_copy_button=True)


* Running on local URL:  http://127.0.0.1:6020
* Running on public URL: https://4a5e37927586e1893f.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)



--- Generation Debug ---
Input prompt: 'a' (will be lowercased)
Input tokens: [12]
Input tokens decoded back: 'a'
Special tokens - PAD: None, EOS (newline): 0
Starting generation with 1 input tokens...
Step 0: Token 9 -> ':' (prob: 0.1236)
Step 1: Token 34 -> 'w' (prob: 0.1939)
Step 2: Token 34 -> 'w' (prob: 0.4320)
Step 3: Token 34 -> 'w' (prob: 0.4292)
Step 4: Token 34 -> 'w' (prob: 0.4280)
Step 5: Token 34 -> 'w' (prob: 0.4284)
Step 6: Token 34 -> 'w' (prob: 0.4303)
Step 7: Token 34 -> 'w' (prob: 0.4323)
Step 8: Token 34 -> 'w' (prob: 0.4322)
Step 9: Token 34 -> 'w' (prob: 0.4294)
Generated 30 new tokens: [9, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34]...
Generated response: ':wwwwwwwwwwwwwwwwwwwwwwwwwwwww'
--- End Debug ---


--- Generation Debug ---
Input prompt: 'a' (will be lowercased)
Input tokens: [12]
Input tokens decoded back: 'a'
Special tokens - PAD: None, EOS (newline): 0
Starting generation with 1 input tokens...
Step 0: Token 16 -> 'e' (

In [None]:
# Test with a very simple prompt
test_prompt = "it is a"
print(f"Testing with: '{test_prompt}'")

# Tokenize
tokens = encode_text(test_prompt, token_to_id_dict)
print(f"Tokens: {tokens}")

# Create model input
input_ids = np.zeros((1, 256), dtype=np.int32)
attention_mask = np.zeros((1, 256), dtype=np.int32)
input_ids[0, -len(tokens):] = tokens
attention_mask[0, -len(tokens):] = 1

# Get model predictions
logits = model((input_ids, attention_mask), training=False)
next_token_logits = logits[0, -1, :]

# Show top 10 predictions
top_probs, top_indices = tf.nn.top_k(tf.nn.softmax(next_token_logits), k=10)
print("Top 10 predictions:")
for i in range(10):
    token_id = int(top_indices[i])
    prob = float(top_probs[i])
    char = id_to_token_dict.get(token_id, f"UNK_{token_id}")
    print(f"  {i+1}. '{char}' (ID: {token_id}) - {prob:.4f}")

Testing with: 'it is a'
Tokens: [30, 41, 1, 30, 40, 1, 22]
Top 10 predictions:
  1. 'f' (ID: 27) - 0.3225
  2. 'c' (ID: 24) - 0.0669
  3. ',' (ID: 6) - 0.0521
  4. ' ' (ID: 1) - 0.0432
  5. '-' (ID: 7) - 0.0400
  6. 't' (ID: 41) - 0.0397
  7. '9' (ID: 18) - 0.0386
  8. 'p' (ID: 37) - 0.0377
  9. 'j' (ID: 31) - 0.0376
  10. 's' (ID: 40) - 0.0365


In [24]:
# Check character frequency in your data
with open('/home/akshat/GPT_from_scratch/text_data/jane_austen_clean.txt', 'r') as f:
    text = f.read().lower()

char_counts = {}
for char in text:
    char_counts[char] = char_counts.get(char, 0) + 1

# Sort by frequency
sorted_chars = sorted(char_counts.items(), key=lambda x: x[1], reverse=True)
print("Top 10 most frequent characters:")
for char, count in sorted_chars[:10]:
    print(f"'{char}': {count}")

Top 10 most frequent characters:
' ': 716875
'e': 433350
't': 296790
'a': 268929
'o': 264312
'n': 244375
'i': 234267
's': 212583
'h': 212370
'r': 209492


In [17]:
# Test your tokenizer
test_text = "Hello"
tokens = encode_text(test_text, token_to_id_dict)
decoded = decode_ids(tokens, id_to_token_dict)
print(f"Original: '{test_text}' -> Tokens: {tokens} -> Decoded: '{decoded}'")

Original: 'Hello' -> Tokens: [26, 33, 33, 36] -> Decoded: 'ello'
