In [None]:
import json
import os
from functools import partial
from pathlib import Path
from typing import Annotated

import anthropic
import instructor
import openai
from chromadb import Collection as ChromaCollection
from dotenv import load_dotenv
from pydantic import AfterValidator, BaseModel, Field, ValidationInfo, model_validator
from tenacity import RetryError, Retrying, stop_after_attempt, wait_random_exponential

from dreamai.ai import ModelName, system_message, user_message
from dreamai.chroma import chroma_collection, query_collection, traverse_ids
from dreamai.pdf import pdf_to_collection
from dreamai.utils import deindent

load_dotenv()

ask_oai = instructor.from_openai(openai.OpenAI())
ask_cld = instructor.from_anthropic(anthropic.Anthropic())

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

In [None]:
# data_dir = Path("/media/hamza/data2/MATH/train/")
# questions_dir = Path("math_102_questions")
# question_id = 0
# for folder in data_dir.iterdir():
#     if folder.is_dir() and folder.name != "counting_and_probability":
#         folder_questions = []
#         for question_file in folder.glob("*.json"):
#             question = json.loads(question_file.read_text())
#             if "5" in question["level"]:
#                 dest = questions_dir / f"{folder.name}/{question_id}.json"
#                 os.makedirs(dest.parent, exist_ok=True)
#                 with open(dest, "w") as f:
#                     json.dump(
#                         {
#                             "id": str(question_id),
#                             "problem": question["problem"],
#                             "solution": question["solution"],
#                         },
#                         f,
#                         indent=2,
#                     )
#                 question_id += 1

In [None]:
MODEL = ModelName.GPT_3
MAX_TOKENS = 1500
CONCEPT_WORD_COUNT = 3
QUESTIONS_PER_FOLDER = 40
ATTEMPTS = 3

In [None]:
def validate_word_count(text: str, word_count: int = 3, text_name: str = "Text") -> str:
    if len(text.split()) < word_count:
        raise ValueError(f"{text_name} should be at least {word_count} words long")
    return text


def validate_topic_subtopic(cls, info: ValidationInfo):
    topics: list[Topic] = info.context.get("topics")  # type: ignore
    topics_names = [topic.name for topic in topics]
    if cls.topic not in topics_names:
        raise ValueError(f"Topic {cls.topic} not found in topics")
    topic = next(topic for topic in topics if topic.name == cls.topic)
    if cls.subtopic not in [subtopic.name for subtopic in topic.subtopics]:
        raise ValueError(f"Subtopic {cls.subtopic} not found in Topic: {cls.topic}")
    return cls


class Question(BaseModel):
    id: str
    problem: str
    solution: str


class TopicSubtopic(BaseModel):
    topic: str
    subtopic: str

    @model_validator(mode="after")  # type: ignore
    def validate_topic_subtopic(self, info: ValidationInfo) -> "TopicSubtopic":
        return validate_topic_subtopic(self, info)


class QuestionWithTopicSubtopic(Question):
    topic: str
    subtopic: str

    @model_validator(mode="after")  # type: ignore
    def validate_topic_subtopic(
        self, info: ValidationInfo
    ) -> "QuestionWithTopicSubtopic":
        return validate_topic_subtopic(self, info)


class QuestionsWithTopicSubtopic(BaseModel):
    questions: list[QuestionWithTopicSubtopic] = Field(..., min_length=3, max_length=4)


class QuestionWithTopicSubtopicAndSubquestions(QuestionWithTopicSubtopic):
    subquestions: list[QuestionWithTopicSubtopic]


class Concept(BaseModel):
    concept: Annotated[
        str,
        AfterValidator(
            partial(
                validate_word_count, word_count=CONCEPT_WORD_COUNT, text_name="Concept"
            )
        ),
    ]


class QuestionWithConcept(QuestionWithTopicSubtopic):
    concept: str


class QuestionWithTopicSubtopicAndConceptSubquestions(QuestionWithTopicSubtopic):
    subquestions: list[QuestionWithConcept]


class QuestionWithConceptAndSubquestions(QuestionWithConcept):
    subquestions: list[QuestionWithConcept]


class ConceptWithQuestionIDs(BaseModel):
    concept: str
    question_ids: list[str] = Field(default_factory=list)


class Subtopic(BaseModel):
    name: str
    concepts: list[ConceptWithQuestionIDs] = Field(default_factory=list)


class Topic(BaseModel):
    name: str
    subtopics: list[Subtopic]

In [None]:
def topics_str(topics: list[Topic]) -> str:
    return deindent(
        f"""
TOPICS:

{json.dumps([topic.model_dump() for topic in topics], indent=2)}
"""
    )


