## Imports

In [47]:
%pip install langchain
%pip install ipywidgets
%pip install langchain-openai


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting ipywidgets
  Downloading ipywidgets-8.1.2-py3-none-any.whl (139 kB)
     -------------------------------------- 139.4/139.4 kB 2.1 MB/s eta 0:00:00
Collecting widgetsnbextension~=4.0.10
  Downloading widgetsnbextension-4.0.10-py3-none-any.whl (2.3 MB)
     ---------------------------------------- 2.3/2.3 MB 9.1 MB/s eta 0:00:00
Collecting jupyterlab-widgets~=3.0.10
  Downloading jupyterlab_widgets-3.0.10-py3-none-any.whl (215 kB)
     ------------------------------------- 215.0/215.0 kB 12.8 MB/s eta 0:00:00
Installing collected packages: widgetsnbextension, jupyterlab-widgets, ipywidgets
Successfully installed ipywidgets-8.1.2 jupyterlab-widgets-3.0.10 widgetsnbextension-4.0.10
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.3.1 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


## Functions

In [1]:
import json
from typing import Any, Dict, Protocol, cast, runtime_checkable

from langchain_core.callbacks import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
    AIMessage,
    BaseMessage,
    FunctionMessage,
    HumanMessage,
    SystemMessage,
    ToolMessage,
)
from langchain_core.messages import ToolCall as LCToolCall
from langchain_core.outputs import (
    ChatGeneration,
    ChatResult,
)
from pydantic.v1 import Field
from typing_extensions import override

from inspect_ai.model import (
    ChatMessage,
    ChatMessageAssistant,
    ChatMessageSystem,
    ChatMessageTool,
    ChatMessageUser,
    Content,
    ContentImage,
    ContentText,
    GenerateConfig,
    ModelName,
    ModelOutput,
    ToolCall,
    ToolChoice,
    ToolInfo,
    ToolParam,
    get_model,
)
from inspect_ai.solver import Generate, Solver, TaskState

@runtime_checkable
class LangChainAgent(Protocol):
    async def __call__(
        self, llm: BaseChatModel, input: dict[str, Any]
    ) -> str | list[str | dict[str, Any]]:
        ...


def langchain_solver(agent: LangChainAgent) -> Solver:
    async def solve(state: TaskState, generate: Generate) -> TaskState:
        # create the inspect model api bridge
        llm = InspectChatModel()

        # call the agent
        await agent(
            llm=llm,
            input=dict(
                input=state.user_prompt.text,
                chat_history=as_langchain_chat_history(state.messages[1:]),
            ),
        )

        # collect output from llm interface
        state.messages = llm.messages
        state.output = llm.output

        # return state
        return state

    return solve


class InspectChatModel(BaseChatModel):
    # track messages and model output so we can update
    # the inspect task state when we are complete
    messages: list[ChatMessage] = Field(default=[], exclude=True)
    output: ModelOutput = Field(default=ModelOutput(), exclude=True)

    @property
    def _llm_type(self) -> str:
        return f"Inspect ({ModelName(get_model()).api})"

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        return {
            "model_name": str(ModelName(get_model()).name),
        }

    @override
    def _generate(
        self,
        messages: list[BaseMessage],
        stop: list[str] | None = None,
        run_manager: CallbackManagerForLLMRun | None = None,
        **kwargs: Any,
    ) -> ChatResult:
        # inspect uses async exclusively
        raise NotImplementedError

    @override
    async def _agenerate(
        self,
        messages: list[BaseMessage],
        stop: list[str] | None = None,
        run_manager: AsyncCallbackManagerForLLMRun | None = None,
        **kwargs: dict[str, Any],
    ) -> ChatResult:
        # extract tools from kwargs
        tools: list[ToolInfo] = []
        tool_choice: ToolChoice | None = None
        lc_tools = cast(list[dict[str, Any]] | None, kwargs.get("tools", None))
        if lc_tools:
            tools = [
                ToolInfo(
                    name=tool["function"]["name"],
                    description=tool["function"]["description"],
                    params=as_inspect_tool_params(tool["function"]["parameters"]),
                )
                for tool in lc_tools
            ]
            tool_choice = "auto"

        # generate
        input = [as_inspect_message(message) for message in messages]
        result = await get_model().generate(
            input=input,
            tools=tools,
            tool_choice=tool_choice,
            config=GenerateConfig(stop_seqs=stop),
        )

        # track last messages / model output
        self.messages = input
        self.messages.append(result.choices[0].message)
        self.output = result

        # extract choices
        generations = [
            ChatGeneration(message=as_langchain_message(choice.message))
            for choice in result.choices
        ]

        # return
        return ChatResult(generations=generations)


