# Imports

In [37]:
# standard library imports
import os
import random
from typing import Callable

# related third party imports
import dotenv
import pandas as pd
from langchain_chroma import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotChatMessagePromptTemplate,
    PromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain_core.output_parsers import PydanticOutputParser
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field, ValidationError
from yacs.config import CfgNode

# local application/library specific imports
from example_selector.example_selector import (
    RandomExampleSelector,
    StudentIDExampleSelector,
)
from data_loader.data_loader import DataLoader
from tools.constants import SILVER_DIR, TRAIN, VALIDATION, TEST
from prompt_template.example_prompt import (
    df_to_listdict,
    human_format_input,
    human_format_output,
    apply_prompt_fmt,
)
from model.build import build_model

# Reload the variables in your '.env' file (override the existing variables)
dotenv.load_dotenv("../.env", override=True)

True

In [38]:
MODEL_STRUCTURED_OUTPUT = {
    "llama3": False,
    "llama3.2": True,
    "olmo2:7b": False,
    "gpt-4o": True,
    "gpt-4o-mini": True,
}

In [39]:
### INPUTS ###
MODEL_NAME = "olmo2:7b"  # "llama3"  # "gpt-4o-mini"  # "llama3.2"
MODEL_PROVIDER = "ollama"  # "openai"  # 
SUPPORTS_STRUCTURED_OUTPUT = MODEL_STRUCTURED_OUTPUT[MODEL_NAME]

In [40]:
model_cfg = CfgNode(
    {
        "NAME": MODEL_NAME,
        "PROVIDER": MODEL_PROVIDER,
        "TEMPERATURE": 0.5,
        "FORMAT": "json",
        "MAX_TOKENS": None,
        "TIMEOUT": None,
        "MAX_RETRIES": None,
    }
)

# Data

In [41]:
# load data
data_loader = DataLoader(read_dir=SILVER_DIR, dataset_name="dbe_kt22")
dataset = data_loader.split_data(train_size=0.6, test_size=0.25, seed=42)

# dataframes
df_train = apply_prompt_fmt(
    df=dataset[TRAIN], input_fmt=human_format_input, output_fmt=human_format_output
)
df_val = apply_prompt_fmt(
    df=dataset[VALIDATION], input_fmt=human_format_input, output_fmt=human_format_output
)
df_test = apply_prompt_fmt(
    df=dataset[TEST], input_fmt=human_format_input, output_fmt=human_format_output
)

# list of dicts
list_train = df_to_listdict(df_train)
list_val = df_to_listdict(df_val)
list_test = df_to_listdict(df_test)

