# LLMPR - Gemma prompt recovery

As a baseline let the LLM itself recover the prompt. 

I have seen such baseline already in other notebook titles, but I have not looked so far what prompt they use... so let's see. The model is wrapped in langchain for easier use.

In [2]:

import os
import pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Init Gemma

In [5]:
model_path = "/kaggle/input/gemma/transformers/7b-it/2"
print(sorted(os.listdir(model_path)))
print(len(os.listdir(model_path)))

['.gitattributes', 'config.json', 'generation_config.json', 'model-00003-of-00004.safetensors', 'model-00004-of-00004.safetensors', 'model.safetensors.index.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer.model', 'tokenizer_config.json']
10


In [4]:
assert len(os.listdir(model_path)) == 13, "not all models files are present"

AssertionError: not all models files are present

In [31]:
!pip install --no-index --find-links /kaggle/input/llmpr-packages transformers==4.39.0.dev0 accelerate bitsandbytes langchain

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Looking in links: /kaggle/input/llmpr-packages


In [32]:
import transformers
print(transformers.__version__)
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

4.39.0.dev0


In [33]:
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path, 
    quantization_config=quantization_config
)

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [34]:
quantization_config

BitsAndBytesConfig {
  "_load_in_4bit": true,
  "_load_in_8bit": false,
  "bnb_4bit_compute_dtype": "float32",
  "bnb_4bit_quant_type": "fp4",
  "bnb_4bit_use_double_quant": false,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": true,
  "load_in_8bit": false,
  "quant_method": "bitsandbytes"
}

In [35]:
tokenizer

GemmaTokenizerFast(name_or_path='/kaggle/input/gemma/transformers/7b-it/2', vocab_size=256000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<bos>', 'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<start_of_turn>', '<end_of_turn>']}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<eos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("<bos>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	106: AddedToken("<start_of_turn>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	107: AddedToken("<end_of_turn

In [36]:
input_text = "Write a haiku about lions."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=256)
print(tokenizer.decode(outputs[0]))

<bos>Write a haiku about lions.

Golden mane ablaze,
Roaming through the savanna,
King's grace in stride.<eos>


# langchain
- See https://python.langchain.com/docs/modules/model_io/llms/custom_llm
- Effectively only ``_call`` methods is required.

In [40]:
import langchain
print(langchain.__version__)
TOKENIZERS_PARALLELISM= False

0.1.11


In [41]:
from typing import Any, List, Mapping, Optional
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM

In [43]:
class GemmaLLM(LLM):
    cfg: Any
    
    
    def __init__(self, max_new_tokens=256, **kwargs):
        super().__init__(**kwargs)
        # resets generation config
        cfg = model.generation_config.from_pretrained(model_path)
        cfg.max_new_tokens = max_new_tokens
        for k, v in kwargs.items():
            if hasattr(cfg, k):
                setattr(cfg, k, v)
        self.cfg = cfg
        

    @property
    def _llm_type(self) -> str:
        return "custom"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
        outputs = model.generate(**input_ids, generation_config=self.cfg)
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        remove_keys = ["pad_token_id", "bos_token_id", "eos_token_id", '_from_model_config', "transformers_version"]
        return {k: v for k, v in self.cfg.to_diff_dict().items() if k not in remove_keys}
    
    def __repr__(self):
        params = ", ".join(f"{k}={v}" for k, v in self._identifying_params.items())
        return f"Gemma({params})"

In [44]:
llm = GemmaLLM(temperature=1.0, do_sample=True)
llm

Gemma(max_new_tokens=256, do_sample=True)

In [45]:
prompt = "Write a haiku about the Alps."
out = llm.invoke(prompt)
print(out)

Write a haiku about the Alps.

Snowcapped peaks reach high,
Granite walls soar to the sky,
Winter's icy grip.


# Ask for the rewrite prompt

In [46]:
data = pd.read_csv("/kaggle/input/llm-prompt-recovery/test.csv")
data

Unnamed: 0,id,original_text,rewritten_text
0,-1,The competition dataset comprises text passage...,Here is your shanty: (Verse 1) The text is rew...


In [47]:
rec = data.to_dict(orient="records")[0]
rec

{'id': -1,
 'original_text': 'The competition dataset comprises text passages that have been rewritten by the Gemma LLM according to some rewrite_prompt instruction. The goal of the competition is to determine what prompt was used to rewrite each original text.  Please note that this is a Code Competition. When your submission is scored, this example test data will be replaced with the full test set. Expect roughly 2,000 original texts in the test set.',
 'rewritten_text': "Here is your shanty: (Verse 1) The text is rewritten, the LLM has spun, With prompts so clever, they've been outrun. The goal is to find, the prompt so bright, To crack the code, and shine the light. (Chorus) Oh, this is a code competition, my dear, With text and prompts, we'll compete. Two thousand texts, a challenge grand, To guess the prompts, hand over hand.(Verse 2) The original text, a treasure lost, The rewrite prompt, a secret to be"}

