In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

# --- Data Preparation ---

# A simple, small corpus from Alice in Wonderland
corpus = """
alice was beginning to get very tired of sitting by her sister on the bank,
and of having nothing to do: once or twice she had peeped into the book her
sister was reading, but it had no pictures or conversations in it, and what
is the use of a book, thought alice without pictures or conversations?
"""

# Tokenization and Vocabulary
class SimpleTokenizer:
    def __init__(self, text):
        # Convert all text to lower case and split by whitespace
        tokens = text.lower().split()
        
        # Create a set of unique words
        vocab = sorted(list(set(tokens)))
        
        # Map words to unique integers (token IDs)
        self.word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
        self.idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
        self.vocab_size = len(vocab) + 1 # +1 for padding/unknown token (index 0)

    def encode(self, text, max_len=None):
        tokens = text.lower().split()
        encoded = [self.word_to_idx.get(word, 0) for word in tokens]
        
        if max_len:
            # Pad or truncate the sequence
            if len(encoded) < max_len:
                encoded += [0] * (max_len - len(encoded))
            else:
                encoded = encoded[:max_len]
        return encoded

    def decode(self, encoded):
        # Decode only if the index is not 0 (padding/unknown)
        return ' '.join([self.idx_to_word[idx] for idx in encoded if idx != 0 and idx in self.idx_to_word])

# Create sequences for next-word prediction
def create_sequences(tokenizer, corpus, seq_len=4):
    tokens = corpus.lower().split()
    encoded_tokens = tokenizer.encode(corpus)
    
    inputs, targets = [], []
    for i in range(len(encoded_tokens) - seq_len):
        # Input: sequence of length seq_len
        input_seq = encoded_tokens[i:i + seq_len]
        # Target: the token immediately following the input sequence
        target_token = encoded_tokens[i + seq_len]
        
        inputs.append(input_seq)
        targets.append(target_token)
        
    return torch.tensor(inputs), torch.tensor(targets)

# --- Transformer Core Component: Scaled Dot-Product Self-Attention ---

class SelfAttention(nn.Module):
    """
    A simplified single-head Self-Attention mechanism, the heart of the Transformer.
    This replaces the sequential processing of an RNN/LSTM.
    """
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        
        # Linear layers to project the input embedding into Query, Key, and Value vectors
        self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, embed_dim)
        
        # 1. Project to Q, K, V
        Q = self.W_q(x) # Query: What am I looking for?
        K = self.W_k(x) # Key: What do I have?
        V = self.W_v(x) # Value: What context should I pass?
        
        # 2. Compute Attention Scores (Scaled Dot-Product)
        # Q @ K.transpose(-2, -1) calculates the similarity between all pairs of words.
        # Scaling by sqrt(d_k) stabilizes the gradient.
        d_k = Q.size(-1)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        
        # 3. Normalize Scores
        # Softmax turns scores into probability weights (0 to 1)
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # 4. Apply Weights to Values
        # The output is a weighted sum of the Value vectors (Context Vector)
        context_vector = torch.matmul(attention_weights, V)
        
        return context_vector # shape: (batch_size, seq_len, embed_dim)


# --- The Toy LLM using Attention ---

class ToyAttentionModel(nn.Module):
    """
    A minimal LLM architecture using an embedding layer followed by Self-Attention.
    """
    def __init__(self, vocab_size, embed_dim, seq_len):
        super().__init__()
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        
        # 1. Embedding Layer: Converts token IDs to dense vectors
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # 2. Positional Encoding (Crucial for Transformers)
        # Since attention processes all words at once, we must inject their order.
        self.positional_encoding = nn.Embedding(seq_len, embed_dim)

        # 3. Self-Attention Block (The new core)
        self.attention = SelfAttention(embed_dim)

        # 4. Final Linear Layer: Maps the context vector back to the vocab size
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        # x shape: (batch_size, seq_len)
        batch_size, seq_len = x.size()

        # 1. Look up word embeddings
        word_embeddings = self.embedding(x) # (batch_size, seq_len, embed_dim)
        
        # 2. Add Positional Encoding
        # Generate position indices (0, 1, 2, 3...)
        positions = torch.arange(seq_len, dtype=torch.long, device=x.device).unsqueeze(0).repeat(batch_size, 1)
        pos_embeddings = self.positional_encoding(positions)
        
        # Transformer Input = Word Embedding + Positional Encoding
        x = word_embeddings + pos_embeddings
        
        # 3. Pass through Self-Attention
        context_vector = self.attention(x) # (batch_size, seq_len, embed_dim)
        
        # For next-word prediction, we only care about the last word's context
        # to predict the next word in the sequence.
        last_context = context_vector[:, -1, :] # (batch_size, embed_dim)

        # 4. Output: Predict the next token
        output = self.fc(last_context) # (batch_size, vocab_size)
        
        return output

