In [None]:
# %%capture
# If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed
!pip install --no-index /kaggle/input/download-pacakages-for-llm/unsloth/peft-0.9.0-py3-none-any.whl --find-links=/kaggle/input/download-pacakages-for-llm/unsloth
!pip install --no-index /kaggle/input/making-wheels-of-necessary-packages-for-hf-llms/bitsandbytes-0.42.0-py3-none-any.whl --find-links=/kaggle/input/making-wheels-of-necessary-packages-for-hf-llms
!pip install --no-index /kaggle/input/making-wheels-of-necessary-packages-for-hf-llms/accelerate-0.27.2-py3-none-any.whl --find-links=/kaggle/input/making-wheels-of-necessary-packages-for-hf-llms
!pip install --no-index /kaggle/input/download-pacakages-for-llm/unsloth/transformers-4.38.2-py3-none-any.whl --find-links=/kaggle/input/download-pacakages-for-llm/unsloth
!pip install --no-index /kaggle/input/making-wheels-of-necessary-packages-for-hf-llms/optimum-1.17.1-py3-none-any.whl --find-links=/kaggle/input/making-wheels-of-necessary-packages-for-hf-llms

In [None]:
from accelerate.utils import BnbQuantizationConfig
from accelerate import Accelerator
import transformers
import optimum
import bitsandbytes

In [None]:
import datetime
start_time = datetime.datetime.now()

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, StoppingCriteria

accelerator = Accelerator()

# Comment/Uncomment and use as per wish

MODEL_PATH = "/kaggle/input/gemma-7b-instruction"
# MODEL_PATH = "/kaggle/input/gemma/transformers/2b-it/2"
# MODEL_PATH = "/kaggle/input/mistral/pytorch/7b-instruct-v0.1-hf/1"
# MODEL_PATH = "/kaggle/input/mixtral/pytorch/8x7b-instruct-v0.1-hf/1"
# MODEL_PATH = "/kaggle/input/llama-2/pytorch/7b-chat-hf/1"
# MODEL_PATH = "/kaggle/input/llama-2/pytorch/13b-chat-hf/1"

# Found a good blog to catch me up fast!
# https://huggingface.co/blog/4bit-transformers-bitsandbytes
# https://huggingface.co/docs/transformers/v4.38.1/en/quantization#compute-data-type
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)


model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=quantization_config,
)

# model = model.to_bettertransformer()
model = accelerator.prepare(model)

In [None]:
from peft import PeftModel
lora_model = PeftModel.from_pretrained(
    model,
    "/kaggle/input/cot-1k-v0/checkpoint-200"
)

In [None]:
import pandas as pd
from tqdm import tqdm

DEBUG = False

TEST_DF_FILE = "/kaggle/input/llm-prompt-recovery/test.csv"
SUB_DF_FILE = "/kaggle/input/llm-prompt-recovery/sample_submission.csv"
NROWS = None if DEBUG else None

if DEBUG:
    TEST_DF_FILE = "/kaggle/input/gemma-rewrite-nbroad/nbroad-v1.csv"
    SUB_DF_FILE = TEST_DF_FILE

tdf = pd.read_csv(TEST_DF_FILE, nrows=NROWS, usecols=["id", "original_text", "rewritten_text"])
sub = pd.read_csv(SUB_DF_FILE, nrows=NROWS, usecols=["id", "rewrite_prompt"])

In [None]:
# examples = pd.read_csv("/kaggle/input/gemini-dataset-3-8-k/gemini_dataset_v0.csv")
def truncate_txt(text, length):
    text_list = text.split()
    
    if len(text_list) <= length:
        return text
    
    return " ".join(text_list[:length])

USER_CHAT_TEMPLATE = """<start_of_turn>user\nYou'll be given an original text and a rewritten text generated by an LLM. 
Analyze the changes in the rewritten version and infer the likely prompt that led to those changes. 
Provide a detailed explanation of how you arrived at your inference step by step.

**Original Text**:
{}

**Rewritten Text**
{}

You should response in the following format:
**Inferred Promp**: ...

**Chain of Thoughts**: ...<end_of_turn>\n<start_of_turn>model\n"""


