In [None]:
import pandas as pd
from pathlib import Path
from string import Template
import os

from dotenv import load_dotenv
from openai import OpenAI, AzureOpenAI
from tqdm import tqdm
import ast

tqdm.pandas()


LIMIT_ROWS = None
TEMPERATURE = 1
N = 6
MAX_TOKENS = 2048

PERSUADE_PATH = "../data/persuade/persuade_full_cleaned.csv"

PROMPT = "no_descriptions"

SAVE_DIR = "../data/rewrites/no_desc/"
SAVE_NAME = "no_desc_full"


In [None]:
load_dotenv()
api_key = os.getenv("API_KEY")
if not api_key:
    raise RuntimeError("Missing API_KEY in environment")

endpoint = "https://extractionhub.cognitiveservices.azure.com/"
model_name = "gpt-4o"
deployment = "gpt-4o"   
api_version = "2024-10-21"

client = AzureOpenAI(
    api_version=api_version,
    azure_endpoint=endpoint,
    api_key=api_key,
)

In [None]:
df = pd.read_csv(PERSUADE_PATH)

if LIMIT_ROWS:
    df = df.head(LIMIT_ROWS)

In [None]:
def load_prompt_template(filename: str) -> Template:
    path = Path(f"../prompts/{filename}.txt")
    text = path.read_text(encoding="utf-8")
    return Template(text)

def build_rewrite_prompt(
    prompt_features: dict,
    filename: str) -> str:
    tpl = load_prompt_template(f"{filename}")

    tpl = tpl.safe_substitute(prompt_features)

    return tpl

def call_llm_optimizer(prompt: str, content: str):

    response = client.chat.completions.create(
        model=deployment,
        messages=[
            {"role": "system", "content": content},
            {"role": "user",   "content": prompt},
        ],
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        n=N,
    )

    outputs = []
    for choice in response.choices:
        text = choice.message.content.strip()
        outputs.append(text)

    return outputs


In [None]:
def build_evauation_prompt(original_text: str, rewritten_texts: str, file: str) -> str:
    path = Path(f"../prompts/evaluation/{file}.txt")
    text = path.read_text(encoding="utf-8")
    template = Template(text)

    rewritten_texts = "\n".join([f"- Text {i}: \n\n" + text for i, text in enumerate(rewritten_texts)])
    return template.substitute(original_text=original_text, rewritten_texts=rewritten_texts)

def call_llm_evaluator(prompt: str, temperature: float = 0.0, max_tokens: int = 1):
    response = client.chat.completions.create(
        model=deployment,
        messages=[
            {
                "role": "system",
                "content": (
                    "You are a text evaluation specialist with expertise in "
                    "socioeconomic style transfer."
                )
            },
            {"role": "user", "content": prompt},
        ],
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        n=1,
    )

    verdicts = ast.literal_eval(response.choices[0].message.content)
    return verdicts 

def evaluate_rewrites(original_text: str, rewritten_texts: list[str], file: str) -> pd.DataFrame:
    return call_llm_evaluator(build_evauation_prompt(original_text, rewritten_texts, file))

In [None]:
for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    prompt_features = {
        "text": row['text']
    }
    full_prompt = build_rewrite_prompt(prompt_features, PROMPT)

    content = "You are an expert text rewriter. Your sole job is to rewrite the given text according to the style instructions in the user message."
    try:
        outputs = call_llm_optimizer(full_prompt, content)

        try:
            content_preserved = evaluate_rewrites(original_text=row['text'], rewritten_texts=outputs, file="evaluation_function")

            for i, cp in enumerate((content_preserved)):
                if content_preserved[i] == 'YES':
                    df.at[index, f'content_preserved_{i}'] = True
                else:
                    df.at[index, f'content_preserved_{i}'] = False
        except Exception as e:
            print(e)
            df.at[index, 'content_preserved_0'] = None

        for i, texts in enumerate(outputs):
            df.at[index, f'rewritten_text_{i}'] = outputs[i]
    
    except Exception as e:
        print(f"Error in evaluation for index {index}: {e}")
        df.at[index, 'rewritten_text'] = 'error occurred'
        df.at[index, 'content_preserved'] = True

In [None]:
# save df_combined to csv
df.to_csv(f'{SAVE_DIR}{SAVE_NAME}', index=False)

In [None]:
df['rewritten_text'].value_counts()


In [None]:
df_cleaned = df[df['rewritten_text'] != 'error ocurred']
df_cleaned.to_csv(f'{SAVE_DIR}{SAVE_NAME}_cleaned.csv', index=False)

In [None]:
df_error = df[df['rewritten_text'] == 'error occurred']
df_error.to_csv(f'{SAVE_DIR}{SAVE_NAME}_errors.csv', index=False)

In [None]:
df_cleaned[[f'content_preserved_{i}' for i in range(6)]].stack().value_counts()