# --- Training and Inference Functions ---

def simulate_pretraining(model, inputs, targets, epochs=3000): # Increased epochs
    """
    Simulates the foundational pre-training phase (Next-Word Prediction).
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    
    print(f"Starting pre-training with {epochs} epochs...")
    for epoch in range(epochs):
        # 1. Forward pass
        outputs = model(inputs)
        
        # 2. Calculate Loss (Error)
        loss = criterion(outputs, targets)
        
        # 3. Backpropagation (Find Blame)
        optimizer.zero_grad()
        loss.backward()
        
        # 4. Optimization (Adjust Weights)
        optimizer.step()
        
        if (epoch + 1) % 500 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
    
    print("Pre-training complete.")
    return model

def generate_text(model, tokenizer, start_phrase, seq_len=4, max_tokens=10, temperature=0.8):
    """
    Generates text using the trained model, using temperature sampling
    to prevent getting stuck in repetitive loops.
    """
    model.eval() # Set model to evaluation mode
    generated_tokens = tokenizer.encode(start_phrase, max_len=seq_len)
    
    if len(generated_tokens) < seq_len:
        print("Error: Start phrase is too short for the required sequence length.")
        return start_phrase

    print(f"\n--- Generating Text (Max {max_tokens} tokens) ---")
    print(f"Start: '{start_phrase}'")
    
    output_tokens = generated_tokens
    
    with torch.no_grad(): # Disable gradient calculations during inference
        for _ in range(max_tokens):
            # 1. Prepare the current sequence input
            current_sequence = torch.tensor([output_tokens[-seq_len:]])
            
            # 2. Get prediction from the Attention Model
            output = model(current_sequence) # (1, vocab_size)
            
            # 3. Apply Temperature (dividing logits by T)
            output = output / temperature
            
            # 4. Convert to probabilities and sample (instead of argmax)
            probabilities = F.softmax(output, dim=-1)
            predicted_token_id = torch.multinomial(probabilities, num_samples=1).item()
            
            if predicted_token_id == 0: # Stop on padding/unknown
                break
            
            # 5. Add the new token to the sequence for the next step
            output_tokens.append(predicted_token_id)
            
            # Stop if the model starts repeating itself or generating gibberish
            if len(output_tokens) > 2 * seq_len and len(set(output_tokens[-seq_len:])) < 2:
                 break

    # Decode and return the result
    return tokenizer.decode(output_tokens)

# --- Main Execution ---

if __name__ == '__main__':
    # Define hyper-parameters
    EMBED_DIM = 32    # Size of the word vector (Increased from 16)
    SEQ_LEN = 4       # Number of words the model looks at to predict the next
    EPOCHS = 3000     # Number of training iterations (Increased from 1500)

    # 1. Data Setup
    tokenizer = SimpleTokenizer(corpus)
    VOCAB_SIZE = tokenizer.vocab_size
    
    inputs, targets = create_sequences(tokenizer, corpus, seq_len=SEQ_LEN)
    
    print(f"Vocabulary Size: {VOCAB_SIZE}")
    print(f"Sequence Length (context window): {SEQ_LEN}")
    print(f"Total training examples: {len(inputs)}")

    # 2. Model Initialization (Using the new Attention Model)
    model = ToyAttentionModel(VOCAB_SIZE, EMBED_DIM, SEQ_LEN)
    
    # 3. Training
    model = simulate_pretraining(model, inputs, targets, epochs=EPOCHS)
    
    # 4. Inference (Generate text)
    start_phrase = "alice was beginning to get"
    generated_text = generate_text(model, tokenizer, start_phrase, seq_len=SEQ_LEN, max_tokens=20)
    
    print(f"\nResulting Text:")
    print(generated_text)
    
    # Simple summary of the attention core
    print("\n--- ATTENTION MECHANISM SUMMARY ---")
    print("In the SelfAttention class, the key operation is:")
    print("1. Q, K, V Projections: Maps input vector (x) into three roles.")
    print("2. Scoring: torch.matmul(Q, K.transpose) calculates pairwise similarity.")
    print("3. Weighting: Softmax turns these scores into probabilistic attention weights.")
    print("4. Context: torch.matmul(Weights, V) sums the V vectors based on the computed weights.")

Vocabulary Size: 44
Sequence Length (context window): 4
Total training examples: 53
Starting pre-training with 3000 epochs...
Epoch [500/3000], Loss: 0.1051
Epoch [1000/3000], Loss: 0.0786
Epoch [1500/3000], Loss: 0.0785
Epoch [2000/3000], Loss: 0.0785
Epoch [2500/3000], Loss: 0.0785
Epoch [3000/3000], Loss: 0.0785
Pre-training complete.

--- Generating Text (Max 20 tokens) ---
Start: 'alice was beginning to get'

Resulting Text:
alice was beginning to get very tired of sitting nothing to do: once or twice she had peeped into the book her sister on

--- ATTENTION MECHANISM SUMMARY ---
In the SelfAttention class, the key operation is:
1. Q, K, V Projections: Maps input vector (x) into three roles.
2. Scoring: torch.matmul(Q, K.transpose) calculates pairwise similarity.
3. Weighting: Softmax turns these scores into probabilistic attention weights.
4. Context: torch.matmul(Weights, V) sums the V vectors based on the computed weights.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

### --- CELL 1: SETUP & HYPERPARAMETERS --- ###
print("### --- CELL 1: SETUP & HYPERPARAMETERS --- ###")

# Define hyper-parameters for the model
EMBED_DIM = 32      # Size of the word vector
SEQ_LEN = 4         # Context window size (e.g., predict word 5 from words 1-4)
EPOCHS = 3000       # Number of training iterations
TEMPERATURE = 0.8   # Sampling temperature (for randomness in generation)

print(f"EMBED_DIM: {EMBED_DIM}, SEQ_LEN: {SEQ_LEN}, EPOCHS: {EPOCHS}, TEMP: {TEMPERATURE}")

# Placeholder for the device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

### --- CELL 1: SETUP & HYPERPARAMETERS --- ###
EMBED_DIM: 32, SEQ_LEN: 4, EPOCHS: 3000, TEMP: 0.8
Using device: cuda


In [2]:

### --- CELL 2: DATA UTILITIES (TOKENIZER & SEQUENCE CREATION) --- ###
print("\n### --- CELL 2: DATA UTILITIES (TOKENIZER & SEQUENCE CREATION) --- ###")

class SimpleTokenizer:
    """Handles tokenization and mapping words to integers (IDs)."""
    def __init__(self, text):
        tokens = text.lower().split()
        vocab = sorted(list(set(tokens)))
        
        # Word-to-index mapping (index 0 is reserved for padding/unknown)
        self.word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
        self.idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
        self.vocab_size = len(vocab) + 1 

    def encode(self, text, max_len=None):
        tokens = text.lower().split()
        encoded = [self.word_to_idx.get(word, 0) for word in tokens]
        
        if max_len:
            if len(encoded) < max_len:
                encoded += [0] * (max_len - len(encoded))
            else:
                encoded = encoded[:max_len]
        return encoded

    def decode(self, encoded):
        return ' '.join([self.idx_to_word[idx] for idx in encoded if idx != 0 and idx in self.idx_to_word])

def create_sequences(encoded_tokens, seq_len=4):
    """Generates input (X) and target (Y) pairs for training."""
    inputs, targets = [], []
    for i in range(len(encoded_tokens) - seq_len):
        input_seq = encoded_tokens[i:i + seq_len]
        target_token = encoded_tokens[i + seq_len]
        
        inputs.append(input_seq)
        targets.append(target_token)
        
    return torch.tensor(inputs), torch.tensor(targets)

print("Tokenizer and sequence creation utilities defined.")



### --- CELL 2: DATA UTILITIES (TOKENIZER & SEQUENCE CREATION) --- ###
Tokenizer and sequence creation utilities defined.


In [3]:

### --- CELL 3: MODEL COMPONENTS (ATTENTION & TOY LLM) --- ###
print("\n### --- CELL 3: MODEL COMPONENTS (ATTENTION & TOY LLM) --- ###")

class SelfAttention(nn.Module):
    """Scaled Dot-Product Self-Attention mechanism."""
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        
        # Q, K, V projection matrices
        self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)
        
    def forward(self, x):
        # x shape: (batch_size, seq_len, embed_dim)
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Compute Attention Scores: Q @ K.transpose / sqrt(d_k)
        d_k = Q.size(-1)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
        
        # Apply Softmax to get probability weights
        attention_weights = F.softmax(attention_scores, dim=-1)
        
        # Apply weights to Values to get the context vector
        context_vector = torch.matmul(attention_weights, V)
        
        return context_vector

class ToyAttentionModel(nn.Module):
    """Minimal LLM combining Embedding, Positional Encoding, and Self-Attention."""
    def __init__(self, vocab_size, embed_dim, seq_len):
        super().__init__()
        self.seq_len = seq_len
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = nn.Embedding(seq_len, embed_dim)
        self.attention = SelfAttention(embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        batch_size, seq_len = x.size()

        # 1. Word and Positional Embeddings
        word_embeddings = self.embedding(x)
        positions = torch.arange(seq_len, dtype=torch.long, device=x.device).unsqueeze(0).repeat(batch_size, 1)
        pos_embeddings = self.positional_encoding(positions)
        x = word_embeddings + pos_embeddings
        
        # 2. Self-Attention
        context_vector = self.attention(x)
        
        # 3. Predict the next token based on the context of the last word
        last_context = context_vector[:, -1, :] 

        # 4. Final classification layer
        output = self.fc(last_context)
        
        return output

print("Attention model components defined.")


def simulate_pretraining(model, inputs, targets, epochs, device):
    """Simulates the training loop."""
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
    
    model.train()
    print(f"Starting pre-training with {epochs} epochs...")
    
    # Move tensors to the correct device
    inputs = inputs.to(device)
    targets = targets.to(device)

    for epoch in range(epochs):
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 500 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
    
    print("Pre-training complete.")
    return model

def generate_text(model, tokenizer, start_phrase, seq_len, max_tokens, temperature, device):
    """Generates text using temperature sampling."""
    model.eval()
    generated_tokens = tokenizer.encode(start_phrase, max_len=seq_len)
    output_tokens = generated_tokens
    
    with torch.no_grad():
        for _ in range(max_tokens):
            # Use the last SEQ_LEN tokens as input
            current_sequence = torch.tensor([output_tokens[-seq_len:]]).to(device)
            
            output = model(current_sequence)
            
            # Apply Temperature and Softmax
            output = output / temperature
            probabilities = F.softmax(output, dim=-1)
            
            # Sample the next token
            predicted_token_id = torch.multinomial(probabilities, num_samples=1).item()
            
            if predicted_token_id == 0: # Stop on padding/unknown
                break
            
            output_tokens.append(predicted_token_id)
            
            # Simple check to prevent repetitive loops
            if len(output_tokens) > 2 * seq_len and len(set(output_tokens[-seq_len:])) < 2:
                 break

    return tokenizer.decode(output_tokens)



### --- CELL 3: MODEL COMPONENTS (ATTENTION & TOY LLM) --- ###
Attention model components defined.


In [5]:
!pip install pypdf
### --- CELL 4: SIMULATED PDF TEXT EXTRACTION --- ###
print("\n### --- CELL 4: SIMULATED PDF TEXT EXTRACTION --- ###")
    
    # NOTE: In a real Jupyter environment, you would use a library like pypdf 
    # to extract text from a file path:
    # 
import pypdf
pdf_path = "the-prince.pdf"
reader = pypdf.PdfReader(pdf_path)
text_content = "".join(page.extract_text() for page in reader.pages)
    
    # For this simulation, we use a large sample text representing the extracted content:
#    text_content = """
#    the quick brown fox jumps over the lazy dog the dog was sleeping
#    and the fox was quick and brown the quick fox needed to jump to
#    get food because the dog was very lazy
#    """
    
    # Clean and flatten the text
