### **Lab - Notebook Objective**

The objective of this notebook is to introduce pre-trained Transformer models for language modeling. It showcases how to use these models for text generation by providing prompts and observing the generated output. Additionally, it encourages comparison between the Transformer models' performance and that of previously discussed n-gram models. The goal is to demonstrate the enhanced capabilities of Transformer models in capturing context and generating more natural and fluent text.

<!--### Dataset:

The section utilizes the same TinyStories dataset as before. This dataset comprises a large number of synthetically generated short stories that are suitable for basic language modeling tasks due to their simplified language and controlled vocabulary.-->

### Libraries

The following libraries are essential for this notebook:

<!---   **datasets:** Used for loading and managing the TinyStories dataset. Install it using `!pip install datasets`. -->
-   **gemma:** Provides tools for working with Gemma language models, including loading and prompting. Install it using `!pip install gemma==3.0.0`.
-   **re:** Used for regular expressions in text processing.
-   **random:** Used for generating random numbers, particularly in text generation tasks.
-   **pandas:** Used for data manipulation and analysis, especially for creating and working with dataframes.
-   **collections:** Provides specialized data structures like `Counter` and `defaultdict`, useful for counting n-grams.
-   **IPython.display:** Used for displaying elements in the notebook, like clearing the output.
-   **keras:** A high-level neural networks API, which might be used for certain language modeling tasks.

### Models

In this notebook, we will use Transformer models of varying sizes: `Gemma-1B`, and `Gemma-4B`.
<!--
- TinyStories-3M, and 33M are Transformer models trained on the TinyStories English Dataset detailed in [this article](https://arxiv.org/abs/2305.07759). The suffixes 3M, and 33M represent the model sizes, with 3, and 33 million parameters, respectively.   -->

- Gemma-1B, and 4B are Transformer models trained on a large corpus of English text comprising of web documents, code and mathematical texts. The suffixes 1B, and 4B represent the model sizes, with 1, and 4 billion parameters, respectively. More details about the model is provided in [this report](https://storage.googleapis.com/deepmind-media/gemma/Gemma3Report.pdf).


The parameters of a model determine how 'big' or complex it is. The bigger the model, the more computing power is required to run it.   



NOTE: For the purposes of this notebook, you will be able to load all the models on the T4 GPU hardware provided for free on Google colab. If you're using CPU, you will be able to run this activity using Gemma-1B only.

### Lab (part of the above) 1.7 Let's play with a Transformer Model now! - To comment - **Jonathan**




We will prompt (i.e we give a sequence of texts to) a transformer model and observe its output.

**Things to think about for this activity**

- As you are prompting the Transformer models, compare the quality of output with n-grams. Does the text look better? or worse? or just different? Take a close look and see what you think...

- Does the text generated by the Transformers look more natural and fluent compared to the n-gram model?

- Is the grammar and structure better than what you got from the n-gram model?

- Did any part of the output seem unexpected or surprising to you?

- Is there a difference in quality of output between the Transformer models?


**Please select the type of device you are using. If unsure, please go to Runtime -> Change runtime type and select GPU, if available.**


In [None]:
# device = 'gpu'  #@param ['cpu','gpu']

In [None]:
!pip install gemma==3.0.0
!pip install datasets

from IPython.display import clear_output
from datasets import load_dataset
import gemma as gm
clear_output()  # Clears the output

In [None]:
# @title keeping the code visible for now but will be hidden away
import os
import jax
import jax.numpy as jnp
import numpy as np
import plotly.express as px
from gemma import gm
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Any


def prompt_transformer_model(input_text: str, max_new_tokens: int = 10, model_name: int = 'Gemma-1B', do_sample: bool = True) -> tuple[str, np.ndarray, Any]:
    """
    Generate text from a transformer model (Gemma ) based on the input text.

    Args:
        input_text (str): The input prompt for the model.
        max_new_tokens (int): The maximum number of new tokens to generate.
        model_name (str): The name of the model to load. Supported options are 'Gemma'.
        do_sample (bool): Whether to use sampling for text generation (True for random sampling, False for greedy).

    Returns:
        output_text (str): The generated text, including the input text and the model's output.
        next_token_logits (np.ndarray): Logits for the next token (probability distribution).
        tokenizer: The tokenizer used for encoding/decoding the text.

    Raises:
        NotImplementedError: If the model_name is not recognized or supported.
    """

    assert isinstance(do_sample, bool), "do_sample must be a boolean value."

    # Process for Gemma-based models
    if 'Gemma' in model_name:
        tokenizer, model, params = load_gemma(model_name)
        sampler = gm.text.Sampler(
            model=model,
            params=params,
            tokenizer=tokenizer,
        )

        if not do_sample:
            sampler_output_text = sampler.sample(input_text, max_new_tokens=max_new_tokens, sampling=gm.text.Greedy())
        else:
            sampler_output_text = sampler.sample(input_text, max_new_tokens=max_new_tokens, sampling=gm.text.RandomSampling())

        # Convert the input text to tokens and apply the model to generate predictions
        prompt = tokenizer.encode(input_text, add_bos=True)
        prompt = jnp.asarray(prompt)
        out = model.apply(
            {'params': params},
            tokens=prompt,
            return_last_only=True,  # Only predict the last token
        )

        next_token_logits = out.logits
        output_text = input_text + sampler_output_text

    # # Process for TinyStories-based models - this will a
    # elif 'TinyStories' in model_name:
    #     tokenizer, model = load_HF(model_name)
    #     input_ids = tokenizer.encode(input_text, return_tensors="pt")
    #     output = model.generate(input_ids, max_new_tokens=max_new_tokens, num_beams=1, do_sample=do_sample, output_scores=True, return_dict_in_generate=True)

    #     output_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
    #     next_token_logits = output.scores[0].flatten().numpy()

    # If model_name is not recognized, raise an error
    else:
        raise NotImplementedError(f"Model '{model_name}' is not recognized! Supported models: Gemma")

    return output_text, next_token_logits, tokenizer


def load_gemma(model_name: str = "Gemma-1B") -> tuple:
    """
    Loads a Gemma model and its associated tokenizer and parameters.

    Args:
        model_name (str): The name of the Gemma model to load. Options are:
                          "Gemma-1B" and "Gemma-4B".

    Returns:
        tokenizer: Tokenizer for the specified Gemma model.
        model: The Gemma model.
        params: The parameters for the specified Gemma model.

    Raises:
        ValueError: If an unsupported model name is provided.
    """
    # Set the full GPU memory usage for JAX
    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

    # Initialize variables
    tokenizer = None
    model = None
    params = None

    # Model loading based on model_name
    if model_name == "Gemma-1B":
        tokenizer = gm.text.Gemma3Tokenizer()
        model = gm.nn.Gemma3_1B()
        params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_1B_PT)
    elif model_name == "Gemma-4B":
        tokenizer = gm.text.Gemma3Tokenizer()
        model = gm.nn.Gemma3_4B()
        params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_PT)
    else:
        raise ValueError(f"Unsupported model name: {model_name}. Please use 'Gemma-1B' or 'Gemma-4B'.")

    return tokenizer, model, params


