In [1]:
import json
from functools import partial
from pathlib import Path
from typing import Annotated

import anthropic
import instructor
import openai
from dotenv import load_dotenv
from pydantic import AfterValidator, BaseModel, Field, ValidationInfo, model_validator
from tenacity import RetryError, Retrying, stop_after_attempt, wait_random_exponential

from dreamai.ai import ModelName, count_gpt_tokens, user_message
from dreamai.utils import deindent

load_dotenv()

ask_oai = instructor.from_openai(openai.OpenAI())
ask_cld = instructor.from_anthropic(anthropic.Anthropic())

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:
MODEL = ModelName.HAIKU
CONCEPT_WORD_COUNT = 3
QUESTIONS_PER_FOLDER = 30

In [34]:
def validate_word_count(text: str, word_count: int = 3, text_name: str = "Text") -> str:
    if len(text.split()) < word_count:
        raise ValueError(f"{text_name} should be at least {word_count} words long")
    return text


def validate_topic_subtopic(cls, info: ValidationInfo):
    topics: list[Topic] = info.context.get("topics")  # type: ignore
    topics_names = [topic.name for topic in topics]
    if cls.topic not in topics_names:
        raise ValueError(f"Topic {cls.topic} not found in topics")
    topic = next(topic for topic in topics if topic.name == cls.topic)
    if cls.subtopic not in [subtopic.name for subtopic in topic.subtopics]:
        raise ValueError(f"Subtopic {cls.subtopic} not found in Topic: {cls.topic}")
    return cls


class Question(BaseModel):
    id: str
    problem: str
    solution: str


class TopicSubtopic(BaseModel):
    topic: str
    subtopic: str

    @model_validator(mode="after")  # type: ignore
    def validate_topic_subtopic(self, info: ValidationInfo) -> "TopicSubtopic":
        return validate_topic_subtopic(self, info)


class QuestionWithTopicSubtopic(Question):
    topic: str
    subtopic: str

    @model_validator(mode="after")  # type: ignore
    def validate_topic_subtopic(
        self, info: ValidationInfo
    ) -> "QuestionWithTopicSubtopic":
        return validate_topic_subtopic(self, info)


class QuestionsWithTopicSubtopic(BaseModel):
    questions: list[QuestionWithTopicSubtopic]


class QuestionWithTopicSubtopicAndSubquestions(QuestionWithTopicSubtopic):
    subquestions: list[QuestionWithTopicSubtopic]


class Concept(BaseModel):
    concept: Annotated[
        str,
        AfterValidator(
            partial(
                validate_word_count, word_count=CONCEPT_WORD_COUNT, text_name="Concept"
            )
        ),
    ]


class QuestionWithConcept(QuestionWithTopicSubtopic):
    concept: str


class QuestionWithTopicSubtopicAndConceptSubquestions(QuestionWithTopicSubtopic):
    subquestions: list[QuestionWithConcept]


class QuestionWithConceptAndSubquestions(QuestionWithConcept):
    subquestions: list[QuestionWithConcept]


class ConceptWithQuestionIDs(BaseModel):
    concept: str
    question_ids: list[str] = Field(default_factory=list)


class Subtopic(BaseModel):
    name: str
    concepts: list[ConceptWithQuestionIDs] = Field(default_factory=list)


class Topic(BaseModel):
    name: str
    subtopics: list[Subtopic]

In [25]:
topics = json.load(open("math_102_topics.json"))
topics = [
    Topic(
        name=topic["name"],
        subtopics=[Subtopic(name=subtopic) for subtopic in topic["subtopics"]],
    )
    for topic in topics
]
topics[0].model_dump()

{'name': 'Set Theory',
 'subtopics': [{'name': 'Representations of Sets', 'concepts': []},
  {'name': 'Classifying Sets', 'concepts': []},
  {'name': 'Set Operations: Union and Intersection', 'concepts': []},
  {'name': 'Understanding Set Difference', 'concepts': []},
  {'name': 'Exploring Set Complement', 'concepts': []}]}