corpus = " ".join(text_content.lower().split())

print("Corpus loaded (simulated PDF text).")
print(f"Corpus size (tokens): {len(corpus.split())}")


Collecting pypdf
### --- CELL 4: SIMULATED PDF TEXT EXTRACTION --- ###

  Downloading pypdf-6.1.3-py3-none-any.whl.metadata (7.1 kB)
Downloading pypdf-6.1.3-py3-none-any.whl (323 kB)
Installing collected packages: pypdf
Successfully installed pypdf-6.1.3



[notice] A new release of pip is available: 25.1.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


Corpus loaded (simulated PDF text).
Corpus size (tokens): 51743


In [6]:
### --- CELL 5: DATA PREPARATION & MODEL INITIALIZATION --- ###
print("\n### --- CELL 5: DATA PREPARATION & MODEL INITIALIZATION --- ###")
    
    # 1. Tokenization
tokenizer = SimpleTokenizer(corpus)
VOCAB_SIZE = tokenizer.vocab_size
    
    # 2. Sequence Creation
encoded_tokens = tokenizer.encode(corpus)
inputs, targets = create_sequences(encoded_tokens, seq_len=SEQ_LEN)
    
print(f"Vocabulary Size: {VOCAB_SIZE}")
print(f"Total training examples: {len(inputs)}")
print(f"Sample Input (IDs): {inputs[0].tolist()}")
print(f"Sample Target (ID): {targets[0].item()}")

    # 3. Model Initialization
