In [None]:
import shap
import pandas as pd
import numpy as np
import pickle
import random 
import yaml
import re
import dill

from shapnarrative_metrics.llm_tools import llm_wrappers
from shapnarrative_metrics.misc_tools.manipulations import full_inversion, shap_permutation
from shapnarrative_metrics.llm_tools.generation import GenerationModel
from shapnarrative_metrics.llm_tools.extraction import ExtractionModel

### Load necessary keys and data

In [None]:
with open("config/keys.yaml") as f:
    dict=yaml.safe_load(f)
api_key = dict["API_keys"]["OpenAI"]
replicate_key = dict["API_keys"]["Replicate"]
anthropic_key=dict["API_keys"]["Anthropic"]
mistral_key=dict["API_keys"]["Mistral"]

In [None]:
dataset_name="fifa"

with open(f'data/{dataset_name}_dataset/dataset_info', 'rb') as f:
   ds_info= pickle.load(f)

with open(f'data/{dataset_name}_dataset/RF.pkl', 'rb') as f:
   trained_model=pickle.load(f)

train=pd.read_parquet(f"data/{dataset_name}_dataset/train_cleaned.parquet")
test=pd.read_parquet(f"data/{dataset_name}_dataset/test_cleaned.parquet")

In [None]:
n=14

idx=882

idx=4
x=test[test.columns[0:-1]].loc[[idx]]
y=test[test.columns[-1]].loc[[idx]]

 

In [None]:
TEMPERATURE=0
MANIP=True

gpt = llm_wrappers.GptApi(api_key, model="gpt-4o", system_role="You are a teacher that explains AI predictions.", temperature=TEMPERATURE)
llama_generation = llm_wrappers.LlamaAPI(api_key=replicate_key , model="llama-3-70b-instruct",system_role="You are a teacher that explains AI predictions.", temperature=TEMPERATURE)
claude_generation = llm_wrappers.ClaudeApi(api_key=anthropic_key , model="claude-3-5-sonnet-20240620",system_role="You are a teacher that explains AI predictions.", temperature=TEMPERATURE)
mistral_generation=llm_wrappers.MistralApi(api_key=mistral_key, model="mistral-large-2407" ,system_role="You are a teacher that explains AI predictions.", temperature=TEMPERATURE)
generator=GenerationModel(ds_info=ds_info, llm=gpt)


In [None]:
generator.gen_variables(trained_model,x,y,tree=True)
generator.explanation_list[0].head(4)

In [None]:
prompt=generator.generate_story_prompt(iloc_pos=0,manipulate=MANIP, manipulation_func=full_inversion)
print(prompt)

In [None]:
narratives =generator.generate_stories(trained_model, x , y , tree=True, manipulate=MANIP)
narrative_split=re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', narratives[0])
for sentence in narrative_split:
    print(sentence)

In [None]:
extractor=ExtractionModel(ds_info=ds_info, llm=gpt)
extraction=extractor.generate_extractions(narratives)
extraction[0]

In [None]:
generator.explanation_list[0].head(4)

In [None]:
rank_diff, sign_diff , value_diff, real_rank, extracted_rank=extractor.get_diff(extraction[0],generator.explanation_list[0])

In [None]:
rank_diff

In [None]:
sign_diff