def as_inspect_message(message: BaseMessage) -> ChatMessage:
    if isinstance(message, SystemMessage):
        return ChatMessageSystem(content=as_inspect_content(message.content))
    elif isinstance(message, HumanMessage):
        return ChatMessageUser(content=as_inspect_content(message.content))
    elif isinstance(message, AIMessage):
        return ChatMessageAssistant(
            content=as_inspect_content(message.content),
            tool_calls=(
                [
                    ToolCall(
                        type="function",
                        function=call["name"],
                        id=call["id"] or call["name"],
                        arguments=call["args"],
                    )
                    for call in message.tool_calls
                ]
                if message.tool_calls and len(message.tool_calls) > 0
                else None
            ),
        )
    elif isinstance(message, ToolMessage):
        return ChatMessageTool(
            content=as_inspect_content(message.content),
            tool_call_id=message.tool_call_id,
        )
    elif isinstance(message, FunctionMessage):
        return ChatMessageTool(
            content=as_inspect_content(message.content), tool_call_id=message.name
        )
    else:
        raise ValueError(f"Unexpected message type: {type(message)}")


def as_langchain_message(message: ChatMessage) -> BaseMessage:
    if isinstance(message, ChatMessageSystem):
        return SystemMessage(content=as_langchain_content(message.content))
    elif isinstance(message, ChatMessageUser):
        return HumanMessage(content=as_langchain_content(message.content))
    elif isinstance(message, ChatMessageAssistant):
        additional_kwargs: dict[str, Any] = {}
        if message.tool_calls and len(message.tool_calls) > 0:
            additional_kwargs["tool_calls"] = [
                dict(
                    id=call.id, name=call.function, arguments=json.dumps(call.arguments)
                )
                for call in message.tool_calls
            ]

        return AIMessage(
            content=as_langchain_content(message.content),
            tool_calls=(
                [
                    LCToolCall(id=call.id, name=call.function, args=call.arguments)
                    for call in message.tool_calls
                ]
                if message.tool_calls
                else []
            ),
            additional_kwargs=additional_kwargs,
        )
    elif isinstance(message, ChatMessageTool):
        return ToolMessage(
            content=as_langchain_content(message.content),
            tool_call_id=message.tool_call_id or "",
        )
    else:
        raise ValueError(f"Unexpected message type: {type(message)}")


def as_langchain_chat_history(messages: list[ChatMessage]) -> list[dict[str, Any]]:
    return [dict(role=message.role, content=message.text) for message in messages]


def as_inspect_content(
    content: str | list[str | dict[str, Any]],
) -> str | list[Content]:
    if isinstance(content, str):
        return content
    else:
        return [
            (
                ContentText(text=c)
                if isinstance(c, str)
                else (
                    ContentText(text=c["text"])
                    if c["type"] == "text"
                    else ContentImage(image=c["image"])
                )
            )
            for c in content
        ]


def as_inspect_tool_params(parameters: dict[str, Any]) -> list[ToolParam]:
    params: list[ToolParam] = []
    for key, param in parameters["properties"].items():
        params.append(
            ToolParam(
                name=key,
                type=param["type"],
                description=param.get("description", param.get("title")),
                optional=key not in parameters["required"],
            )
        )
    return params


def as_langchain_content(
    content: str | list[Content],
) -> str | list[str | dict[str, Any]]:
    if isinstance(content, str):
        return content
    else:
        return [c if isinstance(c, str) else c.model_dump() for c in content]