# def load_HF(model_name='TinyStories-1M'):
#     """
#     Loads a Hugging Face model and its associated tokenizer.

#     Args:
#         model_name (str): The name of the Hugging Face model to load.
#                           By default, it loads 'TinyStories-1M'.

#     Returns:
#         tokenizer: Tokenizer for the specified Hugging Face model.
#         model: The Hugging Face model.

#     Raises:
#         ValueError: If the model name does not contain 'TinyStories'.
#     """
#     # Initialize variables
#     model = None
#     tokenizer = None

#     # Load TinyStories models
#     if 'TinyStories' in model_name:
#         model = AutoModelForCausalLM.from_pretrained(f"roneneldan/{model_name}")
#         tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
#     else:
#         raise ValueError(f"Unsupported model name: {model_name}. Please ensure the model name contains 'TinyStories'.")

#     return tokenizer, model



def plot_next_token(logits: jax.numpy.ndarray, tokenizer: Any, prompt: str, keep_top: int = 30):
    """
    Plots the probability distribution of the next tokens given the model logits and prompt.

    This function generates a bar plot showing the top `keep_top` tokens by probability
    after applying the softmax to the logits, based on the given input prompt.

    Args:
        logits (jax.numpy.ndarray): The raw logits output by the model for the next token prediction.
        tokenizer: The tokenizer used to decode token IDs to human-readable text.
        prompt (str): The input prompt used to generate the next token predictions.
        keep_top (int): The number of top tokens to display in the plot. Default is 30.

    Returns:
        None: Displays a plot showing the probability distribution of the top tokens.

    # Function from gemma https://github.com/google-deepmind/gemma/blob/ee0d55674ecd0f921d39d22615e4e79bd49fce94/gemma/gm/text/_tokenizer.py#L249-L284
    """

    # Apply softmax to logits to get probabilities
    probs = jax.nn.softmax(logits)

    # Select the top `keep_top` tokens by probability
    indices = jnp.argsort(probs)
    indices = indices[-keep_top:][::-1]  # Reverse to get highest probabilities first

    # Get the probabilities and corresponding tokens
    probs = probs[indices].astype(np.float32)
    tokens = [repr(tokenizer.decode(i.item())) for i in indices]

    # Create the bar plot using Plotly
    fig = px.bar(x=tokens, y=probs)

    # Customize the plot layout
    fig.update_layout(
        title=f'Probability Distribution of Next Tokens given the prompt="{prompt}"',
        xaxis_title='Tokens',
        yaxis_title='Probability',
    )

    # Display the plot
    fig.show()