model = ToyAttentionModel(VOCAB_SIZE, EMBED_DIM, SEQ_LEN).to(DEVICE)
print("\nModel instantiated and moved to device.")



### --- CELL 5: DATA PREPARATION & MODEL INITIALIZATION --- ###
Vocabulary Size: 8283
Total training examples: 51739
Sample Input (IDs): [2339, 3209, 2427, 5102]
Sample Target (ID): 1450

Model instantiated and moved to device.


In [None]:
### --- CELL 6: TRAINING THE MODEL --- ###
print("\n### --- CELL 6: TRAINING THE MODEL --- ###")
EPOCHS = 100    
# Train the model
model = simulate_pretraining(model, inputs, targets, epochs=EPOCHS, device=DEVICE)



### --- CELL 6: TRAINING THE MODEL --- ###
Starting pre-training with 100 epochs...


In [None]:
### --- CELL 7: INFERENCE AND GENERATION --- ###
print("\n### --- CELL 7: INFERENCE AND GENERATION --- ###")
    
    # Define the starting phrase (must match words in the corpus)
start_phrase = "live in freedom" 
 
print(f"Starting phrase: '{start_phrase}'")

    # Generate the text
generated_text = generate_text(
    model, 
    tokenizer, 
    start_phrase, 
    seq_len=SEQ_LEN, 
    max_tokens=15, 
    temperature=TEMPERATURE, 
    device=DEVICE
)
    