In [17]:
from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI
from langchain.chains import LLMChain
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
import pandas as pd 

class FactComparator:
    def __init__(self, model):
        self.model = model
        self.parser = PydanticOutputParser(pydantic_object=ComparisonResult)

    async def process_data(self, context, answer):
        print(type(self.model))
        context_replace_pronouns = (await self.model._agenerate([HumanMessage(content=context)])).generations[0].text
        answer_replace_pronouns = (await self.model._agenerate([HumanMessage(content=answer)])).generations[0].text

        context_list = (await self.model._agenerate([HumanMessage(content=self._parse_prompt().format(text=context_replace_pronouns))])).generations[0].text
        answer_list = (await self.model._agenerate([HumanMessage(content=self._parse_prompt().format(text=answer_replace_pronouns))])).generations[0].text

        comparison_result = self.parser.parse((await self.model._agenerate([HumanMessage(content=self._compare_prompt().format(context_list=context_list, answer_list=answer_list))])).generations[0].text)

        return {
            "context_replace_pronouns": context_replace_pronouns,
            "answer_replace_pronouns": answer_replace_pronouns,
            "context_list": context_list,
            "answer_list": answer_list,
            "comparison_result": comparison_result,
        }

    def calculate_metrics(self, comparison_result):
        facts_in_both_count = len(comparison_result.facts_in_both)
        facts_only_in_answer_count = len(comparison_result.facts_only_in_answer)
        facts_only_in_context_count = len(comparison_result.facts_only_in_context)

        total_answer_facts = facts_in_both_count + facts_only_in_answer_count
        total_context_facts = facts_in_both_count + facts_only_in_context_count

        groundedness = facts_in_both_count / total_answer_facts * 100 if total_answer_facts > 0 else 0
        thoroughness = facts_in_both_count / total_context_facts * 100 if total_context_facts > 0 else 0

        return {
            "groundedness": groundedness,
            "thoroughness": thoroughness,
        }

    def process_data_list(self, data_list):
        results = []
        for data in data_list:
            context = data['context']
            answer = data['answer']
    
            result = asyncio.run(comparator.process_data(context, answer))
            metrics = comparator.calculate_metrics(result["comparison_result"])

            result_data = {
                'context': context,
                'answer': answer,
                'context_replace_pronouns': result["context_replace_pronouns"],
                'answer_replace_pronouns': result["answer_replace_pronouns"],
                'context_list': result["context_list"],
                'answer_list': result["answer_list"],
                'facts_in_both': ', '.join(result["comparison_result"].facts_in_both),
                'facts_only_in_answer': ', '.join(result["comparison_result"].facts_only_in_answer),
                'facts_only_in_context': ', '.join(result["comparison_result"].facts_only_in_context),
                'groundedness': metrics['groundedness'],
                'thoroughness': metrics['thoroughness']
            }
            results.append(result_data)

        return pd.DataFrame(results)

    @staticmethod
    def _pronoun_prompt():
        return PromptTemplate(
            input_variables=["text"],
            template="""
            Your task is to replace all the pronouns in the following text with the nouns they refer to:

            <text>
            {text}
            </text>

            The goal is to make the text more explicit and clear by replacing potentially ambiguous pronouns like "he", "she", "it", "they", "them", etc. with the specific nouns or names they refer to.

            For example:
            Original: John went to the store. He bought some milk.
            Pronoun replaced: John went to the store. John bought some milk.

            Here are the steps to complete this task:

            1. Carefully read the provided text and identify all the pronouns 
            2. For each pronoun, look back in the text to determine which noun or name it is referring to
            3. If the pronoun is part of a direct quote, do not replace it
            4. Replace each pronoun with the most recent noun or name it refers to
            5. If a pronoun does not have a clear referent noun or name, do not replace it
            6. Repeat this process until all the pronouns with clear referents have been replaced
            """,
        )

    @staticmethod
    def _parse_prompt():
        return PromptTemplate(
            input_variables=["text"],
            template="""
            Please parse the following text into a list of individual facts:

            <text>
            {text}
            </text>

            Read the text carefully. Your task is to break it down into the key facts it contains. Parse out each individual fact into a separate sentence, even if that means splitting up or rewording the original sentences. The goal is to have a clear, concise list of the core facts contained in the text.

            Output the parsed facts in a numbered list, with each fact written as a complete sentence on its own line. Use <facts> tags to demarcate the start and end of the list.
            """,
        )

    @staticmethod
    def _compare_prompt():
        return PromptTemplate(
            input_variables=["context_list", "answer_list"],
            template="""
            You will be comparing facts between a context and an answer to determine which facts are shared and which are unique to each.

            Here is the context:

            <context>
            {context_list}
            </context>

            And here is the answer: 

            <answer>
            {answer_list}
            </answer>

            Carefully analyze the facts presented in the context and answer, focusing on the semantic meaning rather than the exact wording.

            Then, output a dictionary with the following keys and corresponding lists of facts as values:

            1. "facts_in_both": A list of facts that are present in both the context and the answer

            2. "facts_only_in_answer": A list of facts that are only present in the answer 

            3. "facts_only_in_context": A list of facts that are only present in the context

            Remember, the facts do not need to be worded identically to be considered the same. Focus on whether the core meaning is shared or unique.

            Provide your results in this format:

            {{
                "facts_in_both": [
                    "Fact 1 present in both",
                    "Fact 2 present in both"
                ],
                "facts_only_in_answer": [
                    "Fact 1 only in answer",
                    "Fact 2 only in answer"  
                ],
                "facts_only_in_context": [
                    "Fact 1 only in context",
                    "Fact 2 only in context"
                ]
            }}
            """,
        )