In [48]:
from langchain.prompts import PromptTemplate

In [49]:
temp = """
You are an assistant to generate prompt templates for LLM.
Here is an original text and a rewritten text.
Think about a short and concise rewrite prompt that would make a LLM together with the original text provide the rewritten text.
This prompt is at most 3 sentences long. 
Only return this rewrite prompt.

Original text:::
{original_text}

Rewritten text:::
{rewritten_text}

Rewrite prompt:::
"""
temp = PromptTemplate.from_template(temp)
temp

PromptTemplate(input_variables=['original_text', 'rewritten_text'], template='\nYou are an assistant to generate prompt templates for LLM.\nHere is an original text and a rewritten text.\nThink about a short and concise rewrite prompt that would make a LLM together with the original text provide the rewritten text.\nThis prompt is at most 3 sentences long. \nOnly return this rewrite prompt.\n\nOriginal text:::\n{original_text}\n\nRewritten text:::\n{rewritten_text}\n\nRewrite prompt:::\n')

In [50]:
print(temp.format(**rec))


You are an assistant to generate prompt templates for LLM.
Here is an original text and a rewritten text.
Think about a short and concise rewrite prompt that would make a LLM together with the original text provide the rewritten text.
This prompt is at most 3 sentences long. 
Only return this rewrite prompt.

Original text:::
The competition dataset comprises text passages that have been rewritten by the Gemma LLM according to some rewrite_prompt instruction. The goal of the competition is to determine what prompt was used to rewrite each original text.  Please note that this is a Code Competition. When your submission is scored, this example test data will be replaced with the full test set. Expect roughly 2,000 original texts in the test set.

Rewritten text:::
Here is your shanty: (Verse 1) The text is rewritten, the LLM has spun, With prompts so clever, they've been outrun. The goal is to find, the prompt so bright, To crack the code, and shine the light. (Chorus) Oh, this is a co

In [51]:
chain = temp | llm

In [52]:
out = chain.invoke(rec)
print(out)


You are an assistant to generate prompt templates for LLM.
Here is an original text and a rewritten text.
Think about a short and concise rewrite prompt that would make a LLM together with the original text provide the rewritten text.
This prompt is at most 3 sentences long. 
Only return this rewrite prompt.

Original text:::
The competition dataset comprises text passages that have been rewritten by the Gemma LLM according to some rewrite_prompt instruction. The goal of the competition is to determine what prompt was used to rewrite each original text.  Please note that this is a Code Competition. When your submission is scored, this example test data will be replaced with the full test set. Expect roughly 2,000 original texts in the test set.

Rewritten text:::
Here is your shanty: (Verse 1) The text is rewritten, the LLM has spun, With prompts so clever, they've been outrun. The goal is to find, the prompt so bright, To crack the code, and shine the light. (Chorus) Oh, this is a co

In [53]:
out.split("Rewrite prompt:::")[-1].strip()

'Sure, here is the prompt:\n\n**Rewrite the original text provided below into a creative and engaging story.**\n\nPlease provide the original text below:\n\n(Original text)\n\nThe competition dataset comprises text passages that have been rewritten by the Gemma LLM according to some rewrite_prompt instruction. The goal of the competition is to determine what prompt was used to rewrite each original text.\n\n\n**Note:** This is a code competition, so please expect the test data to be replaced with the full test set once your submission is scored.'

# Invoke

In [54]:
sub = pd.read_csv("/kaggle/input/llm-prompt-recovery/sample_submission.csv")
sub

Unnamed: 0,id,rewrite_prompt
0,9559194,Improve that text.


In [55]:
default_rewrite = "Improve that text."

In [56]:
results = []
for rec in data.to_dict(orient="records"):
    out = chain.invoke(rec)
    res = out.split("Rewrite prompt:::")
    if len(res) != 2:
        res = default_rewrite
    else:
        res = res[-1].strip()
    results.append({
        "id": rec["id"],
        "rewrite_prompt": res
    })

In [57]:
sub = pd.DataFrame(results)
sub

Unnamed: 0,id,rewrite_prompt
0,-1,"Sure, here is the prompt:\n\n**Rewrite the abo..."


In [58]:
sub.to_csv("submission.csv", index=False)