In [1]:
import os

os.sys.path.append(os.path.join(os.getcwd(), "../../.."))

In [2]:
import time
import traceback
from openai import OpenAI
from guardrails import Guard
from dotenv import load_dotenv
from utils.logger import get_logger
from utils.helpers import hash_list
from databases.connector import Connector
from models.ExportFormat import ExportFormat
from models.Prompt import Prompt, PromptResult, PromptRun
from guardrails.hub import ValidChoices
import random

In [3]:
logger = get_logger()
load_dotenv()
client = OpenAI()
connector = Connector(ExportFormat.JSON)

In [6]:
model = "random"
themes_level = 3
question_sample_size = 1500
batch_id = "dd5ef5489a898f6df898005bdae76fe395ea02f715a625da41b55f5133efb2fc"

themes = connector.client.get_themes_by_level(themes_level)
themes_list = list(themes)
themes_list = [theme["name"] for theme in themes_list]

batch = connector.client.get_batch({"identifier": batch_id})
question_list = connector.client.aggregate_questions(
	[{"$match": {"id": {"$in": batch["question_ids"]}}}]
)

guard = Guard().use(ValidChoices, choices=themes_list, on_fail="exception")

prompts = []

prompt = Prompt(
	unique_identifier=hash_list(prompts),
	prompts=prompts,
)
prompt = connector.client.upsert_prompt(prompt)

prompt_run = PromptRun(
	prompt_id=prompt["unique_identifier"],
	batch_id=batch_id,
	parameters={"temperature": 0, "model": model, "type": "zero-shot"},
	timestamp=int(time.time()),
	name="Random classificiation"
)
inserted_prompt_run = connector.client.add_prompt_run(prompt_run)

for question in question_list:
	question_text = question["question_text"]

	try:
		random_theme = random.choice(themes_list)

		question_theme = connector.client.get_theme(
			{"name": question["theme"], "level": 0}
		)
		top_level_theme = connector.client.get_parent_theme(
			question_theme["parent_theme_identifier"]
		)["name"]

		guard.validate(random_theme.lower())

		legislature = question["id"][: question["id"].index("-")]

		prompt_result = PromptResult(
			run_id=str(inserted_prompt_run.inserted_id),
			question_id=question["id"],
			batch_id=batch_id,
			prompt_id=prompt["unique_identifier"],
			response=random_theme,
			response_tokens=0,
			prompt_tokens=0,
			legislature=legislature,
			confidence=0,
		)

		connector.client.add_prompt_result(prompt_result)
	except Exception as _:
		logger.error(traceback.format_exc())
		pass