# Imports

In [None]:
# standard library imports
import os
import random
from typing import Callable

# related third party imports
import dotenv
import pandas as pd
import numpy as np
import structlog
from langchain_chroma import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotChatMessagePromptTemplate,
    PromptTemplate,
    SystemMessagePromptTemplate,
)
from langchain_core.output_parsers import PydanticOutputParser
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field, ValidationError
from yacs.config import CfgNode
from sklearn.metrics import accuracy_score

# local application/library specific imports
from example_selector.example_selector import (
    RandomExampleSelector,
    StudentIDExampleSelector,
)
from data_loader.data_loader import DataLoader
from tools.constants import SILVER_DIR, TRAIN, VALIDATION, TEST, MODEL_STRUCTURED_OUTPUT
from prompt.few_shot_prompt import (
    df_to_listdict,
    human_format_input,
    human_format_output,
    apply_prompt_fmt,
)
from model.build import build_model

logger = structlog.get_logger()

# Reload the variables in your '.env' file (override the existing variables)
dotenv.load_dotenv("../.env", override=True)

In [None]:
### INPUTS ###
MODEL_NAME = "olmo2:7b"  # "gpt-4o-mini"  # "llama3"  # "llama3.2"
MODEL_PROVIDER = "ollama"  # "openai"  # 
SUPPORTS_STRUCTURED_OUTPUT = MODEL_STRUCTURED_OUTPUT[MODEL_NAME]

In [None]:
model_cfg = CfgNode(
    {
        "NAME": MODEL_NAME,
        "PROVIDER": MODEL_PROVIDER,
        "TEMPERATURE": 0.5,
        "FORMAT": "json",
        "MAX_TOKENS": None,
        "TIMEOUT": None,
        "MAX_RETRIES": None,
    }
)

# Data

In [None]:
# load data
data_loader = DataLoader(read_dir=SILVER_DIR, dataset_name="dbe_kt22")
dataset = data_loader.split_data(train_size=0.6, test_size=0.25, seed=42)

# dataframes
df_train = apply_prompt_fmt(
    df=dataset[TRAIN], input_fmt=human_format_input, output_fmt=human_format_output
)
df_val = apply_prompt_fmt(
    df=dataset[VALIDATION], input_fmt=human_format_input, output_fmt=human_format_output
)
df_test = apply_prompt_fmt(
    df=dataset[TEST], input_fmt=human_format_input, output_fmt=human_format_output
)

# list of dicts
list_train = df_to_listdict(df_train)
list_val = df_to_listdict(df_val)
list_test = df_to_listdict(df_test)

In [None]:
df_val.head()

In [None]:
df_val.iloc[:10,:]

# Dynamic few-shot prompting

## Create example selector

NOTE: I need OpenAI credits to use the OpenAI embeddings.

In [None]:
# examples = few_shot_list
# to_vectorize = [" ".join(example.values()) for example in examples]
# embeddings = OpenAIEmbeddings()
# vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=examples)

In [None]:
# example_selector = SemanticSimilarityExampleSelector(
#     vectorstore=vectorstore,
#     k=2,
# )

# # The prompt template will load examples by passing the input do the `select_examples` method
# example_selector.select_examples({"input": "horse"})

In [None]:
# Create the selector with k=3 for 3-shot prompting
example_selector = RandomExampleSelector(examples=list_train, k=3)
example_selector.select_examples({})

In [None]:
# # Select examples of a specific student
# example_selector = StudentIDExampleSelector(examples=list_train, k=3)
# example_selector.select_examples({"student_id": 395})

## Create prompt template

In [None]:
# Pydantic
class MCQAnswer(BaseModel):
    """Answer to a multiple-choice question."""

    explanation: str = Field(
        description="Misconception if incorrectly answered; motivation if correctly answered"
    )
    student_answer: int = Field(
        description="The student's answer to the question, as an integer (1-4)"
    )
    # difficulty: str = Field(description="The difficulty level of the question")

In [None]:
# Define the few-shot prompt.
few_shot_prompt = FewShotChatMessagePromptTemplate(
    # The input variables select the values to pass to the example_selector
    input_variables=["student_id"],  # TODO: do not hardcode
    example_selector=example_selector,
    # Define how each example will be formatted.
    # In this case, each example will become 2 messages:
    # 1 human, and 1 AI
    example_prompt=ChatPromptTemplate.from_messages(
        [("human", "{input}"), ("ai", "{output}")]
    ),
)

