In [1]:
!pip install -U /kaggle/input/bitsandbytes-0-42-0-py3-none-any-whl/bitsandbytes-0.42.0-py3-none-any.whl -qq

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig


# MODEL_PATH = "/kaggle/input/mixtral/pytorch/8x7b-instruct-v0.1-hf/1"
MODEL_PATH = "/kaggle/input/mixtral/pytorch/8x7b-instruct-v0.1-hf/1"

quantization_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# To prevent GPU memory overflow in Mixtral8x7b
config = AutoConfig.from_pretrained(MODEL_PATH)
config.gradient_checkpointing = True


tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    device_map = "auto",
    trust_remote_code = True,
    quantization_config=quantization_config,
    config=config
)

Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

In [3]:
import pandas as pd
from tqdm import tqdm

tdf = pd.read_csv('/kaggle/input/llm-prompt-recovery/test.csv')
display(tdf)

Unnamed: 0,id,original_text,rewritten_text
0,-1,The competition dataset comprises text passage...,Here is your shanty: (Verse 1) The text is rew...


In [4]:
def truncate_txt(text, length):
    text_list = text.split()
    if len(text_list) <= length:
        return text    
    return " ".join(text_list[:length])


def gen_prompt_sample(og_text, rewritten_text):
    og_text = truncate_txt(og_text, 256)
    rewritten_text = truncate_txt(rewritten_text, 256)
    
    return f"""
    Original Essay:
    \"""{og_text}\"""

    Rewritten Essay:
    \"""{rewritten_text}\"""

    Given are 2 essays, the Rewritten essay was created from the Original essay using the google Gemma model.
    You are trying to understand how the original essay was transformed into a new version.
    Analyzing the changes in style, theme, etc., please come up with a prompt that must have been used to guide the transformation from the original to the rewritten essay.
    Keep your output concise, to the point(only the prompt), and less than a 100 words.
    """


SAMPLE_OUTPUT = """Please improve the following text using the writing style of, maintaining the original meaning but altering the tone, diction, and stylistic elements to match the new style.Enhance the clarity, elegance, and impact of the following text by adopting the writing style of , ensuring the core message remains intact while transforming the tone, word choice, and stylistic features to align with the specified style."""

SAMPLE_OUTPUT_1 = """Please improve this text using the writing style with maintaining the original meaning but altering the tone."""

SAMPLE_OUTPUT_2 = """Refine the following passage by emulating the writing style of, with a focus on enhancing its clarity, elegance, and overall impact. Preserve the essence and original meaning of the text, while meticulously adjusting its tone, vocabulary, and stylistic elements to resonate with the chosen style.Please improve the following text using the writing style of, maintaining the original meaning but altering the tone, diction, and stylistic elements to match the new style.Enhance the clarity, elegance, and impact of the following text by adopting the writing style of , ensuring the core message remains intact while transforming the tone, word choice, and stylistic features to align with the specified style."""

SAMPLE_OUTPUT_3 = """Please improve this text using the writing style with maintaining the original meaning but altering the tone, ensuring the core message remains intact while transforming the tone, word choice, and stylistic features to align with the specified style."""


def gen_prompt(og_text, rewritten_text):
    
    # Truncate the texts to first 200 words for now
    # As we are having memory issues on Mixtral8x7b
    og_text = truncate_txt(og_text, 256)
    rewritten_text = truncate_txt(rewritten_text, 256)
    
    return f"""    
    Original Essay:
    \"""{og_text}\"""
    
    Rewritten Essay:
    \"""{rewritten_text}\"""
    
    Given are 2 essays, the Rewritten essay was created from the Original essay using the google Gemma model.
    You are trying to understand how the original essay was transformed into a new version.
    Analyzing the changes in style, theme, etc., please come up with a prompt that must have been used to guide the transformation from the original to the rewritten essay.
    Keep your output concise, to the point(only the prompt), and less than a 100 words.
    Make sure that the generated prompt is in the format of "Please improve this text by [adding a magician].".
    
    Sample output:
    \"""{SAMPLE_OUTPUT}\"""
    """

def gen_prompt_1(og_text, rewritten_text):
    
    # Truncate the texts to first 200 words for now
    # As we are having memory issues on Mixtral8x7b
    og_text = truncate_txt(og_text, 256)
    rewritten_text = truncate_txt(rewritten_text, 256)
    
    return f"""    
    Original Essay:
    \"""{og_text}\"""
    
    Rewritten Essay:
    \"""{rewritten_text}\"""
    
    Given are 2 essays, the Rewritten essay was created from the Original essay using the google Gemma model.
    You are trying to understand how the original essay was transformed into a new version.
    Analyzing the changes in style, theme, etc., please come up with a prompt that must have been used to guide the transformation from the original to the rewritten essay.
    Keep your output concise, to the point(only the prompt), and less than a 100 words.

    Sample output:
    {SAMPLE_OUTPUT_1}"""

In [5]:
import gc

device = 'cuda'
sub = pd.read_csv('/kaggle/input/llm-prompt-recovery/sample_submission.csv')
sub.head()

Unnamed: 0,id,rewrite_prompt
0,9559194,Improve that text.


In [6]:
sub["rewrite_prompt"] = str(SAMPLE_OUTPUT_1)

In [7]:
# tdf.loc[0,'id'] = 9559194
# tdf.head()

In [8]:
for row in tqdm(tdf.itertuples()):
    try:
    
        query_prompt = gen_prompt(row[2], row[3])
#         query_prompt = gen_prompt_sample(row[2], row[3])
        
        messages = [
            {
                "role": "user",
                "content": query_prompt
            }
        ]
#         messages = [
#             {
#                 "role": "user",
#                 "content": query_prompt
#             },
#             {
#                 "role": "assistant",
#                 "content": f"""Sample prompt:
#                 \"""{SAMPLE_OUTPUT}\"""
#                 """
#             }
#         ]

#         encoded_input = tokenizer(query_prompt, return_tensors="pt").to(device)
        inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda")

        with torch.no_grad():
            encoded_output = model.generate(inputs, max_new_tokens=80, do_sample=True, pad_token_id=tokenizer.eos_token_id)

        decoded_output = tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(query_prompt, '').replace("[INST]", "").replace("[/INST]", "").strip()
    
        sub.loc[sub['id'] == row[1], 'rewrite_prompt'] = decoded_output.replace('Prediction:','').replace('prediction:','').replace('Sample Output:', '').replace('output:', '')

        print("FINAL: ", sub)

        torch.cuda.empty_cache()
        gc.collect()

    except Exception as e:
        print(e)
        sub.loc[sub['id'] == row[1], 'rewrite_prompt'] = str(SAMPLE_OUTPUT_1)
#     finally:
#         if not (sub['id'] == row[1]).any():
#             sub.loc[sub['id'] == row[1], 'rewrite_prompt'] = str(SAMPLE_OUTPUT_1)

0it [00:00, ?it/s]2024-04-16 04:02:48.400315: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-16 04:02:48.400447: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-16 04:02:48.651645: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


FINAL:          id                                     rewrite_prompt
0  9559194  Please improve this text using the writing sty...


1it [00:40, 40.39s/it]


In [9]:
sub.to_csv("submission.csv", header=True, index=False)

In [10]:
print(sub.iloc[0]['rewrite_prompt'])

Please improve this text using the writing style with maintaining the original meaning but altering the tone.
