# Environment

In [32]:
# https://www.kaggle.com/code/hotchpotch/llm-detect-pip 
!pip install -q -U accelerate --no-index --find-links ../input/llm-detect-pip/
!pip install -q -U bitsandbytes --no-index --find-links ../input/llm-detect-pip/
!pip install -q -U transformers --no-index --find-links ../input/llm-detect-pip/

In [33]:
import torch
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [61]:
# define the runtime mode
from enum import Enum
class Mode(Enum):
    SUBMISSION = 0
    DEBUG = 1
MODE = Mode.SUBMISSION

#  Data Preparation

In [62]:
import numpy as np
import pandas as pd

match MODE:
    case Mode.DEBUG:   
        train = pd.read_csv("/kaggle/input/gemma-rewrite-nbroad/nbroad-v1.csv")
        gross_test = pd.read_csv("/kaggle/input/gemma-rewrite-nbroad/nbroad-v2.csv")
        indexes = np.random.randint(0, gross_test.shape[1], 20)
        gross_test = gross_test.iloc[indexes]
        test = gross_test[['id', 'original_text', 'rewritten_text']]
        target = gross_test[['id', 'rewrite_prompt']]
    case Mode.SUBMISSION:
        train = pd.read_csv("/kaggle/input/llm-prompt-recovery/train.csv")
        test = pd.read_csv("/kaggle/input/llm-prompt-recovery/test.csv")
    case _:
        print("Mode error")
        raise Exception("Mode Error : Please choose a mode")

In [63]:
from IPython.display import display

# preview
if MODE==Mode.DEBUG:
    display(test.head())

# Build instruction

In [None]:
# # Template - chat format version 1
# def template(Ot, Rt, Rp=None):
#     return [
#         {"role": "user", "content": f"Original text:\n{Ot}"},
#         {"role": "assistant", "content": "Provide the rewritten test"}, # new text and I will tell you the prompt that can generate the original text to the new text."},
#         {"role": "user", "content": f"Rewritten text:\n{Rt}"},
#         {"role": "assistant", "content": f"Give the instruction"},
#         {"role": "user", "content": f"Give the prompt that can generate the original text to the rewritten text"},
#     ] + ([{"role": "assistant", "content": Rp}] if Rp else [])

In [64]:
# Template - chat format version 2
def template(Ot, Rt, Rp=None):
    return [
        {
            "role": "user", 
            "content": f"<original_text>\n{Ot}\n</original_text>\n"\
                        f"<rewritten_text>\n{Rt}\n</rewritten_text>\n" \
                        f"Write a prompt that was likely given to the LLM to rewrite original_text into rewritten_text."
        }
    ] + ([{"role": "assistant", "content": Rp}] if Rp else [])

In [65]:
import numpy as np
def make_instruction(sentence, num_example=5, train_df=train):
    # choose 5 example randomly from train
    indexes = np.random.randint(0, train_df.shape[0], num_example)
    instruction = []
    for item in train_df.iloc[indexes].iloc:
        instruction += template(item['original_text'], item['rewritten_text'], item['rewrite_prompt'])
    # then give the test instruction
    return instruction + template(sentence['original_text'], sentence['rewritten_text'])

In [66]:
# test
if MODE == Mode.DEBUG:
    display(make_instruction(test.iloc[0]))

# Load Model

In [41]:
# model_name  = "/kaggle/input/mistral/pytorch/7b-instruct-v0.1-hf/1"
model_name  = "/kaggle/input/mistral-7b-it-v02"

In [42]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=quantization_config,
    low_cpu_mem_usage=True,
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed.


In [43]:
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Prediction

In [44]:
# import torch
# def get_inputs_batch(sub_df, tokenizer=tokenizer):
#     tokenizer.pad_token = tokenizer.eos_token
#     return [tokenizer.apply_chat_template(make_instruction(line), return_tensors="pt") for line in sub_df.iloc]

In [45]:
# predict batch
# def predict(df, model=model, tokenizer=tokenizer, batch_size=None, device=device):
# #     if batch_size:
# #         index_loader = DataLoader([i for i in range(df.size)], batch_size=batch_size)
# #         for indexes in tqdm(index_loader):
# #             batch_df = df.iloc[indexes]
#     prompts = []
#     for line in tqdm(df.iloc):   
#         messages = make_instruction(line)
#         tokenizer.pad_token = tokenizer.eos_token
#         inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
#         inputs = inputs.to(device)
#         with torch.no_grad():
#             outputs = model.generate(inputs, max_new_tokens=500, pad_token_id=tokenizer.eos_token_id).cpu()
#         answer = tokenizer.batch_decode(outputs)[0]
#         try:
#             result = answer.split("[/INST]")[-1].replace("</s>","").strip()
#         except:
#             print("<START>\nThis context is SB.\n"+answer+"\n<END>")
# #             raise Exception("Sorry, this sentence error")
#             result = answer
#         prompts.append(result)
#     return prompts


In [70]:
def predict(line, model=model, tokenizer=tokenizer, device=device, num_example=5): 
    messages = make_instruction(line, num_example=num_example)
    tokenizer.pad_token = tokenizer.eos_token
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
    if MODE == Mode.DEBUG:
        print(inputs.shape)
    inputs = inputs.to(device)
    with torch.no_grad():
        outputs = model.generate(inputs, max_new_tokens=30, pad_token_id=tokenizer.eos_token_id).cpu()
    answer = tokenizer.batch_decode(outputs)[0]
    try:
        return answer.split("[/INST]")[-1].replace("</s>","").strip()
    except:
        print("<START>\nThis context is SB.\n"+answer+"\n<END>")
#             raise Exception("Sorry, this sentence error")
        return answer

In [71]:
# test one
if MODE==Mode.DEBUG:
    prompts = predict(test.iloc[12])
    print(prompts)

In [72]:
# apply predict function to each line
from tqdm import tqdm
tqdm.pandas()
test['rewrite_prompt'] = test.progress_apply(lambda s : predict(s, model, tokenizer, num_example=10), axis=1)

100%|██████████| 1/1 [00:14<00:00, 14.79s/it]


# Evaluation

In [73]:
# import swifter
# tic = time.time()
# test['rewrite_prompt'] = test.swifter.apply(lambda s : predict(s, model, tokenizer), axis=1)
# toc = time.time()
# print((toc-tic)/60)

In [74]:
# preview
if MODE == Mode.DEBUG:
    display(test.head())

In [75]:
# calculate score
# load sentence-t5 base for embedding
# todo
# define sharped cosine simularity
# todo
# apply

In [76]:
# save the result as csv file 
if MODE == Mode.SUBMISSION:
    test[['id', 'rewrite_prompt']].to_csv('submission.csv', header=True, index=False)

# Verification

In [77]:
if MODE == Mode.SUBMISSION:
    display(pd.read_csv("/kaggle/working/submission.csv").head())

Unnamed: 0,id,rewrite_prompt
0,-1,"Convert this into a sea shanty: """"""The competi..."
