# LLM Prompt Recovery with Gemma

<div align="center">
    <img src="https://i.ibb.co/8xZNc32/Gemma.png">
</div>

**The challenge:** Recover the LLM prompt used to rewrite a given text.

**KerasNLP Models:** [KerasNLP website](https://keras.io/api/keras_nlp/models/)

**About Gemma Models:** Gemma is a collection of advanced open LLMs developed by `Google DeepMind` and other `Google teams`, derived from the same research and technology behind the `Gemini models`.

| Parameters size | Tuned versions    | Intended platforms                 | Preset                 |
|-----------------|-------------------|------------------------------------|------------------------|
| 2B              | Pretrained        | Mobile devices and laptops         | `gemma_2b_en`          |
| 2B              | Instruction tuned | Mobile devices and laptops         | `gemma_instruct_2b_en` |
| 7B              | Pretrained        | Desktop computers and small servers| `gemma_7b_en`          |
| 7B              | Instruction tuned | Desktop computers and small servers| `gemma_instruct_7b_en` |

**Datasets:**
1) [Rewritten texts with Gemma 2B](https://www.kaggle.com/datasets/juanmerinobermejo/rewritten-texts-with-gemma-2b) | Credits: [Juan Merino](https://www.kaggle.com/juanmerinobermejo)

2) [gemma-rewrite-nbroad](https://www.kaggle.com/datasets/nbroad/gemma-rewrite-nbroad) | Credits: [Nicholas Broad](https://www.kaggle.com/nbroad)

3) [LLM Prompt Recovery - Synthetic Datastore](https://www.kaggle.com/datasets/dschettler8845/llm-prompt-recovery-synthetic-datastore) | Credits: [Darien Schettler](https://www.kaggle.com/dschettler8845)

4) [llm-prompt-recovery-data](https://www.kaggle.com/datasets/thedrcat/llm-prompt-recovery-data) | Credits: [Darek Kłeczek](https://www.kaggle.com/thedrcat)

5) [3000 Rewritten texts - Prompt recovery Challenge](https://www.kaggle.com/datasets/dipamc77/3000-rewritten-texts-prompt-recovery-challenge) | Credits: [Dipam Chakraborty](https://www.kaggle.com/dipamc77)


**Evaluation:** For each row in the submission and corresponding ground truth, [sentence-t5-base](https://www.kaggle.com/models/google/sentence-t5/frameworks/tensorFlow2/variations/st5-base) is used to calculate corresponding embedding vectors. The score for each predicted / expected pair is calculated using the [Sharpened Cosine Similarity](https://github.com/brohrer/sharpened-cosine-similarity), using an exponent of `3`. The SCS is used to attenuate the generous score given by embedding vectors for incorrect answers. Do not leave any `rewrite_prompt` blank as null answers will throw an error.

**Requirements:**
* CPU Notebook <= 9 hours run-time
* GPU Notebook <= 9 hours run-time
* Internet access disabled
* Freely & publicly available external data is allowed, including pre-trained models
* Submission file must be named submission.csv
* Submission runtimes have been slightly obfuscated. If you repeat the exact same submission you  will see up to 15 minutes of variance in the time before you receive your score.


*THIS NOTEBOOK IS BASED ON THE "Prompt Recovery with Gemma - KerasNLP Starter" NOTEBOOK! [LINK](https://www.kaggle.com/code/awsaf49/prompt-recovery-with-gemma-kerasnlp-starter)*

**Credits:** Awsaf (Owner); Ashley Chow (Editor); fchollet (Editor); Gusthema (Editor); Martin Görner (Editor); Paul Mooney (Editor); Phil Culliton (Editor).

# 1. Setup Modules & Dependencies

In [None]:
# Ignore Warnings
from warnings import filterwarnings
filterwarnings('ignore')

In [None]:
# Setup Environment [OS]
import os
os.environ["KERAS_BACKEND"] = "jax" # JAX backend for the best performance
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00" # to avoid memory fragmentation on JAX backend

# Natural Language Processing Modules and Machine Learning Modules
import keras
import keras_nlp

# Numerical Data Processing Modules
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
tqdm.pandas() # progress bar for pandas

# Data Visualization Modules
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_theme(style='whitegrid', palette='viridis')

# Markdown Display Modules
from IPython.display import display, Markdown

# 2. Configuration

In [None]:
# Setting up the configuration class
class CFG:
    seed = 42
    dataset_path = "/kaggle/input/llm-prompt-recovery"
    preset = "gemma_instruct_2b_en" # name of pretrained Gemma
    sequence_length = 850 # max size of input sequence for training
    # We will using a different sequence_length for our dataset which will be the mean of the maximum sequence lengths of the final dataset
    batch_size = 1 # size of the input batch in training
    epochs=2 # for our training purpose as 10000 X 1 = 10000 seconds will be required for 1 epoch and hence for 2 epochs a training time of 20000s will be required which is close to 7 hours

# 3. Reproducibility 
Sets value for random seed to produce similar result in each run.

In [None]:
def random_setup(seed):
    np.random.seed(seed)
    keras.utils.set_random_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    
random_setup(CFG.seed)

# 4. Data Upload

**Data Format:**

These datasets includes:
- `original_text`: Input text/essay that needs to be transformed.
- `rewrite_prompt`: Prompt/Instruction that was used in the Gemma LM to transform `original_text`. This is also our **target** for this competition.
- `rewritten_text`: Output text that was generated by the Gemma model.

In [None]:
df1=pd.read_csv("/kaggle/input/rewritten-texts-with-gemma-2b/rewritten_texts_csv.csv",encoding = 'latin-1')
df1=df1[['original_text','prompt','rewritten_text']]
df1=df1.rename(columns={'prompt':'rewrite_prompt'})
df1=df1.head(11000)

df2=pd.read_csv("/kaggle/input/gemma-rewrite-nbroad/nbroad-v2.csv")
df2=df2[['original_text','rewrite_prompt','rewritten_text']]

# `LLM Prompt Recovery - Synthetic Datastore dataset` by @dschettler8845
df3 = pd.read_csv("/kaggle/input/llm-prompt-recovery-synthetic-datastore/gemma1000_w7b.csv")
df3 = df3[["original_text", "rewrite_prompt", "gemma_7b_rewritten_text_temp0"]]
df3 = df3.rename(columns={"gemma_7b_rewritten_text_temp0":"rewritten_text"})

# `3000 Rewritten texts - Prompt recovery Challenge` by @dipamc77
df4 = pd.read_csv("/kaggle/input/3000-rewritten-texts-prompt-recovery-challenge/prompts_0_500_wiki_first_para_3000.csv")

# We will also use a kaggle dataset which is known as llm-prompt-recovery-data and available at https://www.kaggle.com/datasets/thedrcat/llm-prompt-recovery-data
df5=pd.read_csv("/kaggle/input/llm-prompt-recovery-data/gemma10000.csv")
df5=df5[['original_text','rewrite_prompt','rewritten_text']]
df5=df5.head(7500) 

df = pd.concat([df1, df2, df3, df4, df5], axis=0).dropna().reset_index(drop=True)
df=df.dropna().reset_index(drop=True)
df.drop_duplicates() # remove duplicate entries for training
df

Cleaning the Data

In [None]:
import re
def clean_text(text):
    
    text = text.replace("\n", "")
    
    text = re.sub(r'\*\*.*?\*\*', '', text)
    return text
df['original_text']=df['original_text'].apply(clean_text)
df['rewritten_text'] = df['rewritten_text'].apply(clean_text)

In [None]:
def remove_symbols(text):
    # Define regular expression pattern to match unnecessary symbols
    pattern = r'[^\w\s]'
    # Use re.sub() to replace matched symbols with an empty string
    cleaned_text = re.sub(pattern, '', text)
    return cleaned_text
df['original_text']= df['original_text'].apply(remove_symbols)
df['rewritten_text']= df['rewritten_text'].apply(remove_symbols)

In [None]:
df

In [None]:
# make a new column max_len containing the max length among all the columns of each row
df['max_len'] = df.apply(lambda row: max(len(row['original_text']), len(row['rewrite_prompt']), len(row['rewritten_text'])), axis=1)

In [None]:
# Filter the new column by allowing all those lengths which are less than or equal to 2000
df = df[df['max_len'] <= 850]

In [None]:
# Set background gradient style for description plot
df.describe().T.style.background_gradient(cmap='Blues')

In [None]:
# Plotting the max length
plt.figure(figsize=(8, 6))
sns.histplot(df['max_len'], bins=20, color='skyblue', edgecolor='black', kde=False)
plt.title('Distribution of Maximum Lengths')
plt.xlabel('Maximum Length')
plt.ylabel('Frequency')
plt.grid(True)
plt.show()

In [None]:
df['max_len'].mean()

In [None]:
shuffled_df = df.sample(frac=1, random_state=42)
df=shuffled_df.reset_index(drop=True)
df

In [None]:
df=df.dropna().reset_index(drop=True)
df

# 5. Prompt Engineering

Here's a *custom prompt template* we'll use to create instruction-response pairs from the `original_text`, `rewritten_text`, and `rewrite_prompt`

In [None]:
# Advanced prompt engineering template

# template = """Instruction:\nBelow, you'll find two texts: `Original Text` and `Rewritten Text`. The `Rewritten Text` has been generated from the `Original Text` using the LLM model Gemma 7b-it. Your task is to carefully analyze the similarities and differences between the two texts to determine the instruction or hint given to the LLM model to rewrite or transform the text.\n
# Please consider the following 8 questions while analyzing the texts:\n
# 1) What linguistic features have changed between the original and rewritten text, including sentence structure, vocabulary, tone, and style?\t
# 2) Are there any recurring patterns or repetitions in the rewritten text that suggest specific transformations applied by the model?\t
# 3) Does the rewritten text demonstrate a preference for particular word choices or syntactic structures?\t
# 4) How do the length and complexity of the rewritten text compare to the original text?\t
# 5) What contextual clues or hints in the original text may have influenced the transformation process?\t
# 6) How coherent and cohesive is the rewritten text compared to the original?\t
# 7) How readable and fluent is the rewritten text in effectively conveying the intended message?\t
# 8) Are there any implicit biases or sociocultural influences evident in the rewriting process?\n
# By carefully examining these aspects, try to deduce the underlying instruction or hint that guided the LLM model in rewriting the text from `Original Text` to `Rewritten Text`.
# \n\nOriginal Text:\n{original_text}
# \n\nRewriten Text:\n{rewritten_text}
# \n\nResponse:\n{rewrite_prompt}"""

template="""Instruction:\nYou are a skilled language model analyst with experience in text transformations. Help me determine the prompt an LLM might have used to rewrite an original text into a rewritten text.\n
When analyzing the two texts, please consider the following factors:\n
- Changes in tone and style.\t
- Shifts in language or vocabulary.\t
- Modifications to sentence structure or syntax.\t
- Additions, omissions, or alterations in content.\t
- Adjustments to the overall message or purpose of the text.\n
Based on your analysis, suggest the potential rewrite prompt that the LLM might have used to guide the transformation from the original text to the rewritten text.
\n\nOriginal Text:\n{original_text}
\n\nRewriten Text:\n{rewritten_text}
\n\nResponse:\n{rewrite_prompt}"""

In [None]:
# Make a new column in the dataframe for storing the prompt template
df["prompt"] = df.progress_apply(lambda row: template.format(original_text=row.original_text,
                                                             rewritten_text=row.rewritten_text,
                                                             rewrite_prompt=row.rewrite_prompt), axis=1)
# Convert the prompt dataframe into a list for feeding into the model for training
data = df.prompt.tolist()

Let's examine a sample prompt. As the answers in our dataset are curated with **markdown** format, we will render the sample using `Markdown()` to properly visualize the formatting.

In [None]:
data[:10]

## 5.1. Sample

In [None]:
# Text colorization function
def colorize_text(text):
    for word, color in zip(["Instruction", "Original Text", "Rewriten Text", "Response"],
                           ["red", "purple", "blue", "green"]):
        text = text.replace(f"{word}:", f"\n\n**<font color='{color}'>{word}:</font>**")
    return text

In [None]:
# Take a random sample
sample = data[10]

# Give colors to Instruction, Response and Category
sample = colorize_text(sample)

# Show sample in markdown
display(Markdown(sample))

In [None]:
# Initialising the model and getting a rough summary
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset(CFG.preset)
gemma_lm.summary()

# 6. Inference before Fine-Tuning

Before we do fine-tuning, let's try to recover the prompt using the Gemma model with some prepared prompts and see how it responds.

> As this model is not yet fine-tuned for instruction, you will notice that the model's responses are inaccurate.

## 6.1. Sample 1

In [None]:
# Take one sample
row = df.iloc[10]

# Generate Prompt using template
prompt = template.format(
    original_text=row.original_text,
    rewritten_text=row.rewritten_text,
    rewrite_prompt="",
)

# Infer
output = gemma_lm.generate(prompt, max_length=CFG.sequence_length)

# Colorize
output = colorize_text(output)

# Display in markdown
display(Markdown(output))


## 6.2. Sample 2

In [None]:
# Take one sample
row = df.iloc[20]

# Generate Prompt using template
prompt = template.format(
    original_text=row.original_text,
    rewritten_text=row.rewritten_text,
    rewrite_prompt="",
)

# Infer
output = gemma_lm.generate(prompt, max_length=CFG.sequence_length)

# Colorize
output = colorize_text(output)

# Display in markdown
display(Markdown(output))


# 7. Fine-tuning with LoRA

In [None]:
# Enable LoRA for the model and set the LoRA rank to 8.
gemma_lm.backbone.enable_lora(rank=8) 
gemma_lm.summary()

**Notice** that, the number of trainable parameters is reduced from ~$2.5$ billions to ~$2.7$ millions after enabling LoRA.

## 8. Training

In [None]:
# Limit the input sequence length (to control memory usage).
gemma_lm.preprocessor.sequence_length = CFG.sequence_length 

optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
    )

optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Train model
gemma_lm.fit(data, epochs=CFG.epochs, batch_size=CFG.batch_size, shuffle=True)

# 9. Inference after fine-tuning

Let's see how our fine-tuned model responds to the same questions we asked before fine-tuning the model.

## 9.1. Sample 1

In [None]:
# Take one sample
row = df.iloc[10]

# Generate Prompt using template
prompt = template.format(
    original_text=row.original_text,
    rewritten_text=row.rewritten_text,
    rewrite_prompt="",
)

# Infer
output = gemma_lm.generate(prompt, max_length=CFG.sequence_length)

# Colorize
output = colorize_text(output)

# Display in markdown
display(Markdown(output))

## 9.2. Sample 2

In [None]:
# Take one sample
row = df.iloc[20]

# Generate Prompt using template
prompt = template.format(
    original_text=row.original_text,
    rewritten_text=row.rewritten_text,
    rewrite_prompt="",
)

# Infer
output = gemma_lm.generate(prompt, max_length=CFG.sequence_length)

# Colorize
output = colorize_text(output)

# Display in markdown
display(Markdown(output))


# 10. Test Data

In [None]:
# Reading the text data
test_df = pd.read_csv("/kaggle/input/llm-prompt-recovery/test.csv")
test_df['original_text'] = test_df['original_text'].fillna("")
test_df['rewritten_text'] = test_df['rewritten_text'].fillna("")
test_df.head()

## 10.1. Test Sample

Now, let's try out a sample from test data that model hasn't seen during training.

In [None]:
# Loading a sample prompt from the test data
row = test_df.iloc[0]

# Generate Prompt using template
prompt = template.format(
    original_text=row.original_text,
    rewritten_text=row.rewritten_text,
    rewrite_prompt="",
)

# Infer
output = gemma_lm.generate(prompt, max_length=CFG.sequence_length)

# Colorize
output = colorize_text(output)

# Display in markdown
display(Markdown(output))

# 11. Submission

In [None]:
# For storing the predictions
preds = []
for i in tqdm(range(len(test_df))):
    row = test_df.iloc[i]

    # Generate Prompt using template
    prompt = template.format(
        original_text=row.original_text,
        rewritten_text=row.rewritten_text,
        rewrite_prompt=""
    )

    # Infer
    output = gemma_lm.generate(prompt, max_length=CFG.sequence_length)
    pred = output.replace(prompt, "") # remove the prompt from output
    
    # Store predictions
    preds.append([row.id, pred])

While preparing the submission file, we must keep in mind that, leaving any `rewrite_prompt` blank as null answers will throw an error.

In [None]:
sub_df = pd.DataFrame(preds, columns=["id", "rewrite_prompt"])
sub_df['rewrite_prompt'] = sub_df['rewrite_prompt'].fillna("")
sub_df['rewrite_prompt'] = sub_df['rewrite_prompt'].map(lambda x: "Make this text the best text ever!" if len(x) == 0 else x)
sub_df.to_csv("submission.csv",index=False)
sub_df.head()

# 12. Save the Custom Model

In [None]:
gemma_lm.save("gemma_prompt_recovery_master.keras")

# 13. Reference
* [Fine-tune Gemma models in Keras using LoRA](https://www.kaggle.com/code/nilaychauhan/fine-tune-gemma-models-in-keras-using-lora)
* [Parameter-efficient fine-tuning of GPT-2 with LoRA](https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/)
* [Gemma - KerasNLP](https://keras.io/api/keras_nlp/models/gemma/)