def topic_subtopic_message(topics: list[Topic]) -> str:
    return deindent(
        f"""
You are a world class math course instructor.
You will be given a question with a 'problem' and a 'solution'.
Given these topics and subtopics below, assign a 'topic' and a 'subtopic' to the question.
The 'subtopic' must be one of the subtopics of the 'topic'.

{topics_str(topics)}
    """
    )


def subqs_message(topics: list[Topic], book_pages: str = "") -> str:
    prompt = deindent(
        """
You are a world class math course instructor.
You will be given a question with a 'problem', a 'solution', a 'topic', and a 'subtopic'.
Based on the main question's problem and solution, break the question down into 3-4 smaller subquestions.
Answering these questions in sequence should lead to the solution of the main question.
Make sure that the answer to the last subquestion is the solution to the main question.
So the subquestions are basically the steps to solve the main question.
And if a student can solve the main question, we can assume that they have learned the underlying concepts of the subquestions.
No 2 subquestions can have the same concept.
For each subquestion:
    1. Define the 'problem'.
    2. Give a detailed 'solution'.
    3. Given these topics below, assign a 'topic' and a 'subtopic' to the subquestion.
       The 'subtopic' must be one of the subtopics of the 'topic'.
"""
    )

    if book_pages:
        prompt += f"\nYou can use these book pages for reference:\n\n{book_pages}"
    return prompt + f"\n\n{topics_str(topics)}"


def concepts_str(concepts: list[str]) -> str:
    return deindent(
        f"""
You can use a concept from the list below or come up with a new one if needed.

CONCEPTS:

{json.dumps(concepts, indent=2)}
"""
    )


def question_w_concept_message(concepts: list[str]) -> str:
    prompt = deindent(
        f"""
You are a world class math course instructor.
You will be given a question with a 'problem', a 'solution', a 'topic', and a 'subtopic'.
Assign a 'concept' to the question. The 'concept' should have at least {CONCEPT_WORD_COUNT} words.
Solving this question should help students understand this concept.
"""
    )
    if len(concepts) > 0:
        prompt += concepts_str(concepts)
    return prompt


def question_w_subqs_concept_message(concepts: list[str]) -> str:
    prompt = deindent(
        f"""
You are a world class math course instructor.
You will be given a question with a 'problem', a 'solution', a 'topic', and a 'subtopic'.
It will also have 3-5 subquestions. Each subquestion will have a 'problem', a 'solution', a 'topic', a 'subtopic', and a 'concept'.
Based on the subquestions, assign a 'concept' to the main question. The 'concept' should have at least {CONCEPT_WORD_COUNT} words.
Try not to repeat the concepts of the subquestions. Because the subquestions are the steps to solve the main question.
Solving this question should help students understand this concept.
"""
    )
    if len(concepts) > 0:
        prompt += concepts_str(concepts)
    return prompt

In [None]:
def create_question_w_topic_subtopic(
    question: Question, topics: list[Topic], model: ModelName = MODEL, attempts: int = 1
) -> QuestionWithTopicSubtopic:
    question_message = deindent(f"Question:\n\n{question.model_dump_json(indent=2)}")
    ask_kwargs = dict(
        max_tokens=MAX_TOKENS,
        model=model,
        response_model=TopicSubtopic,
        max_retries=attempts,
        validation_context={"topics": topics},
    )
    if model in [ModelName.HAIKU, ModelName.SONNET, ModelName.OPUS]:
        question_topic_subtopic = ask_cld.create(
            system=topic_subtopic_message(topics),
            messages=[
                user_message(content=question_message),  # type: ignore
            ],
            **ask_kwargs,  # type: ignore
        )
    else:
        question_topic_subtopic = ask_oai.create(
            messages=[
                system_message(topic_subtopic_message(topics)),
                user_message(content=question_message),  # type: ignore
            ],
            **ask_kwargs,  # type: ignore
        )
    return QuestionWithTopicSubtopic.model_construct(
        **question.model_dump(), **question_topic_subtopic.model_dump()
    )


