# Imports

In [None]:
import os
import random
from typing import Callable

import dotenv
import pandas as pd
from langchain_chroma import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotChatMessagePromptTemplate,
    PromptTemplate,
)
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field, ValidationError

from example_selector.example_selector import RandomExampleSelector
from data_loader.data_loader import DataLoader
from tools.constants import SILVER_DIR, TRAIN, VALIDATION, TEST, BRONZE_DIR

# Reload the variables in your '.env' file (override the existing variables)
dotenv.load_dotenv("../.env", override=True)

# Data

In [None]:
def human_format_input(row) -> str:
    # NOTE: this is flexible wrt the number of answer options
    text = f"Question:\n{row.q_text}\n\nOptions:\n"
    for i, option in enumerate(row.options_text):
        text += f"{i+1}. {option}\n"
    text += f"\nCorrect answer: {row.correct_answer}"
    return text


def human_format_output(row) -> str:
    return f"Student answer: {row.student_answer}"


def apply_prompt_fmt(
    df: pd.DataFrame, input_fmt: Callable, output_fmt: Callable
) -> pd.DataFrame:
    df_out = pd.DataFrame()
    df_out["input"] = df.apply(input_fmt, axis=1)
    df_out["output"] = df.apply(output_fmt, axis=1)
    return df_out

In [None]:
# load data
data_loader = DataLoader(read_dir=SILVER_DIR, dataset_name="dbe_kt22")
dataset = data_loader.split_data(train_size=0.6, test_size=0.25, seed=42)

# dataframes
df_train = apply_prompt_fmt(dataset[TRAIN], human_format_input, human_format_output)
df_val = apply_prompt_fmt(dataset[VALIDATION], human_format_input, human_format_output)
df_test = apply_prompt_fmt(dataset[TEST], human_format_input, human_format_output)

# list of dicts
list_train = [{"input": row["input"], "output": row["output"]} for _, row in df_train.iterrows()]
list_val = [{"input": row["input"], "output": row["output"]} for _, row in df_val.iterrows()]
list_test = [{"input": row["input"], "output": row["output"]} for _, row in df_test.iterrows()]

# Dynamic few-shot prompting

## Create example selector

NOTE: I need OpenAI credits to use the OpenAI embeddings.

In [None]:
# examples = few_shot_list
# to_vectorize = [" ".join(example.values()) for example in examples]
# embeddings = OpenAIEmbeddings()
# vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=examples)

In [None]:
# example_selector = SemanticSimilarityExampleSelector(
#     vectorstore=vectorstore,
#     k=2,
# )

# # The prompt template will load examples by passing the input do the `select_examples` method
# example_selector.select_examples({"input": "horse"})

In [None]:
# Create the selector with k=3 for 3-shot prompting
example_selector = RandomExampleSelector(examples=list_train, k=3)
example_selector.select_examples({})

## Create prompt template

In [None]:
system_prompt_template = PromptTemplate.from_template(
    "You are a student working on {exam_type}, containing multiple choice questions. "
    "You are shown a set of questions that you answered earlier in the exam, together with the correct answers and your student answers. "
    "Analyse your responses to the questions and identify the possible misconceptions that led to answering incorrectly. "

    "Inspect the new question and think how you would answer it as a student. "
    "If you answer incorrectly, explain which misconception leads to selecting that answer. "
    "If you answer correctly, explain why you think the answer is correct. "
    "Provide your answer as an integer in the range 1-4. "
)

system_prompt_input = system_prompt_template.format(
    exam_type="a database systems exam (Department of Computer Science)",
)
system_prompt_input

In [None]:
# Pydantic
class MCQAnswer(BaseModel):
    """Answer to a multiple-choice question."""

    explanation: str = Field(
        description="Misconception if incorrectly answered; motivation if correctly answered"
    )
    student_answer: int = Field(
        description="The student's answer to the question, as an integer (1-4)"
    )
    # difficulty: str = Field(description="The difficulty level of the question")

In [None]:
# Define the few-shot prompt.
few_shot_prompt = FewShotChatMessagePromptTemplate(
    # The input variables select the values to pass to the example_selector
    input_variables=["input"],
    example_selector=example_selector,
    # Define how each example will be formatted.
    # In this case, each example will become 2 messages:
    # 1 human, and 1 AI
    example_prompt=ChatPromptTemplate.from_messages(
        [("human", "{input}"), ("ai", "{output}")]
    ),
)

out = few_shot_prompt.invoke(input=list_val[0]["input"]).to_messages()
print(len(out))
print(out)

In [None]:
final_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system_prompt_input),
        few_shot_prompt,
        ("human", "{input}"),
    ]
)

print(list_val[0]["input"])
out = final_prompt.invoke(input=list_val[0]["input"]).to_messages()
print(len(out))
print(out)

# Model

In [None]:
# load model
model = ChatOllama(
    model="llama3.2",
    temperature=0.5,
).with_structured_output(MCQAnswer)
chain = final_prompt | model


In [None]:
# run model
val_example = list_val[0]
val_output = chain.invoke(val_example["input"])
val_output

In [None]:
# TODO: add func to only print input (also printing output can be confusing)
def print_example(example: dict) -> None:
    """Print single example.

    Parameters
    ----------
    example : dict
        Example dictionary with 'input' and 'output' keys.
    """
    text = (
        "#" * 40
        + f"\nINPUT\n"
        + "#" * 40
        + f"\n{example['input']}\n"
        + "#" * 40
        + f"\nOUTPUT\n"
        + "#" * 40
        + f"\n{example['output']}\n"
    )
    print(text)


print_example(list_val[0])