In [None]:
import sys
sys.path.append("../../saia-finetuning")

from functools import partial
from typing import Any
from tools import parse_investigation, extract_splunk_system_data, merge_dictionaries, Formatter, InformationExtractor

from constants import MITRE_INSTRUCTIONS, FORMATTER_INSTRUCTIONS, METADATA_MAPPING
from tooling.agents.executor import Executor
from tooling.agents.state import StateGraph
# from tooling.llm_engine.llama3_1Instruct import Llama3_1Instruct as llm
from tooling.llm_engine.azure_oai import AzureGPT as llm

### Specify State Graph Actions and Compositions

In [None]:
# Construct state graph
action_space = StateGraph()

# Instantiate LLM
model = llm()

# specify skills
formatter = Formatter()
information_extractor = InformationExtractor()

# add actions
action_space.add_action(action_name="data_chunker", callable=parse_investigation)
action_space.add_action(
    action_name=f"{information_extractor.skill_name}_mitre_information",
    callable=partial(information_extractor.execute_skill, model=model, instructions=MITRE_INSTRUCTIONS, output_key="mitre_summary")
)
# action_space.add_action(
#     action_name=f"{information_extractor.skill_name}_system_information",
#     callable=partial(information_extractor.execute_skill, model=model, instructions=SPLUNK_SYSTEM_INSTRUCTIONS, temperature=0.2, output_key="system_summary")
# )
action_space.add_action(
    action_name=f"{information_extractor.skill_name}_system_information",
    callable=partial(extract_splunk_system_data, system_data_map=METADATA_MAPPING, output_key="system_summary")
)
action_space.add_action(
    action_name="merge_context",
    callable=partial(merge_dictionaries, mitre_summary="mitre_summary", system_summary="system_summary", temperature=0.2, output_key="merged_summaries")
)

action_space.add_action(
    action_name=formatter.skill_name,
    callable=partial(formatter.execute_skill, model=model, temperature=0.2, instructions=FORMATTER_INSTRUCTIONS, output_key="finding_ai_summary")
)

# define a condition
def state_contains_keys(keys:list[str], context:dict[str,Any]) -> bool:
    return set(context).issuperset(keys)

# add conditional composition
# chunk data and send subcontexts to information extractors
action_space.add_composition(
    action_name_0="data_chunker",
    action_name_1=f"{information_extractor.skill_name}_mitre_information",
    argument_map={"mitre_subcontext": "context"},
)
action_space.add_composition(
    action_name_0="data_chunker",
    action_name_1=f"{information_extractor.skill_name}_system_information",
    argument_map={"system_subcontext": "investigation_data"},
)
# action_space.add_composition(
#     action_name_0="data_chunker",
#     action_name_1=f"{information_extractor.skill_name}_system_information",
#     argument_map={"system_subcontext": "context"},
# )

# TODO need to define logic for a skill to specify multiple skills must finish prior to executing
# In this notebook we avoid this due to the BFS execution logic and early termination
required_keys = ["mitre_subcontext", "system_subcontext"]
action_space.add_composition(
    action_name_0=f"{information_extractor.skill_name}_mitre_information",
    action_name_1="merge_context",
    condition=partial(state_contains_keys, keys=["mitre_summary", "system_summary"])
)
action_space.add_composition(
    action_name_0=f"{information_extractor.skill_name}_system_information",
    action_name_1="merge_context",
    argument_map={"mitre_summary": "mitre_summary", "system_summary": "system_summary"},
    condition=partial(state_contains_keys, keys=["mitre_summary", "system_summary"])
)
action_space.add_composition(
    action_name_0="merge_context",
    action_name_1=formatter.skill_name,
    argument_map={"merged_summaries": "context"},
    condition=partial(state_contains_keys, keys=["merged_summaries"])
)

In [None]:
import matplotlib.pyplot as plt
import networkx as nx

graph = action_space.graph
pos = nx.planar_layout(action_space.graph)
nx.draw_networkx_nodes(graph, pos)
nx.draw_networkx_edges(graph, pos)
nx.draw_networkx_labels(graph, pos, font_size=8)
plt.show()

### Load Data

In [None]:
import json

FINDING_FILES = [
    "/mount/splunka100groupstorage/a100-fs-share1/lbetthauser/scripts/finding_summarization/sample_findings/July23_8_Investigations_Analyst_Notes.json",
    "/mount/splunka100groupstorage/a100-fs-share1/lbetthauser/scripts/finding_summarization/sample_findings/July25_10_Investigations_Analyst_Notes.json",
    "/mount/splunka100groupstorage/a100-fs-share1/lbetthauser/scripts/finding_summarization/sample_findings/noah3_investigations.json",
    "/mount/splunka100groupstorage/a100-fs-share1/lbetthauser/scripts/finding_summarization/sample_findings/sample_input.json",
]
investigations = []
for file in FINDING_FILES:
    with open(file, "r") as f:
        _data = json.load(f)
        if isinstance(_data, list):
            investigations += _data
        elif isinstance(_data, dict):
            investigations.append(_data)
        else:
            print(f"could not parse file: {file}")

### Run Executor

In [None]:
# instantiate executor
executor = Executor()
executor.add_state_graph(action_space)

# construct initial state if applicable
INVESTIGATION = investigations[0]
state = {
    "investigation": INVESTIGATION
}

executor.augment_state(state)
executor.set_entry_point(action_name="data_chunker")

# run executor
terminal_state = executor.run()

In [None]:
# final answer
print(terminal_state["finding_ai_summary"])

In [None]:
# inputs
print(terminal_state["system_summary"])
print(terminal_state["mitre_summary"])