[2m2025-03-18 11:59:10[0m [[32m[1minfo     [0m] [1mSet seed (42)                 [0m
[2m2025-03-18 11:59:10[0m [[32m[1minfo     [0m] [1mCreating train split          [0m [36mnum_interactions[0m=[35m6107[0m
[2m2025-03-18 11:59:10[0m [[32m[1minfo     [0m] [1mCreating validation split     [0m [36mnum_interactions[0m=[35m1528[0m
[2m2025-03-18 11:59:11[0m [[32m[1minfo     [0m] [1mCreating test split           [0m [36mnum_interactions[0m=[35m2546[0m


# Dynamic few-shot prompting

## Create example selector

NOTE: I need OpenAI credits to use the OpenAI embeddings.

In [42]:
# examples = few_shot_list
# to_vectorize = [" ".join(example.values()) for example in examples]
# embeddings = OpenAIEmbeddings()
# vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=examples)

In [43]:
# example_selector = SemanticSimilarityExampleSelector(
#     vectorstore=vectorstore,
#     k=2,
# )

# # The prompt template will load examples by passing the input do the `select_examples` method
# example_selector.select_examples({"input": "horse"})

In [44]:
# # Create the selector with k=3 for 3-shot prompting
# example_selector = RandomExampleSelector(examples=list_train, k=3)
# example_selector.select_examples({})

In [45]:
# Select examples of a specific student
example_selector = StudentIDExampleSelector(examples=list_train, k=3)
example_selector.select_examples({"student_id": 395})

[{'input': 'Question:\nConsider the following two transactions <img src="http://latex.codecogs.com/gif.latex?T_{1}" border="0"/> and <img src="http://latex.codecogs.com/gif.latex?T_{2}" border="0"/>, which transfer money between\r\ndifferent accounts. If the transaction isolation level is “read uncommitted”, which of the\r\nfollowing schedules is not serializable?<br>\r\n<html>\r\n<head>\r\n<meta name="viewport" content="width=device-width, initial-scale=1">\r\n<style>\r\n* {\r\n  box-sizing: border-box;\r\n}\r\n.column {\r\n  float: left;\r\n  width: 20%;\r\n  padding:9px;\r\n  height: 400px; \r\n}\r\n.row:after {\r\n  content: "";\r\n  display: table;\r\n  clear: both;\r\n}\r\n</style>\r\n</head>\r\n<body>\r\n<div class="row"display= "table">\r\n  <div class="column" >\r\n    <h2>1</h2>\r\n   <table style="padding: 20px;text-align: center;width: 90%;" border="1">\r\n\r\n  <tr>\r\n    <th><center><img src="http://latex.codecogs.com/gif.latex?T_{1}" border="0"/></th>\r\n    <th><center

## Create prompt template

In [46]:
# Pydantic
class MCQAnswer(BaseModel):
    """Answer to a multiple-choice question."""

    explanation: str = Field(
        description="Misconception if incorrectly answered; motivation if correctly answered"
    )
    student_answer: int = Field(
        description="The student's answer to the question, as an integer (1-4)"
    )
    # difficulty: str = Field(description="The difficulty level of the question")

In [50]:
# Define the few-shot prompt.
few_shot_prompt = FewShotChatMessagePromptTemplate(
    # The input variables select the values to pass to the example_selector
    input_variables=["student_id"],
    example_selector=example_selector,
    # Define how each example will be formatted.
    # In this case, each example will become 2 messages:
    # 1 human, and 1 AI
    example_prompt=ChatPromptTemplate.from_messages(
        [("human", "{input}"), ("ai", "{output}")]
    ),
)

out = few_shot_prompt.invoke(input=list_val[0]).to_messages()
print(len(out))
print(out)

6
[HumanMessage(content='Question:\nThe order of tuples in a relation is ____. The order of attributes in a relation is _________.\n\nOptions:\n1. Not important; Not important\n2. Important; Important\n3. Not important; Important\n4. Important; Not important\n\nCorrect answer: 3', additional_kwargs={}, response_metadata={}), AIMessage(content='Student answer: 3', additional_kwargs={}, response_metadata={}), HumanMessage(content='Question:\nIf A = {2, 3, 4, 5}, B = {4, 5, 6, 7}, C = {6, 7, 8, 9}, D = {8, 9, 10, 11}, then A - B=___________.\n\nOptions:\n1. {2,3}\n2. {2,3,4,5}\n3. {4,5,6,7}\n4. {6,7}\n\nCorrect answer: 1', additional_kwargs={}, response_metadata={}), AIMessage(content='Student answer: 1', additional_kwargs={}, response_metadata={}), HumanMessage(content='Question:\nIf A × B = {(p, x), (p, y), (q, x), (q, y)}, then A = _____and B = ______.\n\nOptions:\n1. A = {p, q, x} and B = { y}\n2. A = {p} and B = {q, x, y}\n3. A ={x, y} and B = {p, q}\n4. A = {p, q} and B = {x, y}\n\n

In [52]:
system_prompt_raw = (
    "You are a student working on {exam_type}, containing multiple choice questions. "
    "You are shown a set of questions that you answered earlier in the exam, together with the correct answers and your student answers. "
    "Analyse your responses to the questions and identify the possible misconceptions that led to answering incorrectly. "
    "Inspect the new question and think how you would answer it as a student. "
    "If you answer incorrectly, explain which misconception leads to selecting that answer. "
    "If you answer correctly, explain why you think the answer is correct. "
    "Provide your answer as an integer in the range 1-4. "
)
if not SUPPORTS_STRUCTURED_OUTPUT:
    system_prompt_raw += "Wrap the output in `json` tags\n{format_instructions}"
    # Set up a parser
    parser = PydanticOutputParser(pydantic_object=MCQAnswer)

final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt_raw),
        few_shot_prompt,
        ("human", "{input}"),
    ]
).partial(format_instructions=parser.get_format_instructions(),
          exam_type="a database systems exam (Department of Computer Science)")

print(
    final_prompt.invoke(
        input=list_val[0],
    ).to_string()
)
out = final_prompt.invoke(input=list_val[0]).to_messages()
print(len(out))
print(out)

System: You are a student working on a database systems exam (Department of Computer Science), containing multiple choice questions. You are shown a set of questions that you answered earlier in the exam, together with the correct answers and your student answers. Analyse your responses to the questions and identify the possible misconceptions that led to answering incorrectly. Inspect the new question and think how you would answer it as a student. If you answer incorrectly, explain which misconception leads to selecting that answer. If you answer correctly, explain why you think the answer is correct. Provide your answer as an integer in the range 1-4. Wrap the output in `json` tags
The output should be formatted as a JSON instance that conforms to the JSON schema below.

As an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}
the object {"foo": ["bar", "baz"]} is a 

# Model

In [53]:
# model
model = build_model(model_cfg=model_cfg)
if SUPPORTS_STRUCTURED_OUTPUT:
    model = model.with_structured_output(MCQAnswer)

# chain
chain = final_prompt | model
if not SUPPORTS_STRUCTURED_OUTPUT:
    chain = chain.pipe(parser)

[2m2025-03-18 11:59:11[0m [[32m[1minfo     [0m] [1mBuilding model                [0m [36mname[0m=[35molmo2:7b[0m [36mprovider[0m=[35mollama[0m


In [54]:
# run model
val_example = list_val[0]
val_output = chain.invoke(val_example)
val_output

MCQAnswer(explanation='The Cartesian product A \times B includes all possible pairs where the first element is from set A and the second element is from set B. Given three pairs (2, 5), (3, 7), and (4, 7), we know that there must be at least one pair for each element in A (2 and 3) and at least one pair for each element in B (5, 7). Therefore, A \times B contains the pairs: (2, 5), (3, 7), and (4, 7). The correct option includes all these pairs plus one more pair for each of the remaining elements in A (which are 1 and 4) and one pair for each element in B (which are 6 and 9). Thus, the complete set is {(2, 5), (2, 7), (3, 5), (3, 7), (4, 5), (4, 7), (1, 6), (3, 6), (5, 6), (4, 9), (7, 9)} which matches option 2.', student_answer=2)

In [None]:
# TODO: add func to only print input (also printing output can be confusing)
def print_example(example: dict) -> None:
    """Print single example.

    Parameters
    ----------
    example : dict
        Example dictionary with 'input' and 'output' keys.
    """
    text = (
        "#" * 40
        + f"\nINPUT\n"
        + "#" * 40
        + f"\n{example['input']}\n"
        + "#" * 40
        + f"\nOUTPUT\n"
        + "#" * 40
        + f"\n{example['output']}\n"
    )
    print(text)


print_example(list_val[0])