class ComparisonResult(BaseModel):
    facts_in_both: list[str] = Field(default_factory=list, description="List of facts present in both context and answer")
    facts_only_in_answer: list[str] = Field(default_factory=list, description="List of facts only present in the answer")
    facts_only_in_context: list[str] = Field(default_factory=list, description="List of facts only present in the context")

## Run on First Pair of Statements

In [14]:
import asyncio


%env INSPECT_EVAL_MODEL=openai/gpt-4
%env INSPECT_MODEL_NAME=openai/gpt-4

# Create an instance of InspectChatModel with the specified model
inspect_model = InspectChatModel()

# Create an instance of FactComparator with the InspectChatModel
comparator = FactComparator(inspect_model)

context = "The quick brown fox jumps over the rock because he's happy. He was born in 2005. The hedgehog was born in 2010, but she's even happier than him."
answer = "The quick brown fox was born in 2005, and the hedgehog in 2010. The quick brown fox is not as happy as the hedgehog"

# Run the asynchronous process_data method
result = asyncio.run(comparator.process_data(context, answer))

metrics = comparator.calculate_metrics(result["comparison_result"])

print("Context with replaced pronouns:")
print(result["context_replace_pronouns"])

print("Context with replaced pronouns:")
print(result["context_replace_pronouns"])

print("\nAnswer with replaced pronouns:")
print(result["answer_replace_pronouns"])

print("\nContext list:")
print(result["context_list"])

print("\nAnswer list:")
print(result["answer_list"])

print("\nComparison result:")
print(result["comparison_result"])

print("\nMetrics:")
print(f"Groundedness: {metrics['groundedness']:.2f}%")
print(f"Thoroughness: {metrics['thoroughness']:.2f}%")

env: INSPECT_EVAL_MODEL=openai/gpt-4
env: INSPECT_MODEL_NAME=openai/gpt-4
<class '__main__.InspectChatModel'>
Context with replaced pronouns:
The fox, born in 2005, is a content and lively creature with a coat as brown as well-aged oak. His joy manifests itself in leaps of exuberance. These jumps are not dictated by necessity or an instinct for survival but are joyful bounds, spawned from his ebullient spirit. His favourite platform for these athletic displays is an old worn-out rock, almost as if the leaping is his unique way of expressing his happiness.

