# Experimenting with getting history tutor to update state correctly since this can be unreliable sometimes.

In [None]:
%pip install langchain
%pip install typing



In [76]:


import json
from typing import List
from langchain_google_vertexai import VertexAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain.globals import set_debug
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field

model = VertexAI(model_name="gemini-pro", temperature=0)


# Data structure for the output of the update_question_state function
class AnswerQuestionsList(BaseModel):
    questions: List[str] = Field(description="questions that have been answered")

# Copied from history_tutor.py for convenience.
def update_question_state(prompt, world_state, last_answer):        
        
        parser = JsonOutputParser(pydantic_object=AnswerQuestionsList)
        
        prompt = PromptTemplate(
            template=prompt,
            input_variables=["question", "last_message"],
            partial_variables={"format_instructions": parser.get_format_instructions()},
        )

        chain = (prompt | model | parser)
        response = chain.invoke(
            {
                "question": world_state,
                "last_message": last_answer,
            }
        )
        answered_questions = response["questions"]
                
        return answered_questions
        

In [119]:
from langchain_google_vertexai import ChatVertexAI, VertexAI
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser


UPDATE_STATE_PROMPT = """
    You are a history tutor and you must mark a history test.
    
    This is the marking sheet with the question and the list of the correct answers:
    {question}
    
    The student has given the following answer:    
    ===
    {last_message}
    ===
    
    ONLY consider the text provided within the triple equals (===) as the student's answer.
    
    Think carefully, does the student's answer give any of the answers on the marking sheet? 
    
    Return ONLY the answers that the STUDENT correctly gave. 
    If the student's answer didn't contain any correct answers then return an EMPTY string.
        
    Give the FULL TEXT of each correct answer as it appears in the marking sheet.
    
    {format_instructions}
"""

world_state = {
    "question": "What were the different theories about the cause of the Black Death?",
    "answers": [
      {
        "answer": "Religion: God sent the plague as a punishment for people's sins.",
        "hasAnswered": "false"
      },
      {
        "answer": "Miasma: 'bad air' or smells caused by decaying rubbish.",
        "hasAnswered": "false"
      },
      {
        "answer": "Four Humours: most physicians believed that disease was caused by an imbalance in the Four Humours.",
        "hasAnswered": "false"
      },
      {
        "answer": "Outsiders: strangers or witches had caused the disease.",
        "hasAnswered": "false"
      }
    ]
  }
set_debug(False)

# Get one answer correct.
#print(json.dumps(update_question_state(UPDATE_STATE_PROMPT, world_state, 'Was it bad air'), indent=4))

# Answers two questions in one go.
#print(json.dumps(update_question_state(UPDATE_STATE_PROMPT, world_state, 'The four humors and the miasma theory'), indent=4))

# Answers all questions correctly.
#print(json.dumps(update_question_state(UPDATE_STATE_PROMPT, world_state, 'They thought it was a punishment from god, miasma, an imbalance in the four humors, and strangers'), indent=4))

# No correct answers.
#print(json.dumps(update_question_state(UPDATE_STATE_PROMPT, world_state, 'Bananas'), indent=4))

print(json.dumps(update_question_state(UPDATE_STATE_PROMPT, world_state, 'begin'), indent=4))


[
    "Religion: God sent the plague as a punishment for people's sins.",
    "Miasma: 'bad air' or smells caused by decaying rubbish.",
    "Four Humours: most physicians believed that disease was caused by an imbalance in the Four Humours.",
    "Outsiders: strangers or witches had caused the disease."
]


In [49]:
# Trying it with JSON instead of turning the questions into text.

prompt = """
    You are a history tutor and you must mark a history test.
    
    This JSON bundle is the marking sheet with the question and the list of the correct answers:
    
    ===
    {{
    "question": "What were the different theories about the cause of the Black Death?",
    "answers": [
      {{
        "answer": "Religion: God sent the plague as a punishment for people's sins.",
        "hasAnswered": "false"
      }},
      {{
        "answer": "Miasma: 'bad air' or smells caused by decaying rubbish.",
        "hasAnswered": "false"
      }},
      {{
        "answer": "Four Humours: most physicians believed that disease was caused by an imbalance in the Four Humours.",
        "hasAnswered": "false"
      }},
      {{
        "answer": "Outsiders: strangers or witches had caused the disease.",
        "hasAnswered": "false"
      }}
    ]
  }}
  ===
    

    The student has given the following response:    
    ===
    {last_message}
    ===

    Consider the student's response carefully. 
    ONLY consider the text provided within ===
    Think carefully, does this text give any of the answers on the marking sheet? 

    Return ONLY the answers that the STUDENT correct gave. 
    If the student's answer contains no correct answers, return an EMPTY string.

    Example:
    STUDENT RESPONSE: "start lesson"
    YOUR RESPONSE: ""

    Give the FULL TEXT of each correct answer as it appears in the marking sheet.

    {format_instructions}
"""

set_debug(False)
parser = JsonOutputParser(pydantic_object=AnswerQuestionsList)
        
template = PromptTemplate(
    template=prompt,
    input_variables=["last_message"],
    partial_variables={"format_instructions": parser.get_format_instructions()},
)

chain = (template | model | parser)

response = chain.invoke(
    {
        "last_message": "miasma, an imbalance in the four humours, strangers, and a punishment from God.",
    }
)
answered_questions = response["questions"]
print(answered_questions)



["Miasma: 'bad air' or smells caused by decaying rubbish.", 'Four Humours: most physicians believed that disease was caused by an imbalance in the Four Humours.']
