In [1]:
import pandas as pd
from pathlib import Path
from string import Template
import os

from dotenv import load_dotenv
from openai import AzureOpenAI
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
import ast

tqdm.pandas()


LIMIT_ROWS = None
TEMPERATURE = 1
N = 6
MAX_TOKENS = 2048

PERSUADE_PATH = "../data/persuade/persuade_full_cleaned.csv"

PROMPT_DIR = "../prompts/sat/"
USER_PROMPTS_PREFIX = "sat_score_{}.txt"

EVALUATION_PROMPT_PATH = "../prompts/evaluation/evaluation_function.txt"

SAVE_DIR = "../data/rewrites/sat_2/"
SAVE_NAME = "rew_sat_{}.csv"


In [2]:
SYSTEM_PROMPT_PATH = PROMPT_DIR + "system_prompt.txt"

# Load system prompt once
with open(SYSTEM_PROMPT_PATH, "r", encoding="utf-8") as f:
    SYSTEM_PROMPT = f.read()

In [3]:
load_dotenv()
api_key = os.getenv("API_KEY")
if not api_key:
    raise RuntimeError("Missing API_KEY in environment")

endpoint = "https://extractionhub.cognitiveservices.azure.com/"
model_name = "gpt-4o"
deployment = "gpt-4o"   
api_version = "2024-10-21"

client = AzureOpenAI(
    api_version=api_version,
    azure_endpoint=endpoint,
    api_key=api_key,
)

In [4]:
df = pd.read_csv(PERSUADE_PATH)
df.head()

Unnamed: 0,text,holistic_essay_score,race_ethnicity,gender,grade_level,economically_disadvantaged,prompt_name
0,Some schools require students to complete summ...,3,White,M,11.0,0,Summer projects
1,Letting teachers design the project is the mos...,4,Black/African American,M,11.0,0,Summer projects
2,Some schools implement a summer project to con...,4,White,M,11.0,0,Summer projects
3,Would you want to waste your summer on useless...,4,Asian/Pacific Islander,F,11.0,0,Summer projects
4,"During summer break, it's a time in which the ...",3,Hispanic/Latino,M,11.0,1,Summer projects


In [5]:
def stratified_fixed_k(df, label_col, k):
    return (
        df.groupby(label_col, group_keys=False)
          .apply(lambda x: x.sample(min(k, len(x)), replace=False))
    )

K = 50

df_high_ses = df[df["economically_disadvantaged"] == 0]
df_low_ses  = df[df["economically_disadvantaged"] == 1]

sampled_high_ses = stratified_fixed_k(df_high_ses, "holistic_essay_score", K)
sampled_low_ses  = stratified_fixed_k(df_low_ses, "holistic_essay_score", K)

df = pd.concat([sampled_high_ses, sampled_low_ses], ignore_index=True)
df["rewritten_text"] = None
df["content_preserved"] = None

  .apply(lambda x: x.sample(min(k, len(x)), replace=False))
  .apply(lambda x: x.sample(min(k, len(x)), replace=False))


In [6]:
dfs_by_score = {
    i: df.copy().reset_index(drop=True)
    for i in range(1, 7)
}

In [7]:
def load_prompt_template(file: str) -> Template:
    return Template(Path(file).read_text(encoding="utf-8"))

def build_prompt(prompt_features: dict, file: str) -> str:
    tpl = load_prompt_template(file)
    tpl = tpl.safe_substitute(prompt_features)
    return tpl

def call_llm_optimizer(system_prompt: str, user_prompt: str):

    response = client.chat.completions.create(
        model=deployment,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        n=N,
    )

    outputs = [choice.message.content.strip() for choice in response.choices]

    return outputs


In [8]:
def call_llm_evaluator(user_prompt: str):

    response = client.chat.completions.create(
        model=deployment,
        messages=[
            {"role": "system", "content": "You are a text evaluation specialist with expertise in socioeconomic style transfer."},
            {"role": "user", "content": user_prompt},
        ],
        temperature=TEMPERATURE,
        max_tokens=MAX_TOKENS,
        n=1,
    )

    verdicts = ast.literal_eval(response.choices[0].message.content)
    return verdicts 