In [30]:
data_dir = Path("/media/hamza/data2/MATH/train/")
folders = [
    folder
    for folder in data_dir.iterdir()
    if folder.is_dir() and folder.name != "counting_and_probability"
]
questions = []
for folder in folders:
    folder_questions = []
    for question_file in folder.glob("**/*.json"):
        question = json.loads(question_file.read_text())
        if "5" in question["level"] and len(folder_questions) < QUESTIONS_PER_FOLDER:
            folder_questions.append(
                Question(
                    **{
                        "id": str(len(questions) + len(folder_questions)),
                        "problem": question["problem"],
                        "solution": question["solution"],
                    }
                )
            )
    questions += folder_questions[:QUESTIONS_PER_FOLDER]

In [36]:
question: Question = questions[1]
topic = "Mapping & Functions Overview"
subtopic = "Function Operations"
QuestionWithTopicSubtopic.model_validate(
    dict(question.model_dump(), topic=topic, subtopic=subtopic),
    context={"topics": topics},
).model_dump()

{'id': '1',
 'problem': 'What is the least positive integer value of $x$ such that $(2x)^2 + 2\\cdot 37\\cdot 2x + 37^2$ is a multiple of 47?',
 'solution': 'We note that $(2x)^2 + 2\\cdot 37 \\cdot 2x + 37^2 = (2x + 37)^2$. In order for this expression to be a multiple of 47, $2x + 37$ must be a multiple of 47. Since we want the least positive value of $x$, we will want $2x + 37 = 47$. It follows that $x = \\boxed{5}$.',
 'topic': 'Mapping & Functions Overview',
 'subtopic': 'Function Operations'}

In [37]:
topics_str = deindent(
    f"""
TOPICS:

{json.dumps([topic.model_dump() for topic in topics], indent=2)}
"""
)

topic_subtopic_message = deindent(
    f"""
You are a world class math course instructor.
You will be given a question with a 'problem' and a 'solution'.
Given these topics and subtopics below, assign a 'topic' and a 'subtopic' to the question.
The 'subtopic' must be one of the subtopics of the 'topic'.

{topics_str}
    """
)

subqs_message = deindent(
    f"""
You are a world class math course instructor.
You will be given a question with a 'problem', a 'solution', a 'topic', and a 'subtopic'.
Based on the main question's problem and solution, break the question down into 3-5 smaller subquestions.
Answering these questions in sequence should lead to the solution of the main question.
Make sure that the answer to the last subquestion is the solution to the main question.
So the subquestions are basically the steps to solve the main question.
And if a student can solve the main question, we can assume that they have learned the underlying concepts of the subquestions.
No 2 subquestions can have the same concept.
For each subquestion:
    1. Define the 'problem'.
    2. Give a detailed 'solution'.
    3. Given these topics below, assign a 'topic' and a 'subtopic' to the subquestion.
       The 'subtopic' must be one of the subtopics of the 'topic'.

{topics_str}
"""
)


def concepts_str(concepts: list[str]) -> str:
    return deindent(
        f"""
You can use a concept from the list below or come up with a new one if needed.

CONCEPTS:

{json.dumps(concepts, indent=2)}
"""
    )


def question_w_concept_message(concepts: list[str]) -> str:
    prompt = deindent(
        f"""
You are a world class math course instructor.
You will be given a question with a 'problem', a 'solution', a 'topic', and a 'subtopic'.
Assign a 'concept' to the question. The 'concept' should have at least {CONCEPT_WORD_COUNT} words.
Solving this question should help students understand this concept.
"""
    )
    if len(concepts) > 0:
        prompt += concepts_str(concepts)
    return prompt


