## Setup

In [1]:
import pandas as pd
import numpy as np
import yaml
import openai
import asyncio 
import shap
import pickle
from autogen_core.models import AssistantMessage, FunctionExecutionResult, FunctionExecutionResultMessage, UserMessage
from autogen_agentchat.agents import AssistantAgent, CodeExecutorAgent
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.conditions import TextMentionTermination, MaxMessageTermination
from autogen_agentchat.ui import Console
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_agentchat.messages import TextMessage
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_core import CancellationToken
import sys
import os


sys.path.append("../SHAPnarrative-metrics")
sys.path.append("../script")

from shapnarrative_metrics.llm_tools import llm_wrappers
from shapnarrative_metrics.misc_tools.manipulations import full_inversion, shap_permutation
from shapnarrative_metrics.llm_tools.generation import GenerationModel
from shapnarrative_metrics.llm_tools.extraction import ExtractionModel
from ExtractorCritic import ExtractorCriticAgent
from Narrator import NarratorAgent


  from .autonotebook import tqdm as notebook_tqdm


#### API Key & dataset

In [3]:
with open("../config/keys.yaml") as f:
    dict=yaml.safe_load(f)
api_key = dict["API_keys"]["OpenAI"]

dataset_name="student"

with open(f'../data/{dataset_name}_dataset/dataset_info', 'rb') as f:
   ds_info= pickle.load(f)

with open(f'../data/{dataset_name}_dataset/RF.pkl', 'rb') as f:
   trained_model=pickle.load(f)

train=pd.read_parquet(f"../data/{dataset_name}_dataset/train_cleaned.parquet")
test=pd.read_parquet(f"../data/{dataset_name}_dataset/test_cleaned.parquet")

print(test.index)

Index([252, 236, 275, 148, 309,  33, 248, 150, 346, 246, 221, 261,   9, 204,
       329, 207, 332,  59, 371,  11, 300, 299, 159, 280, 175, 374, 196, 189,
       188, 211, 307, 284,  36,  72, 114, 274,  82, 345,  54, 385,  26, 190,
       120, 212, 170,  24,  13,  95, 237, 285, 264,  20, 181, 107, 171, 223,
       192,  42,  48,  15, 258, 228,  41,  31,  75, 102, 173, 373, 147,  79,
       293,  91,  52, 179, 200, 199, 320, 375, 245],
      dtype='int64')


#### Choose Data instance & Generate SHAP table, prompt and story

In [4]:
idx=293

x=test[test.columns[0:-1]].loc[[idx]]
y=test[test.columns[-1]].loc[[idx]]

TEMPERATURE=0
MANIP=True

gpt = llm_wrappers.GptApi(api_key, model="gpt-4o", system_role="You are a teacher that explains AI predictions.", temperature=TEMPERATURE)
generator=GenerationModel(ds_info=ds_info, llm=gpt)
generator.gen_variables(trained_model,x,y,tree=True)
#shap_df here is always the one without any manipulation
shap_df=pd.DataFrame(generator.explanation_list[0].head(4))
prompt = generator.generate_story_prompt(0,prompt_type="long",manipulate=MANIP)
print(shap_df)

  feature_name  SHAP_value  feature_value  feature_average  \
0     freetime   -0.052233              1         3.243671   
1     failures    0.026886              0         0.360759   
2   Mjob_other    0.025752              0         0.370253   
3        goout    0.025017              2         3.098101   

                                        feature_desc  
0  Free time after school (from 1 - very low to 5...  
1        Number of past class failures (from 0 to 3)  
2        One-hot variable for mothers's job -- other  
3  Going out with friends (from 1 - very low to 5...  


## Solution: Three-agent system: ExtractorCritic and FaithfulCritic 
FaithfulCritic summarizes ExtractorCritic's feedback and clearly communicates how the narrative should be revised.

Load the defined system messages

In [5]:
from system_messages import (
    system_message_narrator,
    system_message_narrator_MANIP,
    system_message_faithfulcritic,
    system_message_faithfulcritic_MANIP,
    system_message_coherence,
)

if MANIP:
    narrator_sys_msg = system_message_narrator_MANIP
    faithfulcritic_sys_msg = system_message_faithfulcritic_MANIP
    coherence_sys_msg=system_message_coherence
    print ("The SHAP table is manipulated")
else:
    narrator_sys_msg = system_message_narrator
    faithfulcritic_sys_msg = system_message_faithfulcritic
    coherence_sys_msg=system_message_coherence
    print ("The SHAP table is normal (not manipulated)")

The SHAP table is manipulated


## Three agent system (with extractor and faithful critic)

In [None]:
shap_df = shap_df

async def main() -> None:
    SEED = 42
    TEMPERATURE = 0

    # Initialize OpenAI client
    model_client = OpenAIChatCompletionClient(
        model="gpt-4o",
        api_key=api_key,
        seed=SEED,
        temperature=TEMPERATURE
    )
    
    # Create the narrator agent
    narrator = NarratorAgent(
        name="narrator",
        system_message=narrator_sys_msg, 
        model_client=model_client,
        reflect_on_tool_use=False,
    )
    
    # Create the critic agent with advanced extraction capabilities
    extractorcritic = ExtractorCriticAgent(
        name="extractorcritic",
    )
    
    # Set up the models for the critic - THIS IS IMPORTANT
    extractorcritic.set_models(
        extractor_class=ExtractionModel,
        generator_class=GenerationModel,
        ds_info=ds_info,
        llm=gpt
    )

    extractorcritic.default_shap_df = shap_df 

    faithfulcritic = AssistantAgent(
        name="faithfulcritic",
        system_message=faithfulcritic_sys_msg,
        model_client=model_client,
        reflect_on_tool_use=False,
    )

    # Create the group chat with custom processing
    termination = TextMentionTermination("TERMINATE") | MaxMessageTermination(13)
    
    group_chat = RoundRobinGroupChat(
        [narrator, extractorcritic, faithfulcritic],
        termination_condition=termination
    )
    
    # Run the chat
    stream = group_chat.run_stream(task=prompt)
    
    await Console(stream)

if __name__ == "__main__":
    await main()