print("\n--- GENERATED TEXT ---")
print(generated_text)

In [None]:
### --- CELL 8: SUPERVISED FINE-TUNING (SFT) --- ###
    print("\n--- --- --- CELL 8: SUPERVISED FINE-TUNING (SFT) --- --- ---")
    
    # 1. Prepare SFT Data (Instruction + Desired Response)
    sft_data = [
        ("Tell me about the fox", f"{SEP_TOKEN} the fox was quick and brown and needed to jump"),
        ("Why did the fox jump", f"{SEP_TOKEN} to get food because the dog was lazy"),
        ("Describe the dog", f"{SEP_TOKEN} the lazy dog was sleeping")
    ]
    
    # 2. Convert SFT data to sequences
    sft_inputs_list, sft_targets_list = [], []
    for prompt, completion in sft_data:
        full_text = prompt + " " + completion # E.g., "Tell me... : the fox was quick..."
        encoded = tokenizer.encode(full_text)
        
        # We need to train the model to predict the next token in the instruction-response pair.
        sft_inputs_temp, sft_targets_temp = create_sequences(encoded, seq_len=SEQ_LEN)
        sft_inputs_list.append(sft_inputs_temp)
        sft_targets_list.append(sft_targets_temp)

    sft_inputs = torch.cat(sft_inputs_list)
    sft_targets = torch.cat(sft_targets_list)
    
    print(f"SFT Examples prepared: {len(sft_data)}. Total SFT steps: {len(sft_inputs)}")
    
    # 3. Perform SFT (Tune the base model)
    sft_model = deepcopy(base_model) # Start SFT from the pre-trained weights
    sft_model = train_model(sft_model, sft_inputs, sft_targets, epochs=SFT_EPOCHS, device=DEVICE, title="SFT")

    # 4. SFT Inference Test
    sft_prompt = "Why did the fox jump"
    sft_start_tokens = tokenizer.encode(sft_prompt + " " + SEP_TOKEN) # Start with instruction + separator
    
    sft_generation = generate_text(
        sft_model, 
        tokenizer, 
        sft_start_tokens, 
        seq_len=SEQ_LEN, 
        max_tokens=15, 
        temperature=TEMPERATURE, 
        device=DEVICE
    )
    
    print("\n--- SFT Model Test ---")
    print(f"Instruction: '{sft_prompt}'")
    print(f"SFT Model Response: {sft_generation}")


    ### --- CELL 9: REINFORCEMENT LEARNING WITH HUMAN FEEDBACK (RLHF) --- ###
    print("\n--- --- --- CELL 9: REINFORCEMENT LEARNING WITH HUMAN FEEDBACK (RLHF) --- --- ---")

    # 1. Define a Mock Reward Model (RM)
    # Goal: We want the model to favor responses that talk about 'food' more.
    def reward_model_score(text):
        score = 0
        if "food" in text:
            score += 5.0 # High reward for mentioning the key concept
        if "quick" in text:
            score += 2.0
        if "lazy" in text:
            score -= 1.0 # Penalize for mentioning 'lazy'
        return score

    print("Mock Reward Model defined: rewards responses about 'food'.")
    
    # 2. RLHF Loop (Simulated PPO Step)
    
    # a. Generate a response from the current SFT Policy (Policy = sft_model)
    rlhf_prompt = "Tell me about the fox"
    rlhf_start_tokens = tokenizer.encode(rlhf_prompt + " " + SEP_TOKEN)

    # Use a higher temperature to explore different generations
    generated_response = generate_text(
        sft_model, 
        tokenizer, 
        rlhf_start_tokens, 
        seq_len=SEQ_LEN, 
        max_tokens=15, 
        temperature=1.2, # Higher temp for diverse exploration
        device=DEVICE
    )

    # b. Score the response using the Reward Model (RM)
    reward = reward_model_score(generated_response)
    print(f"\nInstruction: '{rlhf_prompt}'")
    print(f"Generated Response: '{generated_response}'")
    print(f"Reward Model Score (RM): {reward:.2f}")

    # c. Policy Update (Simulated PPO step based on high reward)
    # In reality, this would use PPO to adjust the model based on the reward.
    # Here, we simulate the update by running a few more targeted SFT steps
    # *only* on the highly-rewarded sequence to reinforce it.
    
    # Mock RL Policy Update:
    if reward > 4.5:
        print("\nReward is high (>4.5)! Simulating Policy (PPO) Update...")
        
        # Create a tiny, focused training set from the rewarded generation
        reinforced_text = rlhf_prompt + " " + SEP_TOKEN + " " + generated_response
        reinforced_encoded = tokenizer.encode(reinforced_text)
        
        rlhf_inputs, rlhf_targets = create_sequences(reinforced_encoded, seq_len=SEQ_LEN)
        
        # Run a very small, focused training on this single sequence
        final_model = train_model(sft_model, rlhf_inputs, rlhf_targets, epochs=10, device=DEVICE, title="RLHF Update")
        
    else:
        print("\nReward is too low. Policy Update skipped.")
        final_model = sft_model

    # 3. Final Inference (RLHF-aligned Model)
    print("\n--- RLHF-Aligned Model Test ---")
    
    rlhf_final_generation = generate_text(
        final_model, 
        tokenizer, 
        rlhf_start_tokens, 
        seq_len=SEQ_LEN, 
        max_tokens=15, 
        temperature=TEMPERATURE, 
        device=DEVICE
    )

    print(f"Instruction: '{rlhf_prompt}'")
    print(f"RLHF Model Response: {rlhf_final_generation}")

    # 4. Cleanup
    if DEVICE.type == 'cuda':
        torch.cuda.empty_cache()