def create_subquestions_w_topic_subtopic(
    question_w_topic_subtopic: QuestionWithTopicSubtopic,
    topics: list[Topic],
    model: ModelName = MODEL,
    attempts: int = 1,
    pdf_collection: ChromaCollection | None = None,
    n_results: int = 3,
    n_next_links: int = 2,
    n_prev_links: int = 2,
) -> QuestionsWithTopicSubtopic:
    book_pages = ""
    if pdf_collection is not None:
        question_res, _ = query_collection(
            query_text=question_w_topic_subtopic.model_dump_json(
                exclude={"id"}, indent=2
            ),
            collection=pdf_collection,
            n_results=n_results,
            n_next_links=n_next_links,
            n_prev_links=n_prev_links,
        )
        book_pages = "\n\n".join(["\n".join(res["documents"]) for res in question_res])  # type: ignore
    question_w_topic_subtopic_message = deindent(
        f"Question with Topic and Subtopic:\n\n{question_w_topic_subtopic.model_dump_json(indent=2)}"  # type: ignore
    )
    sys_message = subqs_message(topics=topics, book_pages=book_pages)
    ask_kwargs = dict(
        max_tokens=MAX_TOKENS,
        model=model,
        response_model=QuestionsWithTopicSubtopic,
        max_retries=attempts,
        validation_context={"topics": topics},
    )
    # print(f"SYSTEM MESSAGE:\n\n{sys_message}\n\n")
    if model in [ModelName.HAIKU, ModelName.SONNET, ModelName.OPUS]:
        return ask_cld.create(
            system=sys_message,
            messages=[user_message(content=question_w_topic_subtopic_message)],  # type: ignore
            **ask_kwargs,  # type: ignore
        )
    else:
        return ask_oai.create(
            messages=[
                system_message(sys_message),
                user_message(content=question_w_topic_subtopic_message),  # type: ignore
            ],
            **ask_kwargs,  # type: ignore
        )


def get_topic_subtopic_from_question(
    question: QuestionWithTopicSubtopic, topics: list[Topic]
) -> tuple[Topic, Subtopic, int, int]:
    topic_id, topic = next(
        (i, topic) for i, topic in enumerate(topics) if topic.name == question.topic
    )
    subtopic_id, subtopic = next(
        (i, subtopic)
        for i, subtopic in enumerate(topic.subtopics)
        if subtopic.name == question.subtopic
    )
    return topic, subtopic, topic_id, subtopic_id


def create_subquestions_concepts(
    subquestions: QuestionsWithTopicSubtopic,
    topics: list[Topic],
    model: ModelName = MODEL,
    attempts: int = 1,
) -> tuple[list[QuestionWithConcept], list[Topic]]:
    subquestions_concepts = []
    for subquestion_idx, subquestion in enumerate(subquestions.questions):
        _, subtopic, topic_id, subtopic_id = get_topic_subtopic_from_question(
            subquestion, topics
        )
        subtopic_concepts = [c.concept for c in subtopic.concepts]
        # print(f"{subtopic.model_dump_json(indent=2)}\n\n")
        concept_prompt = question_w_concept_message(concepts=subtopic_concepts)
        subquestion_message = deindent(
            f"Question:\n\n{subquestion.model_dump_json(indent=2)}"
        )
        try:
            for attempt in Retrying(
                wait=wait_random_exponential(min=30, max=60), stop=stop_after_attempt(3)
            ):
                with attempt:
                    ask_kwargs = dict(
                        max_tokens=MAX_TOKENS,
                        model=model,
                        response_model=Concept,  # type: ignore
                        max_retries=attempts,
                        validation_context={"topics": topics},
                    )
                    if model in [ModelName.HAIKU, ModelName.SONNET, ModelName.OPUS]:
                        subquestion_concept = ask_cld.create(
                            system=concept_prompt,
                            messages=[user_message(content=subquestion_message)],  # type: ignore
                            **ask_kwargs,  # type: ignore
                        ).concept
                    else:
                        subquestion_concept = ask_oai.create(
                            messages=[
                                system_message(concept_prompt),
                                user_message(content=subquestion_message),  # type: ignore
                            ],
                            **ask_kwargs,  # type: ignore
                        ).concept
            subquestions_concepts.append(subquestion_concept)
            if subquestion_concept not in subtopic_concepts:
                subtopic.concepts.append(
                    ConceptWithQuestionIDs(concept=subquestion_concept)
                )
            topics[topic_id].subtopics[subtopic_id] = subtopic
        except RetryError as e:
            print(f"Failed to generate concept for subquestion: {subquestion_idx}: {e}")
            continue
    subquestions_w_concepts = [
        QuestionWithConcept.model_construct(**subquestion.model_dump(), concept=concept)
        for subquestion, concept in zip(subquestions.questions, subquestions_concepts)
    ]
    return subquestions_w_concepts, topics


def construct_question_w_topic_subtopic_and_subquestions_w_concepts(
    question_w_topic_subtopic: QuestionWithTopicSubtopic,
    subquestions_w_concepts: list[QuestionWithConcept],
) -> QuestionWithTopicSubtopicAndConceptSubquestions:
    return QuestionWithTopicSubtopicAndConceptSubquestions.model_construct(
        **question_w_topic_subtopic.model_dump(), subquestions=subquestions_w_concepts
    )