out = few_shot_prompt.invoke(input=list_val[0]).to_messages()
print(len(out))
print(out)

In [None]:
system_prompt_raw = (
    "You are a student working on {exam_type}, containing multiple choice questions. "
    "You are shown a set of questions that you answered earlier in the exam, together with the correct answers and your student answers. "
    "Analyse your responses to the questions and identify the possible misconceptions that led to answering incorrectly. "
    "Inspect the new question and think how you would answer it as a student. "
    "If you answer incorrectly, explain which misconception leads to selecting that answer. "
    "If you answer correctly, explain why you think the answer is correct. "
    "Provide your answer as an integer in the range 1-4. "
)
# Set up a parser (not used if model supports structured output)
parser = PydanticOutputParser(pydantic_object=MCQAnswer)
if not SUPPORTS_STRUCTURED_OUTPUT:
    system_prompt_raw += "Wrap the output in `json` tags\n{format_instructions}"


final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt_raw),
        few_shot_prompt,
        ("human", "{input}"),
    ]
).partial(
    format_instructions=parser.get_format_instructions(),
    exam_type="a database systems exam (Department of Computer Science)",
)

# print(
#     final_prompt.invoke(
#         input=list_val[0],
#     ).to_string()
# )
out = final_prompt.invoke(input=list_val[0]).to_messages()
print(len(out))
print(out)

# Model

In [None]:
# model
model = build_model(model_cfg=model_cfg)
if SUPPORTS_STRUCTURED_OUTPUT:
    model = model.with_structured_output(MCQAnswer, include_raw=True)

# chain
chain = final_prompt | model
# if not SUPPORTS_STRUCTURED_OUTPUT:
#     chain = chain.pipe(parser)

In [None]:
# # run model
# val_example = list_val[0]
# val_output = chain.invoke(val_example)
# val_output

In [None]:
# run model in batch
val_output = chain.batch(list_val[:10])
val_output = [output["raw"] for output in val_output]
val_output

In [None]:
val_output_test = val_output

In [None]:
from langchain_core.messages.ai import AIMessage

val_output_test[0] = AIMessage(content="""{"explanation": "I don't know", "student_answer": "bla"}""")
val_output_test

In [None]:
from langchain_core.exceptions import OutputParserException

def validate_output(outputs: list, schema) -> list:
    """Validate the LLM outputs against the schema.

    Parameters
    ----------
    outputs : list
        List of AIMessages
    schema : _type_
        Pydanctic schema

    Returns
    -------
    list
        List of validated outputs
    """
    outputs_validated = []
    for i, output in enumerate(outputs):
        try:
            output_validated = parser.invoke(output)
        except OutputParserException as e:
            logger.warning("Invalid output", index=i)
            print(e)
            output_validated = schema(explanation="", student_answer=-1)
        outputs_validated.append(output_validated)

    return outputs_validated

# val_output_validated = validate_output(outputs=val_output, schema=MCQAnswer)
# val_output_validated

In [None]:
val_output_validated = validate_output(outputs=val_output_test, schema=MCQAnswer)
val_output_validated

In [None]:
# TODO: count number of invalid responses (student_answer=-1)

In [None]:
y_val_pred = np.array([output.student_answer for output in val_output_validated])
y_val_pred

In [None]:
y_val_student = dataset[VALIDATION]["student_answer"].to_numpy()[:10]
y_val_student


In [None]:
y_val_true = dataset[VALIDATION]["correct_answer"].to_numpy()[:10]
y_val_true

In [None]:
acc_student_pred = accuracy_score(y_true=y_val_student, y_pred=y_val_pred)
acc_true_student = accuracy_score(y_true=y_val_true, y_pred=y_val_student)
acc_true_pred = accuracy_score(y_true=y_val_true, y_pred=y_val_pred)

print(f"{acc_student_pred = }")
print(f"{acc_true_student = }")
print(f"{acc_true_pred = }")

In [None]:
# TODO: add func to only print input (also printing output can be confusing)
def print_example(example: dict) -> None:
    """Print single example.

    Parameters
    ----------
    example : dict
        Example dictionary with 'input' and 'output' keys.
    """
    text = (
        "#" * 40
        + f"\nINPUT\n"
        + "#" * 40
        + f"\n{example['input']}\n"
        + "#" * 40
        + f"\nOUTPUT\n"
        + "#" * 40
        + f"\n{example['output']}\n"
    )
    print(text)


print_example(list_val[0])