# Imports

In [12]:
# standard library imports
import os
import random
from typing import Callable

# related third party imports
import dotenv
import pandas as pd
import numpy as np
import structlog
from langchain_chroma import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_ollama import OllamaEmbeddings
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotChatMessagePromptTemplate,
    PromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.documents import Document
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field, ValidationError
from yacs.config import CfgNode
from sklearn.metrics import accuracy_score

# local application/library specific imports
from example_selector.example_selector import (
    RandomExampleSelector,
    StudentIDRandomExampleSelector,
)
from data_loader.data_loader import DataLoader
from tools.constants import SILVER_DIR, TRAIN, VALIDATION, TEST, MODEL_STRUCTURED_OUTPUT
from prompt.few_shot_prompt import (
    df_to_listdict,
)
from model.build import build_model
from example_formatter.build import build_example_formatter

logger = structlog.get_logger()

# Reload the variables in your '.env' file (override the existing variables)
dotenv.load_dotenv("../.env", override=True)

True

In [2]:
### INPUTS ###
MODEL_NAME = "llama3"  # "olmo2:7b"  # "gpt-4o-mini"  # "llama3.2"
MODEL_PROVIDER = "ollama"  # "openai"  # 
SUPPORTS_STRUCTURED_OUTPUT = MODEL_STRUCTURED_OUTPUT[MODEL_NAME]

In [3]:
model_cfg = CfgNode(
    {
        "NAME": MODEL_NAME,
        "PROVIDER": MODEL_PROVIDER,
        "TEMPERATURE": 0.5,
        "MAX_TOKENS": None,
        "TIMEOUT": None,
        "MAX_RETRIES": None,
    }
)
example_formatter_cfg = CfgNode(
    {
        "NAME": "A"
    }
)

# Data

In [4]:
# load data
data_loader = DataLoader(read_dir=SILVER_DIR, dataset_name="dbe_kt22", join_key="question_id")
datasets = data_loader.split_data(train_size=0.6, test_size=0.25, seed=42)