def create_question_w_concept_and_subquestions(
    question: QuestionWithTopicSubtopicAndConceptSubquestions,
    topics: list[Topic],
    model: ModelName = MODEL,
    attempts: int = 1,
) -> QuestionWithConceptAndSubquestions:
    subtopic = get_topic_subtopic_from_question(question, topics)[1]
    subtopic_concepts = [c.concept for c in subtopic.concepts]
    concept_prompt = question_w_subqs_concept_message(concepts=subtopic_concepts)
    question_message = deindent(
        f"Question with Subquestions:\n\n{question.model_dump_json(indent=2)}"
    )
    ask_kwargs = dict(
        max_tokens=MAX_TOKENS,
        model=model,
        response_model=Concept,
        max_retries=attempts,
        validation_context={"topics": topics},
    )
    if model in [ModelName.HAIKU, ModelName.SONNET, ModelName.OPUS]:
        question_concept = ask_cld.create(
            system=concept_prompt,
            messages=[user_message(content=question_message)],  # type: ignore
            **ask_kwargs,  # type: ignore
        ).concept
    else:
        question_concept = ask_oai.create(
            messages=[
                system_message(concept_prompt),
                user_message(content=question_message),  # type: ignore
            ],
            **ask_kwargs,  # type: ignore
        ).concept
    return QuestionWithConceptAndSubquestions.model_construct(
        **question.model_dump(exclude={"subquestions"}),
        concept=question_concept,
        subquestions=question.subquestions,
    )


def update_topic_subtopic_concepts(
    question: QuestionWithConceptAndSubquestions, topics: list[Topic]
) -> list[Topic]:
    _, subtopic, topic_id, subtopic_id = get_topic_subtopic_from_question(
        question, topics
    )
    subtopic_concepts = [c.concept for c in subtopic.concepts]
    if question.concept not in subtopic_concepts:
        subtopic.concepts.append(
            ConceptWithQuestionIDs(
                concept=question.concept,
                question_ids=[question.id],
            )
        )
    else:
        subtopic.concepts[
            subtopic_concepts.index(question.concept)
        ].question_ids.append(question.id)
    topics[topic_id].subtopics[subtopic_id] = subtopic
    return topics

In [None]:
pdf_file = "/media/hamza/data2/algebra.pdf"
collection_name = "algebra_collection"

In [None]:
# pdf_collection = pdf_to_collection(
#     pdf_file,
#     collection_name=collection_name,
#     chunk_size=4000,
#     chunk_overlap=200,
#     device="cuda",
#     delete_existing=True,
# )

In [None]:
pdf_collection = chroma_collection(name=collection_name, delete_existing=False)
pdf_collection.count()

In [None]:
topics = json.load(open("math_102_topics.json"))
topics = [
    Topic(
        name=topic["name"],
        subtopics=[Subtopic(name=subtopic) for subtopic in topic["subtopics"]],
    )
    for topic in topics
]
topics[0].model_dump()

In [None]:
questions_dir = Path("math_102_questions")
questions = [
    Question(**json.loads(question_file.read_text()))
    for folder in questions_dir.iterdir()
    for question_file in list(folder.glob("*.json"))[:QUESTIONS_PER_FOLDER]
]

In [None]:
final_dir = Path("math_102_final_questions")
os.makedirs(final_dir, exist_ok=True)
final_topics = Path("math_102_final_topics.json")
missed_questions = []
for i, question in enumerate(questions):
    current_ids = [int(q.stem) for q in final_dir.glob("*.json")]
    if int(question.id) not in current_ids:
        try:
            question = create_question_w_topic_subtopic(
                question=question,
                topics=topics,
                attempts=ATTEMPTS,
                model=ModelName.GPT_3,
            )
            subquestions = create_subquestions_w_topic_subtopic(
                question_w_topic_subtopic=question,
                topics=topics,
                attempts=ATTEMPTS,
                pdf_collection=pdf_collection,
                model=ModelName.GPT_4,
            )
            subquestions, topics = create_subquestions_concepts(
                subquestions=subquestions,
                topics=topics,
                attempts=ATTEMPTS,
                model=ModelName.GPT_3,
            )
            question = construct_question_w_topic_subtopic_and_subquestions_w_concepts(
                question_w_topic_subtopic=question, subquestions_w_concepts=subquestions
            )
            question = create_question_w_concept_and_subquestions(
                question=question,
                topics=topics,
                attempts=ATTEMPTS,
                model=ModelName.GPT_3,
            )
            topics = update_topic_subtopic_concepts(question=question, topics=topics)
            dest = final_dir / f"{question.id}.json"
            with open(dest, "w") as f:
                json.dump(question.model_dump(), f, indent=2)
            with open(final_topics, "w") as f:
                json.dump(
                    [topic.model_dump() for topic in topics],
                    f,
                    indent=2,
                )
        except Exception as e:
            missed_questions.append([i, e])
            print(f"Failed to generate question: {i}: {e}")