In 2010, another littleness of joy came into the world - a bright-eyed hedgehog. Treading along the garden, she radiates an undeniable joy that surpasses even the fox's. Despite her spiky exterior, the light in her eyes and the small curl of her mouth when she munches on her favorite treats speak a clear language: here lives one happy hedgehog. Some say it is the innocent bliss of youth, however, no one can deny that her happiness ou

## Run on another pair of statements

In [16]:
%env INSPECT_EVAL_MODEL=openai/gpt-4
%env INSPECT_MODEL_NAME=openai/gpt-4

# Create an instance of InspectChatModel with the specified model
inspect_model = InspectChatModel()

# Create an instance of FactComparator with the InspectChatModel
comparator = FactComparator(inspect_model)

context = "To boil pasta, first bring a large pot of salted water to a rolling boil over high heat.."
answer = "To boil pasta, begin by filling a large pot with water, making sure there's enough to fully submerge the pasta. Bring the water to a rolling boil over high heat, then add salt to enhance the pasta's flavor. Once the water is boiling, carefully add the pasta, stirring gently to prevent sticking. Cook the pasta according to the package instructions or until it reaches your desired level of tenderness, usually around 8-12 minutes. To check for doneness, taste a piece of pasta—it should be tender but still slightly firm (al dente)."

# Run the asynchronous process_data method
result = asyncio.run(comparator.process_data(context, answer))

metrics = comparator.calculate_metrics(result["comparison_result"])
print("Context with replaced pronouns:")
print(result["context_replace_pronouns"])

print("\nAnswer with replaced pronouns:")
print(result["answer_replace_pronouns"])

print("\nContext list:")
print(result["context_list"])

print("\nAnswer list:")
print(result["answer_list"])

print("\nComparison result:")
print(result["comparison_result"])

print("\nMetrics:")
print(f"Groundedness: {metrics['groundedness']:.2f}%")
print(f"Thoroughness: {metrics['thoroughness']:.2f}%")

env: INSPECT_EVAL_MODEL=openai/gpt-4
env: INSPECT_MODEL_NAME=openai/gpt-4
<class '__main__.InspectChatModel'>
Context with replaced pronouns:
Once the water is boiling, add the pasta. Stir it occasionally to prevent it from sticking to the bottom of the pot. 

Cooking times will vary depending on the type of pasta you are using. For example, spaghetti usually takes around 8-10 minutes to cook, while smaller pasta types like macaroni may only take 5-7 minutes. 

Check the pasta's packaging for specific cooking instructions and times. Generally, you'll want to cook the pasta until it is al dente, or firm to the bite.

Once the pasta is cooked, drain it immediately in a colander. Avoid rinsing the pasta as this will remove the starch that helps sauce adhere to it.

Now your pasta is ready to be combined with your favorite sauce or used in any recipe! Enjoy.

Answer with replaced pronouns:
When your pasta is ready, carefully drain the water from the pot using a colander. Be careful not to 

  result = asyncio.run(comparator.process_data(context, answer))


## Run on a list of dictionaries - return DF

In [18]:
data_list = [
    {
        'context': 'The quick brown fox jumps over the rock because he\'s happy. He was born in 2005. The hedgehog was born in 2010, but she\'s even happier than him.',
        'answer': 'The quick brown fox was born in 2005, and the hedgehog in 2010. The quick brown fox is not as happy as the hedgehog'
    },
    {
        'context': 'The sun is a star at the center of our solar system. It is about 93 million miles away from Earth. The sun is a hot ball of glowing gases that provides light and warmth to Earth.',
        'answer': 'The sun is a star located approximately 93 million miles from Earth. It is the source of light and heat for our planet. The sun is not a solid object, but rather a sphere of hot glowing gases.'
    },
    {
        'context': 'Birds are warm-blooded vertebrates that lay eggs and have feathers, wings, and beaks. There are over 10,000 species of birds worldwide. Some common bird species include sparrows, pigeons, and parrots.',
        'answer': 'Birds are a diverse group of animals with feathers and wings. They are warm-blooded egg-laying vertebrates. The number of bird species globally exceeds 10,000. Pigeons, parrots, and sparrows are among the most familiar bird types.'
    },
    {
        'context': 'The Eiffel Tower is a wrought-iron lattice tower located on the Champ de Mars in Paris, France. It was constructed from 1887 to 1889 and stands at a height of 324 meters. The tower is named after Gustave Eiffel, whose company designed and built it.',
        'answer': 'The Eiffel Tower, found in Paris, France, is a lattice tower made of wrought iron. Built between 1887 and 1889, it reaches a height of 324 meters. Gustave Eiffel\'s company was responsible for the tower\'s design and construction, hence its name.'
    },
    {
        'context': 'The Great Wall of China is a series of fortifications and walls built across the historical northern borders of ancient Chinese states and Imperial China. The most well-known sections were built during the Ming dynasty, which ruled from 1368 to 1644.',
        'answer': 'The Great Wall of China, a series of walls and fortifications, was constructed along the northern borders of ancient Chinese states and Imperial China. The Ming dynasty, which lasted from 1368 to 1644, is responsible for the construction of the most famous sections of the wall.'
    }
]

