In [33]:
import json
import time
from typing import Dict, List
import os
import requests

from data.clean.user_categories import select_random_user_categories
from data.clean.question_categories import select_random_question_categorizations
from urllib.parse import urlparse, unquote


BASE_URL = "https://api.ai71.ai/v1/"
API_KEY = "ai71-api-b1e07fa1-d007-41cd-8306-85fc952e12a6"

def check_budget():
    resp = requests.get(
        f"{BASE_URL}check_budget",
        headers={"Authorization": f"Bearer {API_KEY}"},
    )
    resp.raise_for_status()
    print(json.dumps(resp.json(), indent=4))

check_budget()

{
    "remaining_budget": 9665
}


In [31]:
def bulk_generate(n_questions: int, question_categorizations: List[Dict], user_categorizations: List[Dict]):
    resp = requests.post(
        f"{BASE_URL}bulk_generation",
        headers={"Authorization": f"Bearer {API_KEY}"},
        json={
                "n_questions": n_questions,
                "question_categorizations": question_categorizations,
                "user_categorizations": user_categorizations
            }
    )
    resp.raise_for_status()
    request_id = resp.json()["request_id"]
    print(json.dumps(resp.json(), indent=4))

    result = wait_for_generation_to_finish(request_id)
    return result


def wait_for_generation_to_finish(request_id: str):
    while True:
        resp = requests.get(
            f"{BASE_URL}fetch_generation_results",
            headers={"Authorization": f"Bearer {API_KEY}"},
            params={"request_id": request_id},
        )
        resp.raise_for_status()
        if resp.json()["status"] == "completed":
            print(json.dumps(resp.json(), indent=4))
            return resp.json()
        else:
            print("Waiting for generation to finish...")
            time.sleep(5)

### Run the following cells with caution, it will create training data and consume budget.

In [32]:
total_questions = 50  # total number of questions you want to generate
per_iteration = 10     # number of questions generated per iteration


iterations = total_questions // per_iteration
for i in range(iterations):
    results = bulk_generate(n_questions=per_iteration,
                             question_categorizations= select_random_question_categorizations(),
                             user_categorizations= select_random_user_categories()
                             )

    # Extract the path from the URL and decode it
    parsed_url = urlparse(results["file"])
    path = unquote(parsed_url.path)

    # Extract the filename
    filename = os.path.basename(path)

    # Download the file
    response = requests.get(results["file"])

    # Check if the request was successful
    if response.status_code == 200:
        with open(f"./generated/{filename}", 'wb') as f:
            f.write(response.content)
        print(f"File downloaded and saved as '{filename}'")
    else:
        print(f"Failed to download file. Status code: {response.status_code}")

{
    "request_id": "fba0a69f-7953-4517-870f-30d0c4ccad19",
    "type": "async"
}
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
Waiting for generation to finish...
{
    "status": "completed",
    "file": "https://s3.amazonaws.com/data.aiir/data_morgana/web_api/results_id_6a8f2233-64ea-4af2-b004-0c5419fec000_user_id_bfc67a4c-41ca-4a55-87c0-85fe0e346a90.jsonl?X

In [28]:
# Test if the earlier cell don't work.
# result = wait_for_generation_to_finish("0f494ec6-b71b-45b9-a165-420ae071c015")