def question_w_subqs_concept_message(concepts: list[str]) -> str:
    prompt = deindent(
        f"""
You are a world class math course instructor.
You will be given a question with a 'problem', a 'solution', a 'topic', and a 'subtopic'.
It will also have 3-5 subquestions. Each subquestion will have a 'problem', a 'solution', a 'topic', a 'subtopic', and a 'concept'.
Based on the subquestions, assign a 'concept' to the main question. The 'concept' should have at least {CONCEPT_WORD_COUNT} words.
Try not to repeat the concepts of the subquestions. Because the subquestions are the steps to solve the main question.
Solving this question should help students understand this concept.
"""
    )
    if len(concepts) > 0:
        prompt += concepts_str(concepts)
    return prompt

In [38]:
print(
    question_w_subqs_concept_message(
        ["addition", "subtraction", "multiplication", "division"]
    )
)

You are a world class math course instructor.
You will be given a question with a 'problem', a 'solution', a 'topic', and a 'subtopic'.
It will also have 3-5 subquestions. Each subquestion will have a 'problem', a 'solution', a 'topic', a 'subtopic', and a 'concept'.
Based on the subquestions, assign a 'concept' to the main question. The 'concept' should have at least 3 words.
Try not to repeat the concepts of the subquestions. Because the subquestions are the steps to solve the main question.
Solving this question should help students understand this concept.You can use a concept from the list below or come up with a new one if needed.

CONCEPTS:

[
  "addition",
  "subtraction",
  "multiplication",
  "division"
]


In [9]:
count_gpt_tokens(topic_subtopic_message), count_gpt_tokens(subqs_message)

(1869, 2012)

In [None]:
def create_question_w_topic_subtopic(
    question: Question, topics: list[Topic], model: ModelName = MODEL, attempts: int = 1
) -> QuestionWithTopicSubtopic:
    question_message = deindent(f"Question:\n\n{question.model_dump_json(indent=2)}")
    question_topic_subtopic = ask_cld.create(
        system=topic_subtopic_message,
        messages=[
            user_message(content=question_message),  # type: ignore
        ],
        max_tokens=2048,
        model=model,
        response_model=TopicSubtopic,
        max_retries=attempts,
        validation_context={"topics": topics},
    )
    return QuestionWithTopicSubtopic.model_construct(
        **question.model_dump(), **question_topic_subtopic.model_dump()
    )

In [None]:
def create_subquestions_w_topic_subtopic(
    question_w_topic_subtopic: QuestionWithTopicSubtopic,
    topics: list[Topic],
    model: ModelName = MODEL,
    attempts: int = 1,
) -> QuestionsWithTopicSubtopic:
    question_w_topic_subtopic_message = deindent(
        f"Question with Topic and Subtopic:\n\n{question_w_topic_subtopic.model_dump_json(indent=2)}"  # type: ignore
    )
    return ask_cld.create(
        system=subqs_message,
        messages=[user_message(content=question_w_topic_subtopic_message)],  # type: ignore
        max_tokens=2048,
        model=model,
        response_model=QuestionsWithTopicSubtopic,
        max_retries=attempts,
        validation_context={"topics": topics},
    )

In [None]:
def get_topic_subtopic_from_question(
    question: QuestionWithTopicSubtopic, topics: list[Topic]
) -> tuple[Topic, Subtopic, int, int]:
    topic_id, topic = next(
        (i, topic) for i, topic in enumerate(topics) if topic.name == question.topic
    )
    subtopic_id, subtopic = next(
        (i, subtopic)
        for i, subtopic in enumerate(topic.subtopics)
        if subtopic.name == question.subtopic
    )
    return topic, subtopic, topic_id, subtopic_id


