In [None]:
import pandas as pd
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# torch seed
torch.manual_seed(42)

In [None]:
TEST_DATA_PATH = 'test_set.csv'
test_data = pd.read_csv(TEST_DATA_PATH)

In [None]:
model_id = 'google/gemma-2b-it'
new_model = 'gemma-Finetune-test'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = "right"

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
merged_model= PeftModel.from_pretrained(base_model, new_model)
merged_model= merged_model.merge_and_unload()
merged_model.eval()

In [None]:
def prepare_prompt_tempalte_test(input_text, output_text):
    prompt_template = """<start_of_turn>user\nThis is the original text: {input_text}, this is the rewritten text: {output_text}. Which prompt was used to rewrite the original text to the rewritten text?<end_of_turn>\n<start_of_turn>model\n"""
    return prompt_template.format(input_text=input_text, output_text=output_text)

def get_completion(input_text: str, output_text: str, model, tokenizer) -> str:
  device = "cuda:0"
  
  prompt = prepare_prompt_tempalte_test(input_text=input_text, output_text=output_text)
  encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
  model_inputs = encodeds.to(device)
  
  generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True, top_k=1, temperature=0.01, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.encode("\n"))
  # decoded = tokenizer.batch_decode(generated_ids)
  decoded = tokenizer.decode(*generated_ids, skip_special_tokens=False)
  decoded = decoded.replace(f"<bos>{prompt}", "")
  decoded = decoded.replace(f"<end_of_turn>", "")
  return decoded

result = get_completion(input_text=test_data['text'][1], output_text=test_data['rewritten_text'][1], model=merged_model, tokenizer=tokenizer)
print("original Prompt:", test_data['prompt'][1])
print("Generated Prompt:", result)

In [None]:
test_data = test_data[:10]

In [None]:
result = []
for i in range(0, len(test_data)):
    out = get_completion(input_text=test_data['text'][i], output_text=test_data['rewritten_text'][i], model=merged_model, tokenizer=tokenizer)
    result.append(out)

In [None]:
test_data['rewrite_prompt'] = result

In [None]:
test_data['id'] = test_data.index

In [None]:
sub_df = test_data[['id', 'rewrite_prompt']]
sub_df.to_csv('submission.csv', index=False)