# Create an instance of InspectChatModel with the specified model
inspect_model = InspectChatModel()

# Create an instance of FactComparator with the InspectChatModel
comparator = FactComparator(inspect_model)

df = comparator.process_data_list(data_list)
df

<class '__main__.InspectChatModel'>
<class '__main__.InspectChatModel'>
<class '__main__.InspectChatModel'>
<class '__main__.InspectChatModel'>
<class '__main__.InspectChatModel'>


Unnamed: 0,context,answer,context_replace_pronouns,answer_replace_pronouns,context_list,answer_list,facts_in_both,facts_only_in_answer,facts_only_in_context,groundedness,thoroughness
0,The quick brown fox jumps over the rock becaus...,"The quick brown fox was born in 2005, and the ...","The fox and the hedgehog, despite their differ...",because he is older and has had to face more c...,<facts>\n1. The story features a fox and a hed...,<facts>\n1. The brown fox is older. \n2. The f...,"The fox is 16 years old., The hedgehog is 11 y...","The brown fox is older., The fox has faced mor...","The story features a fox and a hedgehog., Desp...",18.181818,13.333333
1,The sun is a star at the center of our solar s...,The sun is a star located approximately 93 mil...,The sun's diameter is about 109 times that of ...,The core temperature of the sun is estimated t...,<facts>\n1. The sun's diameter is approximatel...,<facts>\n1. The sun's core temperature is esti...,The sun's core is the site of nuclear fusion a...,The sun's core temperature is estimated to be ...,The sun's diameter is approximately 109 times ...,35.714286,35.714286
2,Birds are warm-blooded vertebrates that lay eg...,Birds are a diverse group of animals with feat...,"Birds play a vital role in the ecosystem, as t...",Birds have unique capabilities; their ability ...,<facts>\n1. Birds play a vital role in the eco...,<facts>\n1. Birds have the unique capability t...,Birds play a vital role in the ecosystem by co...,"Birds have the unique capability to fly., Ostr...","Birds serve as food for other animals., Birds ...",40.0,37.5
3,The Eiffel Tower is a wrought-iron lattice tow...,"The Eiffel Tower, found in Paris, France, is a...",The Eiffel Tower is renowned worldwide as a sy...,The Eiffel Tower was initially criticized by s...,<facts>\n1. The Eiffel Tower is known worldwid...,<facts>\n1. The Eiffel Tower was initially cri...,The Eiffel Tower is known worldwide as a symbo...,The Eiffel Tower was initially criticized by s...,The Tower has three levels accessible by eleva...,35.0,43.75
4,The Great Wall of China is a series of fortifi...,"The Great Wall of China, a series of walls and...",The construction of the Great Wall began as ea...,The Great Wall was initially built to protect ...,<facts>\n1. The construction of the Great Wall...,<facts>\n1. The Great Wall was initially built...,The construction of the Great Wall of China be...,"Numerous myths surround the Great Wall, includ...",These sections were mainly made from bricks an...,73.333333,73.333333
