# Starter Notebook: Generating More Data With Gemma
Our ultimate goal in this competition is to take an original sample of text and a new version of that text rewritten by Gemma, and to figure out what prompt was used to get the new version. A helpful first step is to be able to generate a bunch of examples of what that looks like, so we can then learn the relationships between the original text, rewrite prompt and rewritten text.

To generate examples, we'll need a few things:
1. A corpus of original texts
2. A set of rewrite prompts
3. Our model (Gemma!) to use the original text and rewrite prompt to generate a rewritten text

Let's tackle them one by one.

## Generating `original_text`
While we don't know too much about the original text used in the competition test set,
the meta-kaggle dataset provides a corpus of forum messages on kaggle that we can
use as a simple example.


In [1]:
import pandas as pd

forum_messsages_df = pd.read_csv('/kaggle/input/meta-kaggle/ForumMessages.csv')
forum_messsages_df.head()

Unnamed: 0,Id,ForumTopicId,PostUserId,PostDate,ReplyToForumMessageId,Message,Medal,MedalAwardDate
0,667077,115913,1788308,11/06/2019 19:38:55,666668.0,"<p><a href=""/cdeotte"">@cdeotte</a> </p>\n\n<p>...",3.0,11/06/2019
1,667076,74968,3961461,11/06/2019 19:38:19,,"<p>A very detailed and helpful notebook, \nTha...",,
2,667075,115817,1666986,11/06/2019 19:37:59,,<p>You don't say. You might just got your wish...,,
3,667074,113468,1073620,11/06/2019 19:34:36,666591.0,"<p>Hi <a href=""/mobassir"">@mobassir</a> If I ...",3.0,11/07/2019
4,667073,116025,1666986,11/06/2019 19:33:54,,<p>This like betting your life savings on a ga...,3.0,11/06/2019


In [2]:
# Let's grab the first 5 messages to test our generation pipeline:

original_texts = forum_messsages_df['Message'][:5]

## Generating `rewrite_prompt`
While there are lots of ways to come up with rewrite prompts, for simplicity here are a few random prompts we can use.

In [3]:
rewrite_prompts = [
    'Explain this to me like I\'m five.',
    'Convert this into a sea shanty.',
    'Make this rhyme.',
]

## Generating `rewritten_text` with Gemma
Now for the fun part! We can use gemma to rewrite our original text samples
using the rewrite prompts we created.
The code in this cell is borrowed from [the model card](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch/variations/7b-it-quant).
The important things to know:

We're using the 7B parameter instruction tuned quantized model, which means:

- 7B Parameter: this is the larger of the two Gemma models (the other has 2 billion parameters).
    In general we expect the larger model to perform better on complex tasks, but
    it's more resource intensive. You can see exactly how Gemma 7B compares to to Gemma 2B [here](https://ai.google.dev/gemma).
- Instruction Tuned: instruction tuning is an extra training step that results in a model that
    can follow user instructions better. Our rewrite prompt is a kind of instruction, so this is what we want!
- Quantized: quantization is a way of shrinking the size of a model by reducing the precision of each
    parameter; so while our model still has 7 billion parameters, it's easier to run on limited
    hardware.

At the end of this cell, we'll have a `model` we can call `generate` on with a specially formatted prompt.

In [4]:
!pip install -q -U immutabledict sentencepiece 
!git clone https://github.com/google/gemma_pytorch.git
!mkdir /kaggle/working/gemma/
!mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/

import sys 
sys.path.append("/kaggle/working/gemma_pytorch/") 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

# Load the model
VARIANT = "7b-it-quant" 
MACHINE_TYPE = "cuda" 
weights_dir = '/kaggle/input/gemma/pytorch/7b-it-quant/2' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

# Model Config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
model_config.quant = "quant" in VARIANT

# Model.
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()


Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 102, done.[K
remote: Counting objects: 100% (47/47), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 102 (delta 26), reused 23 (delta 14), pack-reused 55[K
Receiving objects: 100% (102/102), 2.15 MiB | 22.92 MiB/s, done.
Resolving deltas: 100% (48/48), done.


AssertionError: /kaggle/input/gemma/pytorch/7b-it-quant/2/tokenizer.model

In [None]:
# Now we can loop through our input texts, randomly select a rewrite prompt, and see Gemma in action:

import random
random.seed(0)
# This is the prompt format the model expects
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"

rewrite_data = []

for original_text in original_texts:
    rewrite_prompt = random.choice(rewrite_prompts)
    prompt = f'{rewrite_prompt}\n{original_text}'
    rewritten_text = model.generate(
        USER_CHAT_TEMPLATE.format(prompt=prompt),
        device=device,
        output_len=100,
    )
    rewrite_data.append({
        'original_text': original_text,
        'rewrite_prompt': rewrite_prompt,
        'rewritten_text': rewritten_text,
    })
    

In [None]:
# Let's turn our generated data into a dataframe, and spot check the first rewrite to see if it makes sense.
rewrite_data_df = pd.DataFrame(rewrite_data)
rewrite_data_df[:1].values

# Next Steps

Huzzah! We have a dataset with original texts, rewrite prompts, and rewritten text. Here are a couple of suggestions of next steps you could take to generate a larger, more diverse dataset:
1. Add more original text data sources; besides just using all of the forum messages (instead of just the first 5), Kaggle has tons of datasets that would make reasonable input text. Here are few random datasets you could use:
    - The `Plot` column from the [Wikipedia Movie Plots dataset](https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots).
    - The `text` column from the [Emotions dataset](https://www.kaggle.com/datasets/nelgiriyewithana/emotions).
    - The `body_text` and `abstract` columns of the [Wikibooks Dataset](https://www.kaggle.com/datasets/dhruvildave/wikibooks-dataset).
    
    Note that each of these may need different preprocessing; for example, Gemma has a context length of 8192 tokens, so if the text is long, you'll need to truncate it.
2. Use gemma to generate original text.
3. Expand the list of rewrite prompts. You can come up with them manually, or explore having Gemma write rewrite prompts.
4. Play around with the generation of `rewritten_text`:
   - How does changing `output_len` affect the length and quality of rewrites?
   - Do rewrites with the 2B parameter model differ substantially from the 7B model?
   - Can you use [few shot prompting](https://www.promptingguide.ai/techniques/fewshot) to get higher quality rewrites?