def create_subquestions_concepts(
    subquestions: QuestionsWithTopicSubtopic,
    topics: list[Topic],
    model: ModelName = MODEL,
    attempts: int = 1,
) -> tuple[list[QuestionWithConcept], list[Topic]]:
    subquestions_concepts = []
    for subquestion_idx, subquestion in enumerate(subquestions.questions):
        topic, subtopic, topic_id, subtopic_id = get_topic_subtopic_from_question(
            subquestion, topics
        )
        subtopic_concepts = [c.concept for c in subtopic.concepts]
        print(f"{subtopic.model_dump_json(indent=2)}\n\n")
        concept_prompt = question_w_concept_message(concepts=subtopic_concepts)
        subquestion_message = deindent(
            f"Question:\n\n{subquestion.model_dump_json(indent=2)}"
        )
        try:
            for attempt in Retrying(
                wait=wait_random_exponential(min=30, max=60), stop=stop_after_attempt(3)
            ):
                with attempt:
                    subquestion_concept = ask_cld.create(
                        system=concept_prompt,
                        messages=[user_message(content=subquestion_message)],  # type: ignore
                        max_tokens=2048,
                        model=model,
                        response_model=Concept,  # type: ignore
                        max_retries=attempts,
                        validation_context={"topics": topics},
                    ).concept
            subquestions_concepts.append(subquestion_concept)
            if subquestion_concept not in subtopic_concepts:
                subtopic.concepts.append(
                    ConceptWithQuestionIDs(concept=subquestion_concept)
                )
            topics[topic_id].subtopics[subtopic_id] = subtopic
        except RetryError as e:
            print(f"Failed to generate concept for subquestion: {subquestion_idx}: {e}")
            continue
    subquestions_w_concepts = [
        QuestionWithConcept.model_construct(**subquestion.model_dump(), concept=concept)
        for subquestion, concept in zip(subquestions.questions, subquestions_concepts)
    ]
    return subquestions_w_concepts, topics

In [None]:
def create_question_w_topic_subtopic_and_subquestions_w_concepts(
    question_w_topic_subtopic: QuestionWithTopicSubtopic,
    subquestions_w_concepts: list[QuestionWithConcept],
) -> QuestionWithTopicSubtopicAndConceptSubquestions:
    return QuestionWithTopicSubtopicAndConceptSubquestions.model_construct(
        **question_w_topic_subtopic.model_dump(), subquestions=subquestions_w_concepts
    )

In [43]:
def create_question_w_concept_and_subquestions(
    question: QuestionWithTopicSubtopicAndConceptSubquestions,
    topics: list[Topic],
    model: ModelName = MODEL,
    attempts: int = 1,
) -> QuestionWithConceptAndSubquestions:
    subtopic = get_topic_subtopic_from_question(question, topics)[1]
    subtopic_concepts = [c.concept for c in subtopic.concepts]
    concept_prompt = question_w_subqs_concept_message(concepts=subtopic_concepts)
    question_message = deindent(
        f"Question with Subquestions:\n\n{question.model_dump_json(indent=2)}"
    )
    question_concept = ask_cld.create(
        system=concept_prompt,
        messages=[user_message(content=question_message)],  # type: ignore
        max_tokens=2048,
        model=model,
        response_model=Concept,
        max_retries=attempts,
        validation_context={"topics": topics},
    ).concept
    return QuestionWithConceptAndSubquestions.model_construct(
        **question.model_dump(exclude={"subquestions"}),
        concept=question_concept,
        subquestions=question.subquestions,
    )


def update_topic_subtopic_concepts(
    topics: list[Topic], question: QuestionWithConceptAndSubquestions
) -> list[Topic]:
    _, subtopic, topic_id, subtopic_id = get_topic_subtopic_from_question(
        question, topics
    )
    subtopic_concepts = [c.concept for c in subtopic.concepts]
    if question.concept not in subtopic_concepts:
        subtopic.concepts.append(
            ConceptWithQuestionIDs(
                concept=question.concept,
                question_ids=[question.id],
            )
        )
    else:
        subtopic.concepts[
            subtopic_concepts.index(question.concept)
        ].question_ids.append(question.id)
    topics[topic_id].subtopics[subtopic_id] = subtopic
    return topics