In [None]:
import os
import json
import numpy as np
from tqdm import tqdm
from openai import OpenAI
import logging
logging.getLogger("httpx").setLevel(logging.ERROR)

from misc import HUF_TOKEN, OAI_TOKEN
from misc import ASSIGNMENTS_PATH, TOPICS_PATH, TOPIC_MODEL, MAX_CHARS

from dataset import load_comparison_dataset

In [None]:
os.environ["OPENAI_API_KEY"] = OAI_TOKEN
client = OpenAI()

In [None]:
assignments = np.load(ASSIGNMENTS_PATH)

In [None]:
_, prompts, _, _, _ = load_comparison_dataset(token=HUF_TOKEN)

In [None]:
groups = {}
for i, label in enumerate(assignments):
    groups.setdefault(label, []).append(prompts[i])

with open(TOPICS_PATH, "a") as fout:
    for gid, prompts in tqdm(groups.items()):
        text = "\n".join(f"- {p[:MAX_CHARS]}" for p in prompts)
        msg = [
            {"role": "system", "content": "You output ONLY a 1-3 word topic in English."},
            {"role": "user", "content": f"Topic for these prompts:\n{text}"}
        ]
        resp = client.responses.create(
            model=TOPIC_MODEL,
            input=msg,
        )
        topic = resp.output[0].content[0].text.strip()

        fout.write(json.dumps({"gid": int(gid), "topic": topic}) + "\n")
        fout.flush()