[2m2025-03-27 09:32:26[0m [[32m[1minfo     [0m] [1mSet seed (42)                 [0m
[2m2025-03-27 09:32:26[0m [[32m[1minfo     [0m] [1mCreating train split          [0m [36mnum_interactions[0m=[35m1967[0m
[2m2025-03-27 09:32:26[0m [[32m[1minfo     [0m] [1mCreating validation split     [0m [36mnum_interactions[0m=[35m492[0m
[2m2025-03-27 09:32:26[0m [[32m[1minfo     [0m] [1mCreating test split           [0m [36mnum_interactions[0m=[35m820[0m


In [5]:

# # dataframes
# df_train = apply_prompt_fmt(
#     df=dataset[TRAIN], input_fmt=human_format_input, output_fmt=human_format_output
# )
# df_val = apply_prompt_fmt(
#     df=dataset[VALIDATION], input_fmt=human_format_input, output_fmt=human_format_output
# )
# df_test = apply_prompt_fmt(
#     df=dataset[TEST], input_fmt=human_format_input, output_fmt=human_format_output
# )

# # list of dicts
# list_train = df_to_listdict(df_train)
# list_val = df_to_listdict(df_val)
# list_test = df_to_listdict(df_test)

In [6]:
# dataframes
datasets_fmt = build_example_formatter(
    example_formatter_cfg=example_formatter_cfg,
    datasets=datasets,
)

# list of dicts
list_train = df_to_listdict(datasets_fmt[TRAIN])
list_val = df_to_listdict(datasets_fmt[VALIDATION])
list_test = df_to_listdict(datasets_fmt[TEST])  # noqa

[2m2025-03-27 09:32:26[0m [[32m[1minfo     [0m] [1mBuilding example formatter    [0m [36mname[0m=[35mA[0m [36msplits[0m=[35m['train', 'validation', 'test'][0m


In [7]:
datasets_fmt[VALIDATION].head()

Unnamed: 0,input,output,student_id,question_id,interact_id,q_text
6,Question:\nWhat is the Cartesian product of A ...,Student answer: 0,86,4,787,"What is the Cartesian product of A = {1, 2} an..."
10,"Question:\nIf A × B = {(p, x), (p, y), (q, x),...",Student answer: 1,86,8,791,"If A × B = {(p, x), (p, y), (q, x), (q, y)}, t..."
11,"Question:\nIf A = {2, 3, 4, 5}, B = {4, 5, 6, ...",Student answer: 1,86,9,792,"If A = {2, 3, 4, 5}, B = {4, 5, 6, 7}, C = {6,..."
15,"Question:\nIf A = {2, 3, 4, 5}, B = {4, 5, 6, ...",Student answer: 3,86,13,796,"If A = {2, 3, 4, 5}, B = {4, 5, 6, 7}, C = {6,..."
21,Question:\nConsider a database that stores nam...,Student answer: 1,31,95,7873,"Consider a database that stores names, address..."


# Dynamic few-shot prompting

## Create example selector

NOTE: I need OpenAI credits to use the OpenAI embeddings.

In [8]:
import time
from pinecone import Pinecone, ServerlessSpec
from langchain_ollama import OllamaEmbeddings
from langchain_pinecone import PineconeVectorStore


index_name = "llama3"  # change if desired

pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY"))

EMBEDDINGS_DIM = {"llama3": 4096}

if not pc.has_index(index_name):
    pc.create_index(
        name=index_name,
        dimension=EMBEDDINGS_DIM[model_cfg.NAME],
        metric="cosine",
        spec=ServerlessSpec(cloud="aws", region="us-east-1"),
    )
    while not pc.describe_index(index_name).status["ready"]:
        time.sleep(1)

index = pc.Index(index_name)

embeddings = OllamaEmbeddings(model="llama3")  # TODO: make dynamic

vector_store = PineconeVectorStore(index=index, embedding=embeddings, namespace="dbe_kt22")

In [None]:
# # prepare data for vector store
# vector_input_df = datasets[TRAIN].drop_duplicates(subset="question_id")

# vector_input_doc = [
#     Document(
#         page_content=row["q_text"],
#         metadata={
#             "question_id": str(row["question_id"]),
#         },
#     )
#     for _, row in vector_input_df.iterrows()
# ]
# vector_input_id = vector_input_df["question_id"].astype(str).tolist()

# print(vector_input_doc)
# len(vector_input_doc)
# print(vector_input_id)

[Document(metadata={'question_id': '3'}, page_content='The set that consists of all odd positive integers less than 10 is represented by _____________.'), Document(metadata={'question_id': '4'}, page_content='What is the Cartesian product of A = {1, 2} and B = {a, b}?'), Document(metadata={'question_id': '5'}, page_content='The Cartesian product B x A is always equal to the Cartesian product A x B. Is it true or false?'), Document(metadata={'question_id': '2'}, page_content='A __________ is a collection of distinct elements.\r\n\r\n\r\n'), Document(metadata={'question_id': '6'}, page_content='The cardinality of a set is the number of elements of the set. What is the cardinality of the set of odd positive integers less than 10?'), Document(metadata={'question_id': '7'}, page_content='Which of the following two sets are equal?'), Document(metadata={'question_id': '10'}, page_content='If A = {2, 3, 4, 5}, B = {4, 5, 6, 7}, C = {6, 7, 8, 9}, D = {8, 9, 10, 11}, then A - (A ∩ B)=________.')

In [None]:
# _ = vector_store.add_documents(documents=vector_input_doc, ids=vector_input_id)

In [None]:
# "llama-text-embed-v2"

In [None]:
# vector_store.delete(delete_all=True)

In [None]:
from langchain_core.example_selectors.base import BaseExampleSelector
from langchain_pinecone import PineconeVectorStore
from langchain_openai import OpenAIEmbeddings
from pinecone import Pinecone


def get_vector_store(
    index_name: str, embedding_name: str, namespace: str
) -> PineconeVectorStore:
    """Get the Pinecode vector store.

    Parameters
    ----------
    index_name : str
        Index name
    embedding_name : str
        Embedding name
    namespace : str
        Index namespace

    Returns
    -------
    PineconeVectorStore
        The Pinecone vector store.

    Raises
    ------
    ValueError
        If the index does not exist.
    """
    pc = Pinecone(api_key=os.environ.get("PINECONE_API_KEY"))
    existing_indexes = [index_info["name"] for index_info in pc.list_indexes()]
    if index_name not in existing_indexes:
        raise ValueError(f"Index {index_name} does not exist.")
    index = pc.Index(index_name)
    embeddings = OllamaEmbeddings(
        model=embedding_name
    )  # TODO: how to handle different embedding providers?
    vector_store = PineconeVectorStore(
        index=index, embedding=embeddings, namespace=namespace
    )
    logger.info(
        f"Loaded Pinecone vector store", index_name=index_name, namespace=namespace
    )
    return vector_store


EMBEDDING_NAMES = {"llama3": "llama3"}


class StudentIDSemanticExampleSelector(BaseExampleSelector):
    """Filter examples of the same student_id and select based on semantic similarity."""

    def __init__(
        self, examples: list, k: int, index_name: str, model_name: str, namespace: str
    ) -> None:
        """Initialize the example selector.

        Parameters
        ----------
        k : int
            k-shot prompting
        index_name : str
            The name of the Pinecone index.
        model_name : str
            The name of the LLM.
        namespace : str
            The namespace of the Pinecone index.
        """
        self.examples = examples
        self.k = k

        embedding_name = EMBEDDING_NAMES[model_name]
        self.vectorstore = get_vector_store(
            index_name=index_name, embedding_name=embedding_name, namespace=namespace
        )

    def add_example(self, example: list) -> None:
        self.examples.append(example)

    def select_examples(self, input_variables: dict) -> list[dict[str, str]]:
        """Select examples based on semantic similarity.

        Parameters
        ----------
        input_variables : dict[str, str]
            A dict containing info about a single observation.

        Returns
        -------
        list[dict[str, str]]
            The selected examples.
        """
        # information of target observation
        student_id = input_variables["student_id"]
        question_id = input_variables["question_id"]
        q_text = input_variables["q_text"]

        # find all questions answered by this student
        student_interactions = [
            interact
            for interact in self.examples
            if interact["student_id"] == student_id
        ]
        q_answered = set([interact["question_id"] for interact in student_interactions])
        q_answered = list(
            map(str, q_answered - {question_id})
        )  # NOTE: remove current question_id
        print(f"{q_answered=}")  # TODO: remove

        # semantic search on question text
        results = self.vectorstore.similarity_search(
            query=q_text,
            k=self.k,
            filter={"question_id": {"$in": q_answered}},
        )
        question_ids_selected = list(
            map(int, [res.metadata["question_id"] for res in results])
        )
        print(f"{question_ids_selected=}")  # TODO: remove

        # find interactions of selected question_ids and student_id
        interactions_selected = [
            interact
            for interact in self.examples
            if (
                interact["question_id"] in question_ids_selected
                and interact["student_id"] == student_id
            )
        ]
        # if a Q has multiple interactions, randomly select one
        if len(interactions_selected) > self.k:
            # find duplicate Q IDs
            question_ids_interacted = np.array(
                [interact["question_id"] for interact in interactions_selected]
            )
            unique, counts = np.unique(question_ids_interacted, return_counts=True)
            duplicate_q_ids = unique[np.where(counts > 1)]

            # sample from duplicate Q IDs
            for q_id in duplicate_q_ids:
                # find indexes to remove
                idxs = np.where(question_ids_interacted == q_id)[0].tolist()
                idx_to_remove = random.sample(idxs, len(idxs) - 1)
                for idx in idx_to_remove:
                    interactions_selected.pop(idx)
        if len(interactions_selected) < self.k:
            raise NotImplementedError(
                "TODO: do we randomly select interactions or leave them empty?"
            )
            # TODO

        return interactions_selected
        # NOTE: can decide to only return input and output
        # return [
        #     {"input": interact["input"], "output": interact["output"]}
        #     for interact in interactions_selected
        # ]

In [71]:
example_selector = StudentIDSemanticExampleSelector(
    examples=list_train,
    k=2,
    index_name="llama3",
    model_name=model_cfg.NAME,
    namespace="dbe_kt22",
)
example_selector.select_examples(list_val[0])

[2m2025-03-27 11:34:28[0m [[32m[1minfo     [0m] [1mLoaded Pinecone vector store  [0m [36mindex_name[0m=[35mllama3[0m [36mnamespace[0m=[35mdbe_kt22[0m
q_answered=['2', '3', '36', '37', '6', '7', '39', '40', '10', '11', '43', '44', '14', '45', '46']
question_ids_selected=[6, 7]


[{'input': 'Question:\nThe cardinality of a set is the number of elements of the set. What is the cardinality of the set of odd positive integers less than 10?\n\nOptions:\n1. 3\n2. 5\n3. 10\n4. 20\n\nCorrect answer: 1',
  'output': 'Student answer: 1',
  'student_id': 86,
  'question_id': 6,
  'interact_id': 789,
  'q_text': 'The cardinality of a set is the number of elements of the set. What is the cardinality of the set of odd positive integers less than 10?'},
 {'input': 'Question:\nWhich of the following two sets are equal?\n\nOptions:\n1. A = {1, 2} and B = {1}\n2. A = {1, 2} and B = {1, 2, 3}\n3. A = {1, 2, 4} and B = {1, 2, 3}\n4. A = {1, 2, 3} and B = {2, 1, 3}\n\nCorrect answer: 3',
  'output': 'Student answer: 3',
  'student_id': 86,
  'question_id': 7,
  'interact_id': 790,
  'q_text': 'Which of the following two sets are equal?'}]

___

In [None]:
# examples = list_train[:10]
# to_vectorize = [example["input"] for example in examples]
# # embeddings = OpenAIEmbeddings()
# embeddings = OllamaEmbeddings(model="llama3")
# vectorstore = Chroma.from_texts(
#     texts=to_vectorize,
#     embedding=embeddings,
#     metadatas=examples,
#     persist_directory=os.path.join("output", "vectorstore", "chroma_langchain_db"),
# )

In [None]:
# NOTE: texts depend on the example formatter used

In [None]:
# vectorstore

In [None]:
# example_selector = SemanticSimilarityExampleSelector(
#     vectorstore=vectorstore,
#     k=1
# )
# example_selector.select_examples({"input": list_val[0]["input"]})

In [None]:
# example_selector = SemanticSimilarityExampleSelector(
#     vectorstore=vectorstore,
#     k=2,
# )

# # The prompt template will load examples by passing the input do the `select_examples` method
# example_selector.select_examples({"input": "horse"})

In [None]:
# # Create the selector with k=3 for 3-shot prompting
# example_selector = RandomExampleSelector(examples=list_train, k=3)
# example_selector.select_examples({})

In [None]:
# # Select examples of a specific student
# example_selector = StudentIDExampleSelector(examples=list_train, k=3)
# example_selector.select_examples({"student_id": 395})

## Create prompt template

In [None]:
# Pydantic
class MCQAnswer(BaseModel):
    """Answer to a multiple-choice question."""

    explanation: str = Field(
        description="Misconception if incorrectly answered; motivation if correctly answered"
    )
    student_answer: int = Field(
        description="The student's answer to the question, as an integer (1-4)"
    )
    # difficulty: str = Field(description="The difficulty level of the question")

In [None]:
# Define the few-shot prompt.
few_shot_prompt = FewShotChatMessagePromptTemplate(
    # The input variables select the values to pass to the example_selector
    input_variables=["student_id"],  # TODO: do not hardcode
    example_selector=example_selector,
    # Define how each example will be formatted.
    # In this case, each example will become 2 messages:
    # 1 human, and 1 AI
    example_prompt=ChatPromptTemplate.from_messages(
        [("human", "{input}"), ("ai", "{output}")]
    ),
)

out = few_shot_prompt.invoke(input=list_val[0]).to_messages()
print(len(out))
print(out)

In [None]:
system_prompt_raw = (
    "You are a student working on {exam_type}, containing multiple choice questions. "
    "You are shown a set of questions that you answered earlier in the exam, together with the correct answers and your student answers. "
    "Analyse your responses to the questions and identify the possible misconceptions that led to answering incorrectly. "
    "Inspect the new question and think how you would answer it as a student. "
    "If you answer incorrectly, explain which misconception leads to selecting that answer. "
    "If you answer correctly, explain why you think the answer is correct. "
    "Provide your answer as an integer in the range 1-4. "
)
# Set up a parser (not used if model supports structured output)
parser = PydanticOutputParser(pydantic_object=MCQAnswer)
if not SUPPORTS_STRUCTURED_OUTPUT:
    system_prompt_raw += "Wrap the output in `json` tags\n{format_instructions}"


final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt_raw),
        few_shot_prompt,
        ("human", "{input}"),
    ]
).partial(
    format_instructions=parser.get_format_instructions(),
    exam_type="a database systems exam (Department of Computer Science)",
)

# print(
#     final_prompt.invoke(
#         input=list_val[0],
#     ).to_string()
# )
out = final_prompt.invoke(input=list_val[0]).to_messages()
print(len(out))
print(out)

# Model

In [None]:
# model
model = build_model(model_cfg=model_cfg)
if SUPPORTS_STRUCTURED_OUTPUT:
    model = model.with_structured_output(MCQAnswer, include_raw=True)

# chain
chain = final_prompt | model
# if not SUPPORTS_STRUCTURED_OUTPUT:
#     chain = chain.pipe(parser)

In [None]:
from prompt.json_schema import validate_output

# run model in batch
preds_raw = chain.batch(list_val[:10])
if SUPPORTS_STRUCTURED_OUTPUT:
    # get all raw outputs
    preds_raw = [output["raw"] for output in preds_raw]
preds_validated = validate_output(preds_raw, schema=MCQAnswer)

In [None]:
y_val_pred = np.array([output.student_answer for output in preds_validated])
y_val_pred

In [None]:
y_val_student = datasets[VALIDATION]["student_option_id"].to_numpy()[:10]
y_val_student


In [None]:
y_val_true = datasets[VALIDATION]["correct_option_id"].to_numpy()[:10]
y_val_true

In [None]:
acc_student_pred = accuracy_score(y_true=y_val_student, y_pred=y_val_pred)
acc_true_student = accuracy_score(y_true=y_val_true, y_pred=y_val_student)
acc_true_pred = accuracy_score(y_true=y_val_true, y_pred=y_val_pred)

print(f"{acc_student_pred = }")
print(f"{acc_true_student = }")
print(f"{acc_true_pred = }")

In [None]:
# TODO: add func to only print input (also printing output can be confusing)
def print_example(example: dict) -> None:
    """Print single example.

    Parameters
    ----------
    example : dict
        Example dictionary with 'input' and 'output' keys.
    """
    text = (
        "#" * 40
        + f"\nINPUT\n"
        + "#" * 40
        + f"\n{example['input']}\n"
        + "#" * 40
        + f"\nOUTPUT\n"
        + "#" * 40
        + f"\n{example['output']}\n"
    )
    print(text)


print_example(list_val[0])