# USER_CHAT_TEMPLATE = """<start_of_turn>user\nYou'll be given an original text and a rewritten text generated by an LLM. 
# Analyze the stylistic changes in the rewritten version and identify the likely prompt that led to those changes. 
# Notice only output the prompt.
# Here's what to focus on:
# Shifts in Style: Look for changes in:
# -Genre: (sci-fi, fantasy, historical fiction, etc.)
# -Tone: (serious, humorous, conversational, etc.)
# -Vocabulary: (formal vs. informal, technical vs. simple)
# -Sentence Structure: (short and direct vs. flowing and complex)
# Literary References: Consider if the rewritten style echoes a specific author or literary period (e.g., Shakespearean, Hemingway-esque).
# Here is an example.
# **Example Original Text**:
# {}

# **Example Rewritten Text**:
# {}

# **Example Output**:
# {}

# **Original Text**:
# {}

# **Rewritten Text**
# {}
# <end_of_turn>\n<start_of_turn>model\n"""

In [None]:
stop_words_ids = 2
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList, STOPPING_CRITERIA_INPUTS_DOCSTRING, add_start_docstrings

stopping_criteria = StoppingCriteriaList()
class StopAtSpecificTokenCriteria(StoppingCriteria):
    """
    当生成出第一个指定token时，立即停止生成
    ---------------
    ver: 2023-08-02
    by: changhongyu
    """
    def __init__(self, token_id_list):
        """
        :param token_id_list: 停止生成的指定token的id的列表
        """
        self.token_id_list = token_id_list
        
    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        # return np.argmax(scores[-1].detach().cpu().numpy()) in self.token_id_list
        # 储存scores会额外占用资源，所以直接用input_ids进行判断
        return input_ids[0][-1].detach().cpu().numpy() in self.token_id_list
stopping_criteria.append(StopAtSpecificTokenCriteria(token_id_list=[108,109]))

In [None]:
import gc
import re
import random

device = accelerator.device
tdf["id"] = sub["id"].copy()

pbar = tqdm(total=tdf.shape[0])

it = iter(tdf.iterrows())
idx, row = next(it, (None, None))

DEFAULT_TEXT = "Please improve the following text using the writing style of, maintaining the original meaning but altering the tone, diction, and stylistic elements to match the new style.Enhance the clarity, elegance, and impact of the following text by adopting the writing style of , ensuring the core message remains intact while transforming the tone, word choice, and stylistic features to align with the specified style."

res = []

while idx is not None:
    
    if (datetime.datetime.now() - start_time) > datetime.timedelta(hours=8, minutes=30):
        res.append([row["id"], DEFAULT_TEXT])
        idx, row = next(it, (None, None))
        pbar.update(1)
        continue
        
    torch.cuda.empty_cache()
    gc.collect()
        
    try:
#         random_idx = random.randint(0, len(examples) - 1)
#         eot = truncate_txt(examples.loc[random_idx, "original_text"], 100)
#         ert = truncate_txt(examples.loc[random_idx, "rewritten_text"], 100)
#         e_ans = examples.loc[random_idx, "rewrite_prompt"]
        prompt = USER_CHAT_TEMPLATE.format(truncate_txt(row["original_text"], 400), truncate_txt(row["rewritten_text"], 400))
        prompt_tokenized=tokenizer(prompt, return_tensors="pt").to("cuda") 
        
        with torch.no_grad():
            output_tokenized = lora_model.generate(**prompt_tokenized, max_new_tokens=50, use_cache=True, stopping_criteria=stopping_criteria)[0] 
        # remove prompt from output  
        output_tokenized=output_tokenized[len(prompt_tokenized["input_ids"][0]):] 
        decoded_output = tokenizer.decode(output_tokenized)
#         res.append([row["id"], decoded_output.split("<end_of_turn>")[0]])
        out =  decoded_output.split("**Inferred Prompt**:")[1].split("\n")[0].strip()
        if "a" <= out[-1] <= "z":
            out = out + "."
        res.append([row["id"], out])
                            
    except Exception as e:
        print(f"ERROR: {e}")
        res.append([row["id"], DEFAULT_TEXT])
        
    finally:
        idx, row = next(it, (None, None))
        pbar.update(1)

        
pbar.close()

In [None]:
res

In [None]:
sub = pd.DataFrame(res, columns=["id", "rewrite_prompt"])

sub.to_csv("sample_submission.csv", index=False)
sub.to_csv("submission.csv", index=False)

In [None]:
res