# Part 1: Generate Training Data 📝

***Full notebook contents viewable on [Kaggle](https://www.kaggle.com/code/chuhuayang/prompt-recovery-pt-1-generate-training-data).***

First, we need to generate some training data to tune the model with. Three components are necessary:
1. Original Texts - For best results, this should be a large and diverse set of paragraph-length texts
2. Rewrite Prompts
3. Re-written Text - Since the competition targets Gemma-7B, we will feed that model the first two components and obtain the output

For Original Texts, since this is a Kaggle-hosted competition, we can take advantage of high quality natural language datasets available for free on Kaggle. To achieve diversity, we will use samples from four different well-curated datasets:
- [Wikipedia Movie Plots](https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots): This dataset contains descriptions of thousands of movies from around the world. We can use the long-form plot summaries as Original Texts.
- [Emotions](https://www.kaggle.com/datasets/nelgiriyewithana/emotions): This is a collection of English twitter messages, labeled with corresponding emotions. For our purposes, we can just use the tweets.
- [Wikibooks](https://www.kaggle.com/datasets/dhruvildave/wikibooks-dataset): This is a massive dataset, containing the complete contents of the [Wikibooks](https://www.wikibooks.org/) archives. We can use the `abstract` column from the English language table.

For Rewrite Prompts, there are a variety of ways to obtain them. We can write them fully manually, or write [Mad Libs](https://www.madlibs.com/)-style templates and write a simple program to fill in the blanks with different words. We can also prompt LLMs to generate them. These notebooks will use a Kaggle dataset containing a [collection of prompts](https://www.kaggle.com/datasets/chuhuayang/llm-prompt-recovery-competition) generated by myself and others using these methods.

### Loading in Data

We will add the aforementioned datasets to our notebook, then read them in. We will perform two basic pre-processing steps to obtain more representative samples.
- For entries from Wikipedia Movie Plots and Wikibooks, because they can be very long, we will slice the text into chunks of 256 words, or tokens. This also helps save resources when performing training and inference.
- Then, we will filter out entries that are too short or that contain sequences or characters not typical of English texts.

In these notebooks, as an example, we will randomly select only a small subset of the available data to fine-tune on.

In [1]:
import pandas as pd
import sqlite3 as sql

MAX_LENGTH = 256
def create_slices(x):
    return [' '.join(x[i:min(i+MAX_LENGTH, len(x))]) for i in range(0, len(x), MAX_LENGTH)]

def filter(x):
    filtered_out = ("==" in x) or (len(x) < 50) or (not x.isascii()) 
    return not filtered_out

def preprocessing(x):
    processed_df = x.str.split().apply(create_slices)
    processed_df = processed_df.explode(ignore_index=True).dropna()
    bool_df = processed_df.apply(filter)
    return processed_df[bool_df]
    

series_1 = preprocessing(pd.read_csv("/kaggle/input/emotions/text.csv")["text"])
series_1 = series_1.sample(n=2000, random_state=0)
df_1 = series_1.to_frame("original_text")
print(df_1.head())

series_2 = preprocessing(pd.read_csv("/kaggle/input/wikipedia-movie-plots/wiki_movie_plots_deduped.csv")["Plot"])
series_2 = series_2.sample(n=1500, random_state=0)
df_2 = series_2.to_frame("original_text")
print(df_2.head())

conn = sql.connect("/kaggle/input/wikibooks-dataset/wikibooks.sqlite")

series_3 = preprocessing(pd.read_sql_query("SELECT abstract FROM en", conn)["abstract"])
series_3 = series_3.sample(n=1500, random_state=0)
df_3 = series_3.to_frame("original_text")
print(df_3.head())

conn.close()

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


                                            original_text
48979   i feel positive about my k at the end of the m...
176784  i am feeling generous i can put up files fille...
148564  i have a feeling the defense is going to come ...
136809  i woke up in my bed alone for the last time fe...
261174  i am actually liking sofia and i feel this is ...


                                           original_text
8524   the problem now which Playboy shows Bugs a fly...
11031  Singing-and-dancing stage star Julie (Betty Gr...
53471  school for admission where she meets George af...
55971  Dollar (played by Pawan Kumar (of Lucia (2013 ...
67650  When the Japan Coast Guard investigates an aba...


                                           original_text
83536  Punjab was divided between the nations of Indi...
62306  Methodology: A body of practices, procedures, ...
16513  A computer-aided translation tool, developed b...
70892  The cause of mental disorders is usually unkno...
84067  This quest allows players access to a secret p...


In [2]:
df = pd.concat([df_1, df_2, df_3], ignore_index=True)
df = df.sample(frac=1).reset_index(drop=True)
print(len(df))

5000


### Assembling Prompts

Now, we will add the dataset containing example rewrite prompts to the notebook. Each chunk of text will be randomly matched up with one of the rewrite prompts. Then, we will place both components into a simple template, to get more consistent results from Gemma.

In [3]:
import random
random.seed(0)

df.insert(1, "rewrite_prompt", value="")
prompts = pd.read_csv("/kaggle/input/prompt-recovery-sample-rewrite-prompts/prompts.csv").iloc[:,0]
for i in df.index:
    df.at[i, "rewrite_prompt"] = random.choice(prompts)
df.head()

Unnamed: 0,original_text,rewrite_prompt
0,i kinda thought that they might be giving some...,Convey the same message as this text but throu...
1,im being honest ive been feeling quite bitchy ...,Rewrite this text in the style of a formal dip...
2,i am trying to feel calmer and more relaxed ev...,Imagine this as a conversation between quirky ...
3,i was called and invited to have a talk about ...,Craft a version of the paragraph suitable for ...
4,off to do the same. After overcoming many chal...,Elevate this text by introducing a compelling ...


In [4]:
template = """Instruction:
{rewrite_prompt}

Original Text:
{original_text}

Response:
"""

df["prompt"] = df.apply(lambda row: template.format(rewrite_prompt=row.rewrite_prompt, 
                                                             original_text=row.original_text), axis=1)

### Inference with Gemma

We are ready to run inference with Gemma and obtain the rewritten texts. Since Kaggle gives access to TPU accelerators, this notebook will showcase how to use TPUs to achieve significant increases in compute power. We will be using Keras with Jax backend. Jax has support for [Model Parallelism](https://huggingface.co/docs/transformers/v4.15.0/en/parallelism) techniques, enabling sharding, or partitioning, of weight and embedding tensors across the 8 TPU cores and distributed computation. This effectively combines the 8 cores' compute power and memory capacities, significantly speeding up computation time for a single batch and allowing larger models to be fit into memory.

In [5]:
!pip install -q tensorflow-cpu
!pip install -q -U keras-nlp tensorflow-hub
!pip install -q -U keras>=3
!pip install -q -U tensorflow-text

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
keras-nlp 0.8.1 requires keras-core, which is not installed.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.5 which is incompatible.
tensorflow 2.15.0 requires ml-dtypes~=0.2.0, but you have ml-dtypes 0.3.2 which is incompatible.
tensorflow 2.15.0 requires tensorboard<2.16,>=2.15, but you have tensorboard 2.16.2 which is incompatible.[0m[31m
[0m


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-cpu 2.16.1 requires keras>=3.0.0, but you have keras 2.15.0 which is incompatible.
tensorflow-cpu 2.16.1 requires ml-dtypes~=0.3.1, but you have ml-dtypes 0.2.0 which is incompatible.
tensorflow-cpu 2.16.1 requires tensorboard<2.17,>=2.16, but you have tensorboard 2.15.2 which is incompatible.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.3.3 which is incompatible.
tensorflow-cpu 2.16.1 requires ml-dtypes~=0.3.1, but you have ml-dtypes 0.2.0 which is incompatible.
tensorflow-cpu 2.16.1 requires tensorboard<2.17,>=2.16, but you have tensorboard 2.15.2 which is incompatible.[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


[0m


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [6]:
import jax

jax.devices()

E0525 04:26:30.349712299     121 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {created_time:"2024-05-25T04:26:30.349696576+00:00", grpc_status:2}


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

We define `device_mesh` and `layout_map` configurations that describe how Gemma's tensors should be sharded. Then, using the **`keras.distribution`** API, we pass these configurations to the Jax backend, which the performs the sharding when the model weights are loaded.

In [7]:
import os
os.environ["KERAS_BACKEND"] = "jax"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9" # Pre-allocate 90% of TPU memory to minimize memory fragmentation and allocation

import keras
import keras_nlp

# Create a device mesh with (1, 8) shape so that the weights are sharded across all 8 TPU cores.
device_mesh = keras.distribution.DeviceMesh(
    (1, 8),
    ["batch", "model"],
    devices=keras.distribution.list_devices())

# Create a layout map and define how each layer's weights should be sharded
layout_map = keras.distribution.LayoutMap(device_mesh)
model_dim = "model"
# Weights that match 'token_embedding/embeddings' will be sharded on 8 TPUs
layout_map["token_embedding/embeddings"] = (None, model_dim)
# Use a regex to match against the query, key and value matrices in the decoder attention layers
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (None, model_dim, None)
# Shards the attention output, feed-forward gating, and feed-forward linear layers of the decoder
layout_map["decoder_block.*attention_output.*kernel"] = (None, None, model_dim)
layout_map["decoder_block.*ffw_gating.*kernel"] = (model_dim, None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, model_dim)

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
model_parallel = keras.distribution.ModelParallel(device_mesh, layout_map, batch_dim_name="batch")

keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_instruct_7b_en")

# Print out information on one decoder block to verify weights were sharded correctly
decoder_block_1 = gemma_lm.backbone.get_layer('decoder_block_1')
print(type(decoder_block_1))
for variable in decoder_block_1.weights:
    print(f'{variable.path:<58}  {str(variable.shape):<16}  {str(variable.value.sharding.spec)}')

Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'task.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'config.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'config.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'config.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'model.weights.h5' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'metadata.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'preprocessor.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'tokenizer.json' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


Attaching 'assets/tokenizer/vocabulary.spm' from model 'keras/gemma/keras/gemma_instruct_7b_en/2' to your Kaggle notebook...


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


<class 'keras_nlp.src.models.gemma.gemma_decoder_block.GemmaDecoderBlock'>
decoder_block_1/pre_attention_norm/scale                    (3072,)           PartitionSpec(None,)
decoder_block_1/attention/query/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/key/kernel                        (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/value/kernel                      (16, 3072, 256)   PartitionSpec(None, 'model', None)
decoder_block_1/attention/attention_output/kernel           (16, 256, 3072)   PartitionSpec(None, None, 'model')
decoder_block_1/pre_ffw_norm/scale                          (3072,)           PartitionSpec(None,)
decoder_block_1/ffw_gating/kernel                           (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_gating_2/kernel                         (3072, 24576)     PartitionSpec('model', None)
decoder_block_1/ffw_linear/kernel                           (

We will also cut Gemma's response off at 512 tokens, to prevent excessively long responses to certain prompts and conserve compute resources.

In [9]:
def generate(row):
    output = gemma_lm.generate(row.prompt, max_length=512)
    return output.replace(row.prompt, "") # Gemma's responses will repeat the entire user prompt. We remove this unnecessary repetition

df["rewritten_text"] = df.apply(generate, axis=1)
df.to_csv("/kaggle/working/training_data.csv", index=False)