# Setup
## Import libraries

In [9]:
# Import the libraries and set the paths
import sys
import os
notebook_dir = os.getcwd()
root = os.path.abspath(os.path.join(notebook_dir, '../../'))
sys.path.append(root)
from IPython.display import Image, display

from utils.data.csv_parsing import load_csv_as_dicts, load_csv_as_dataframe
from utils.langchain.llm_model_selector import get_llm_from_model_name
from utils.langchain.prompts import STRUCTURED_OUTPUT_PROMPT, NAIVE_ZERO_SHOT_CLASSIFICATION_PROMPT, ROBUST_ZERO_SHOT_CLASSIFICATION_PROMPT

import getpass
from pydantic import BaseModel
from typing import Annotated
from typing_extensions import TypedDict

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langsmith.evaluation import evaluate
from langsmith.schemas import Example, Run



## Load Data

In [10]:
article_path = '/data/transformed/FA-KES.csv'

articles = load_csv_as_dicts(root+article_path)
articles_df = load_csv_as_dataframe(root+article_path)

## Load API Keys

In [11]:
def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")


_set_env("ANTHROPIC_API_KEY")
_set_env("LANGSMITH_API_KEY")
_set_env("OPENAI_API_KEY")
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "FakeNews Detection - Zero Shot"

# Graph Definition

In [16]:
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from typing_extensions import Annotated, TypedDict

# Utility function to get the right model
def get_llm_from_model_name(model_name: str):
    if "claude" in model_name:
        return ChatAnthropic(model=model_name,temperature=0)
    elif "gpt" in model_name:
        return ChatOpenAI(model=model_name,temperature=0)
    else:
        raise ValueError(f"Unknown model provider for model: {model_name}")

# Define the state
class State(TypedDict):
    messages: Annotated[list, "A list to store messages exchanged with the LLM"]
    article_title: Annotated[str, "The title of the article to be analysed"]
    article_content: Annotated[str, "The content of the article to be analysed"]
    label: Annotated[str, "The label of the classification ('Credible' or 'Fake')"]
    explanation: Annotated[str, "The reasoning behind the classification"]

# Class for misinformation detection workflow
class MisinformationDetection:
    def __init__(self, model_name: str, system_message_prefix: str):
        self.llm = get_llm_from_model_name(model_name)
        self.system_message = f"""
            {system_message_prefix}\n
            {STRUCTURED_OUTPUT_PROMPT}
        """

    def detect_misinformation(self, state: State):
        article_title = state["article_title"]
        article_content = state["article_content"]
        input_text = f"Title: {article_title}\n\nContent: {article_content}"

        response = self.llm.invoke([{"role": "system", "content": self.system_message}, 
                                    {"role": "user", "content": input_text}])

        state["messages"] = [response]  # Update state with response
        return state

    def handle_structured_output(self, state: State):
        raw_output = state["messages"][-1].content
        label_start = raw_output.find("Label: ") + len("Label: ")
        label_end = raw_output.find("\n", label_start)
        label = raw_output[label_start:label_end].strip()

        explanation_start = raw_output.find("Explanation:") + len("Explanation:")
        explanation = raw_output[explanation_start:].strip()

        state["label"] = label
        state["explanation"] = explanation if explanation else "No explanation provided."
        return state

# Class to manage the graph
class GraphManager:
    def __init__(self, model_name: str, system_message_prefix: str):
        self.detection_system = MisinformationDetection(model_name,system_message_prefix)
        self.graph_builder = StateGraph(State)
        self.build_graph()

    def build_graph(self):
        self.graph_builder.add_node("detect_fake_news", self.detection_system.detect_misinformation)
        self.graph_builder.add_node("handle_output", self.detection_system.handle_structured_output)
        self.graph_builder.add_edge(START, "detect_fake_news")
        self.graph_builder.add_edge("detect_fake_news", "handle_output")
        self.graph_builder.add_edge("handle_output", END)
        self.graph = self.graph_builder.compile()

    def run_graph_on_example(self, example: dict):
        initial_state = {
            "messages": [],
            "article_title": example.get("article_title"),
            "article_content": example.get("article_content")
        }
        final_state = self.graph.invoke(initial_state)
        return {
            "label": final_state.get("label"),
            "explanation": final_state.get("explanation")
        }

# Class for running evaluation
class Evaluator:
    @staticmethod
    def correct_label(root_run: Run, example: dict) -> dict:
        predicted_label = root_run.outputs.get("label")
        actual_label = example.outputs.get("label")
        score = predicted_label == actual_label
        return {"score": int(score), "key": "correct_label"}

    @staticmethod
    def run_evaluation(graph_manager: GraphManager, dataset_name: str):
        results = evaluate(
            graph_manager.run_graph_on_example,
            data=dataset_name,
            evaluators=[Evaluator.correct_label],
            experiment_prefix=f"{graph_manager.detection_system.llm.__class__.__name__}",
            description="Evaluating graph-based misinformation detection system."
        )
        return results


In [17]:
model_name = "claude-3-haiku-20240307"
graph_manager = GraphManager(model_name,system_message_prefix=NAIVE_ZERO_SHOT_CLASSIFICATION_PROMPT)
dataset_name = "FA-KES test"
evaluation_results = Evaluator.run_evaluation(graph_manager, dataset_name)
print(evaluation_results)

View the evaluation results for experiment: 'ChatAnthropic-25759c41' at:
https://smith.langchain.com/o/7ade50a2-3a1f-5106-9003-6a8cfb7b3652/datasets/a26430d4-620e-4f8b-aa7e-3676c5486d8c/compare?selectedSessions=f358ee24-6d48-401d-9f15-6ef39c371649




10it [00:03,  2.66it/s]

<ExperimentResults ChatAnthropic-25759c41>