**Your Task:**
1. Select the Transformer model from the `model_name` dropdown menu.
2. Enter a prompt of your choice using the `prompt` text field.
3. Click the Play button to run the cell.
4. Inspect the model's prediction for the next word.

For example, if you start with the prompt: `Jide was hungry so she went looking for` the Transformer model will predict the next token. A token can be a single character (like T), a full word (like The), or a sub-word (such as Th).

*Try running the cell several times to observe how the model responds to different prompts!*

[Take a small pause here]

What do you think should be the next word to follow `Jide was hungry so she went looking for` ?

Write down your answer

[make it free text, we can collect this and have huge dataset from student's answers]

> Is the cell below running too slow? 🤔 Click on the `model_name` to try out a different model size. Remember, model with fewer parameters is faster!

In [None]:
model_name = "Gemma-4B" #@param [ "Gemma-1B", "Gemma-4B"]

prompt = "Jide was hungry so she went looking for" #@param {type:"string"}
prompt = str(prompt)


output_text, next_token_logits, tokenizer = prompt_transformer_model(prompt, max_new_tokens=1, model_name=model_name)
clear_output() # clears the output

print(output_text)

NameError: name 'prompt_transformer_model' is not defined

**Visualize the probability distribution of the predicted next token**

Now that you've seen the model's prediction, let's think about the probability distribution behind the next token. The Transformer model doesn't just pick one token randomly—it actually calculates the likelihood of each possible next token, based on the context (prior words) of the prompt you provided.

The plot below visualizes the probability distribution of the next token predicted by the language model given the prompt.  Each bar represents a different token and its height corresponds to the probability assigned to that token by the model.  

Visualizing the probability distribution allows us to analyze the model's preferences for different token choices given the prompt.  A highly peaked distribution suggests high confidence in a single prediction, while a flatter distribution indicates greater uncertainty and a broader range of plausible next tokens.  Examining this distribution provides insights into the model's internal workings and helps us understand how it generates text, highlighting both its strengths (confident predictions) and weaknesses (uncertainty or biases).

In [None]:
plot_next_token(next_token_logits, tokenizer, prompt=prompt)

NameError: name 'plot_next_token' is not defined

When you run the cell above, the model generates a probability distribution for the next token. Some tokens will have higher probabilities than others, meaning they are more likely to be chosen as the next word.

**[Write out your observations]**

Here are a few likely observations using the Gemma-1B model:

1. The most probable token will usually be a common word that fits the context of the sentence (e.g., "food" after the prompt "Jide was hungry so she went looking for).
2. The model might suggest words that seem plausible but aren't always the most expected, like "a" or "something"
3. You might notice some tokens have low probabilities, meaning the model considers them less likely to fit but doesn't completely rule them out like "work" or "help"
4. Changing the Transformer model may result in slight variations in the predicted next token, as the prediction is influenced by the model's learned weights, which are in turn determined by the dataset used for training.

Try out different prompts and observe the probability distribution of the next token prediction!

**Changing the context slightly**

What happens to the probability distribution if we change the context slightly? Let's try `Jide was thirsty so she went looking for`

*Click the Play button to run the cell below*

In [None]:
model_name = "Gemma-1B" #@param ["Gemma-1B", "Gemma-4B"]

prompt = "Jide was thirsty so she went looking for" #@param {type:"string"}
prompt = str(prompt)


output_text, next_token_logits, tokenizer = prompt_transformer_model(prompt, max_new_tokens=1, model_name=model_name)
clear_output() # clears the output

plot_next_token(next_token_logits, tokenizer, prompt=prompt)

**What did you observe?**

When running the Transformer model with prompts like "Jide was thirsty so she went looking for" you might notice certain patterns in the predicted next tokens. For instance, you may see drink-related words like "water" suggested more often. This is because the Transformer model is **context-aware** and understands that terms related to hunger and thirst tend to align with certain words—like “food” or “water”—based on the context provided by the prompt.


**Comparison Between Transformer Models**

Different Transformer models can sometimes generate different next tokens, even for the same prompt. You might see variations in the suggestions depending on the size and training of the model you're using. Larger models, with more data and parameters, tend to generate more accurate and contextually appropriate predictions. Smaller models might be more limited in their understanding, occasionally offering less relevant or more generic predictions.

**Transformer Models vs. N-gram Models**

When comparing the Transformer models to traditional n-gram models, you likely noticed some key differences. N-gram models predict the next token based on a fixed window of the preceding tokens (e.g., the last two or three words). These models often struggle with longer-range dependencies or more complex sentence structures, as they only consider a limited context.

In contrast, Transformer models consider the entire sequence of text and focus on the relationships between all tokens, not just the immediate neighbors. This makes them more flexible and accurate, especially in situations where the context stretches beyond just a few words.

For example, when comparing outputs for the same prompt, you may see that n-gram models often fail to predict more specific words (like "water" or "food" after "hungry") because they don't understand the broader context as effectively. Transformer models, on the other hand, would likely generate more contextually appropriate words, like "food" when the prompt mentions hunger, or "water" when thirst is implied.

**Generating more samples**

Now, try increasing the `num_next_tokens` to generate more texts and observe how the model responds.


*Click the Play button to run the cell below*

In [None]:
model_name = "Gemma-1B" #@param [ "Gemma-1B", "Gemma-4B"]

prompt = "Jide was thirsty so she went looking for" #@param {type:"string"}
prompt = str(prompt)

num_next_tokens = 100 #@param {type: "number"}

output_text, next_word_logits, tokenizer = prompt_transformer_model(prompt, max_new_tokens=num_next_tokens, model_name=model_name)
clear_output() # clears the output

print(output_text)

Jide was thirsty so she went looking for a cool water which is located near her house. She went and took a good drink before she started her search. After she had taken her drink she went back home with her new found water. As she was coming her way another woman was also thirsty so she too went in search of a cool and refreshing water. While she was looking for water she saw a man trying to stop a woman who had fallen in the pool. The woman was bleeding from her arm. She went over to the man and


**Language Models as Stochastic Parrots**

When you ran the cell above multiple times, what did you notice?

- Did you observe stereotypical outputs like Jide carrying a pot full of mud?
- or perhaps the model switched gender from female to male pronoun mid sentence?
[Revise this]
- or it defaulted to common names seen frequently in its training data (like 'Lily', 'Jack', 'Jill', etc., in models trained only on datasets like TinyStories)?

Ultimately, language models are adept at predicting the next token, but they closely follow the distribution of their training data. If the model is trained on biased data, it will produce biased outputs. Similarly, if it's trained on data scraped from the entire internet, it will reflect the dominant texts and perspectives found there, often sidelining less common viewpoints and cultures. This can lead to the model reinforcing stereotypes, such as certain professions being associated with specific genders.

In upcoming modules, we'll revisit these issues and explore ways to better align language model outputs with humane values and preferences.

**The Output Above Changes Every Time You Ran the Cell, Right?**

You likely noticed that the output of the Transformer model changes each time you run the cell above, even with the same prompt. This is because the model uses a probability distribution to pick the next token, which introduces a level of stochasticity (randomness) into the prediction. This is similar to what you saw in the n-gram models, where the next word isn't always the same due to the model sampling from a set of possibilities.

This variability helps the Transformer model generate more diverse and creative outputs.

**Controlling the Model's Output**

If you want the model to always pick the token with the **highest probability** (meaning the most likely next token), you can set the variable `do_sample=False`. This will make the model more deterministic and return the most probable token each time.

Here's how you can do that:

```python
prompt_transformer_model(prompt, max_new_tokens=num_next_tokens, model_name=model_name, do_sample=False)
```

With this setting, the output will be consistent across multiple runs for the same prompt, as it always selects the most probable token.

**Sampling Mode (Default: `do_sample=True`)**

By default, when `do_sample=True`, the model samples from the probability distribution, which introduces randomness and results in more varied and creative outputs. This is helpful when you want the model to explore a range of possible continuations for a prompt, rather than sticking strictly to the most likely outcome.

*Run the cell below multiple times and observe the result.*

In [None]:
model_name = "Gemma-1B" #@param ["Gemma-1B", "Gemma-4B"]

prompt = "Jide was thirsty so she went looking for" #@param {type:"string"}

num_next_tokens = 100 #@param {type: "number"}

output_text, next_word_logits, tokenizer = prompt_transformer_model(prompt, max_new_tokens=num_next_tokens, model_name=model_name, do_sample=False)
clear_output() # clears the output

print(output_text)



*Running the cell above multiple times should return the same output!*

**Balancing creativity and consistency**

Sampling from a probability distribution allows the Transformer model to explore a range of possible next tokens, fostering creativity and generating diverse outputs. This approach contrasts with always picking the token with the highest probability, which focuses on the most likely next token, as you have seen above.

Different applications require different settings for this balance. For creative tasks, such as generating stories, sampling from the probability distribution is ideal because it allows the model to explore various possibilities and produce more imaginative results. On the other hand, in sensitive domains like healthcare, where accuracy, consistency, and reliability are critical, it's better to choose the token with the highest probability, as this reflects the model's highest certainty and minimizes the risk of errors.

 `End of second notebook `