def evaluate_rewrites(original_text: str, rewritten_texts: list[str], file: str) -> pd.DataFrame:
    prompt_features = {"original_text" : original_text}
    
    rew_txt = "Texts: \n"
    for i,text in enumerate(rewritten_texts):
        rew_txt += f"- Text {i}: \n {text} \n"
    prompt_features["rewritten_texts"] = rew_txt

    return call_llm_evaluator(build_prompt(prompt_features=prompt_features, file=file))

In [9]:
def rewrite_one_style(user_prompt: str, system_prompt: str):
    """
    Calls the optimizer for a single style prompt and returns a single rewritten text.
    """
    outputs = call_llm_optimizer(system_prompt=system_prompt, user_prompt=user_prompt)
    # call_llm_optimizer always returns a list
    if not outputs:
        return None
    return outputs[0]

In [10]:
for index, row in tqdm(df.iterrows(), total=df.shape[0]):
    original_text = row["text"]
    
    # Build the 6 user prompts for this essay
    user_prompts = [
        build_prompt(prompt_features={"ESSAY_TEXT": original_text}, file=PROMPT_DIR + USER_PROMPTS_PREFIX.format(i))
        for i in range(1, 7)
    ]
    
    # Call the LLM once per style level (simplest & most modular)
    try:
        with ThreadPoolExecutor(max_workers=6) as executor:
            futures = [
                executor.submit(rewrite_one_style, prompt, SYSTEM_PROMPT)
                for prompt in user_prompts
            ]
            # preserve order: futures[0] -> style 1, futures[1] -> style 2, etc.
            outputs = [f.result() for f in futures]
    except Exception as e:
        print(f"Error in LLM call for index {index}: {e}")
        # Mark error in all 6 dfs for this row and continue
        for i in range(1, 7):
            dfs_by_score[i].at[index, "rewritten_text"] = "error occurred"
            dfs_by_score[i].at[index, "content_preserved"] = None
        continue

    # --- Evaluate content preservation across the 6 rewrites ---
    try:
        content_preserved_list = evaluate_rewrites(
            original_text=original_text,
            rewritten_texts=outputs,
            file=EVALUATION_PROMPT_PATH
        )
    except Exception as e:
        print(f"Error in evaluation for index {index}: {e}")
        # If evaluation fails, set all content_preserved to None and still keep the rewrites
        content_preserved_list = [None] * 6

    # --- Store results: for each style i, put its rewrite in the corresponding df ---
    for i, (rewritten_text, cp) in enumerate(zip(outputs, content_preserved_list), start=1):
        dfs_by_score[i].at[index, "rewritten_text"] = rewritten_text

        if cp is None:
            dfs_by_score[i].at[index, "content_preserved"] = None
        else:
            dfs_by_score[i].at[index, "content_preserved"] = (cp == "YES")


  1%|          | 7/600 [01:19<1:47:26, 10.87s/it]

Error in LLM call for index 6: 'NoneType' object has no attribute 'strip'


 55%|█████▌    | 330/600 [1:37:24<36:16,  8.06s/it]  

Error in LLM call for index 329: 'NoneType' object has no attribute 'strip'


 60%|██████    | 363/600 [1:42:56<38:05,  9.64s/it]

Error in LLM call for index 362: 'NoneType' object has no attribute 'strip'


100%|██████████| 600/600 [2:51:09<00:00, 17.12s/it]   


In [11]:
# save df_combined to csv
for i, df_x in dfs_by_score.items():
    df_x.to_csv(f'{SAVE_DIR}raw_{SAVE_NAME.format(i)}', index=False)

In [None]:
bad_indices = set()

# 1) Collect bad indices across all dfs
for i, df in dfs_by_score.items():
    col = df["rewritten_text"].astype(str)

    mask_error = col.str.contains("error occurred", case=False, na=False)
    mask_nan   = df["rewritten_text"].isna()
    mask_empty = col.str.strip().eq("")
    mask_short = col.str.len() < 20

    mask_bad = mask_error | mask_nan | mask_empty | mask_short

    bad_indices.update(df.index[mask_bad])

# 2) Drop those rows from ALL dfs and save cleaned versions
for i, df in dfs_by_score.items():
    df_cleaned = df.loc[~df.index.isin(bad_indices)].copy()
    df_cleaned.to_csv(f"{SAVE_DIR}{SAVE_NAME.format(i)}", index=False)
